%cd ..
%load_ext autoreload
%autoreload 2
/home/runner/work/numpyro-doing-bayesian/numpyro-doing-bayesian
import arviz as az
import jax.numpy as jnp
import jax.random as random
import matplotlib.pyplot as plt
import numpy as np
import numpyro
from numpyro.infer import MCMC, NUTS
import numpyro_glm
import numpyro_glm.metric.models as glm_metric
import numpyro_glm.metric.plots as plots
import pandas as pd
from scipy.stats import norm, t
numpyro.render_model( glm_metric.one_group, model_args=(jnp.ones(5), ), render_params=True)
MEAN = 5
STD = 3
N = 100
y = np.random.normal(MEAN, STD, N)
# Plot the histogram and true PDF of normal distribution.
fig, ax = plt.subplots()
ax.hist(y, density=True, label='Histogram of $y$')
xmin, xmax = ax.get_xlim()
p = np.linspace(xmin, xmax, 1000)
ax.plot(p, norm.pdf(p, loc=MEAN, scale=STD), label='Normal PDF')
ax.legend()
fig.tight_layout()
Now, we'll try to apply the metric model on that data to see if it can recover the parameter.
mcmc_key = random.PRNGKey(0)
kernel = NUTS(glm_metric.one_group)
mcmc = MCMC(kernel, num_warmup=250, num_samples=750)
mcmc.run(mcmc_key, jnp.array(y))
mcmc.print_summary()
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
mean std median 5.0% 95.0% n_eff r_hat
mean 4.96 0.31 4.97 4.53 5.51 517.16 1.00
std 3.05 0.22 3.03 2.73 3.45 746.19 1.00
Number of divergences: 0
Plot diagnostics plot to see if the MCMC chains are well-behaved.
numpyro_glm.plot_diagnostic(mcmc, ['mean', 'std'])
<xarray.Dataset>
Dimensions: (chain: 1, draw: 750)
Coordinates:
* chain (chain) int64 0
* draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749
Data variables:
mean (chain, draw) float32 5.209 4.529 5.347 5.227 ... 4.773 4.991 4.968
std (chain, draw) float32 2.917 3.103 2.868 2.874 ... 2.672 3.41 3.339
Attributes:
created_at: 2023-05-29T23:56:31.752566
arviz_version: 0.12.1
inference_library: numpyro
inference_library_version: 0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([[5.209106 , 4.528548 , 5.3471146, 5.227207 , 4.2945375, 4.9182925,
4.7712283, 5.2968564, 5.132386 , 4.844204 , 5.586605 , 4.487925 ,
5.3550844, 5.4572015, 4.448584 , 4.9535537, 5.02507 , 4.8676105,
5.031705 , 5.547575 , 5.0685067, 4.779368 , 5.002617 , 5.0431023,
5.4613724, 4.8222647, 5.0249243, 4.903239 , 5.109538 , 4.791892 ,
5.222782 , 5.2325306, 4.6903687, 4.693336 , 4.9970613, 4.7783775,
4.8024116, 5.4343495, 5.04768 , 4.604811 , 4.6258793, 4.529574 ,
4.5864644, 4.955139 , 4.8493094, 5.0941453, 4.96266 , 4.893078 ,
5.227385 , 4.9269013, 5.0402994, 5.046525 , 5.125822 , 4.6672425,
4.203383 , 4.907324 , 4.990067 , 4.1970205, 4.35923 , 4.966628 ,
4.88882 , 5.092228 , 5.1102777, 4.8794503, 5.1425686, 4.9433007,
4.9666333, 5.1734595, 4.7041483, 5.017935 , 4.9279866, 5.1911583,
5.1911583, 5.011348 , 5.0969815, 4.9910264, 4.695913 , 4.8254447,
4.7390966, 5.0573435, 4.6840773, 5.1652255, 5.024248 , 4.6640615,
4.5730023, 5.0042996, 4.979213 , 4.7948694, 4.718663 , 4.583088 ,
4.7011843, 5.0055323, 4.8422976, 4.8485436, 4.965094 , 4.7230763,
5.069906 , 4.9947624, 4.9111443, 4.2580767, 5.2471714, 4.5457106,
4.775864 , 5.025947 , 5.061445 , 4.8522716, 5.1862683, 5.0539246,
5.0416985, 5.168807 , 4.7731266, 4.703271 , 4.8041005, 4.589089 ,
4.500424 , 4.4929605, 4.4101706, 5.3536725, 5.31576 , 4.583191 ,
...
4.989824 , 4.7870226, 5.1676106, 4.773499 , 4.433352 , 5.527904 ,
5.376922 , 5.3109035, 4.531634 , 5.121908 , 5.2329516, 4.5748124,
4.9590473, 4.995907 , 4.777923 , 4.907775 , 5.110473 , 5.484889 ,
4.984773 , 4.9303713, 4.9617133, 4.9597883, 4.201158 , 4.108816 ,
5.5899467, 4.4740453, 5.1457696, 5.0359325, 5.559925 , 5.5884337,
5.177686 , 5.4082894, 4.67077 , 4.835922 , 5.263118 , 5.2465863,
5.2224298, 5.4455686, 5.6108904, 4.6170073, 4.7014694, 4.8181043,
5.2965465, 4.779739 , 4.9903846, 4.6685143, 5.001776 , 5.2450037,
4.6075706, 5.258145 , 5.197925 , 4.683494 , 5.085534 , 5.0441504,
5.057175 , 4.7345133, 5.0199018, 5.068657 , 4.5667806, 4.5667806,
5.1042504, 5.0103126, 5.3162837, 5.049756 , 5.0858436, 4.8877597,
4.871423 , 4.964435 , 4.952388 , 5.177528 , 4.7585597, 4.937545 ,
5.2067914, 5.404242 , 5.3792343, 5.4727798, 5.503424 , 3.9227598,
4.616737 , 5.349469 , 4.94926 , 5.276943 , 4.5606093, 4.350913 ,
4.8356533, 5.3012824, 4.9218187, 5.049454 , 5.0492573, 5.3923044,
5.154832 , 5.172335 , 4.881768 , 4.9017844, 4.2971644, 5.148047 ,
4.9634027, 4.969659 , 5.075037 , 4.83324 , 4.8330393, 5.2102723,
4.702924 , 4.691301 , 5.379582 , 4.30379 , 5.499221 , 4.7738557,
5.0708055, 4.6337047, 4.5571523, 4.7732024, 4.9907207, 4.967629 ]],
dtype=float32)array([[2.916654 , 3.1028943, 2.8679268, 2.8739176, 3.1552985, 3.1833546,
2.927499 , 2.9361997, 3.2647386, 2.7347465, 3.6316378, 3.186601 ,
2.9091659, 2.878644 , 3.0915058, 3.2480638, 2.8444803, 2.9026027,
3.0260732, 3.271102 , 3.2191224, 3.219358 , 2.9903677, 3.2239099,
3.1981618, 3.1861975, 2.8435102, 3.0665843, 2.7998087, 3.2023396,
3.0847368, 3.129275 , 2.832983 , 2.7954605, 3.4137855, 2.7664967,
2.963932 , 3.4414675, 2.8450003, 3.208023 , 3.1765056, 3.1218975,
3.1018298, 2.775923 , 3.06603 , 2.757947 , 3.0582092, 3.1054397,
2.8913946, 2.94138 , 3.5591953, 3.0489569, 3.0945673, 3.0281289,
2.956422 , 3.021427 , 3.1175191, 2.9662404, 2.9738827, 3.0137942,
3.0707273, 2.9698124, 3.1119697, 2.8752227, 3.112528 , 3.043435 ,
2.8680277, 3.205892 , 3.4510348, 2.9880588, 2.9658875, 2.6942782,
2.6942782, 2.6509972, 2.7107518, 3.024229 , 2.7974968, 2.8184597,
3.330964 , 3.1583602, 2.9045746, 2.8659198, 3.016166 , 3.070869 ,
3.4217649, 2.9372633, 3.2278461, 2.8106005, 2.9665637, 2.9308615,
2.7869065, 3.2926528, 2.7854984, 3.6578674, 3.8099601, 2.665423 ,
3.0543578, 2.9724834, 3.0155618, 3.133643 , 2.9342916, 3.0967343,
3.5509822, 3.3757555, 3.491986 , 3.4743314, 2.8963757, 3.2922297,
2.6979878, 3.2163372, 2.8820283, 2.8636553, 3.1171484, 3.261943 ,
3.330661 , 3.3311145, 3.3987849, 2.7803333, 2.9204037, 3.2823446,
...
2.9119968, 3.0894525, 3.6251438, 3.1715524, 3.0566335, 2.939902 ,
2.8716073, 2.9746158, 2.9994235, 3.2255704, 3.2961411, 2.9679601,
3.4174898, 3.3049793, 2.9214787, 3.075012 , 2.8656034, 3.1087933,
2.8138204, 3.1475034, 3.1343572, 3.1616294, 3.1369445, 3.1915271,
2.875889 , 3.1248076, 3.087011 , 2.9718072, 3.2590137, 3.2520576,
3.185291 , 2.859958 , 2.821579 , 2.781268 , 2.8592463, 2.9062448,
3.534457 , 2.9518278, 3.3549776, 2.7666702, 2.7543926, 3.254572 ,
2.9607222, 3.1747704, 3.569678 , 3.118485 , 2.9555767, 2.847376 ,
2.9998915, 3.1964002, 3.1241221, 3.0805674, 3.368725 , 2.869935 ,
2.7692757, 2.6986244, 3.442483 , 3.3160274, 2.842026 , 2.842026 ,
3.2840214, 3.4086637, 2.5889783, 2.8690841, 3.1804018, 3.1136734,
3.1699562, 2.8367765, 2.8882794, 3.1074224, 2.8305895, 2.9813702,
3.1293077, 2.8376162, 2.9135425, 3.3758152, 3.674394 , 2.779959 ,
3.299925 , 3.1810057, 3.2408876, 2.957203 , 3.4878654, 2.9024978,
3.3169794, 3.256874 , 3.084282 , 2.9575238, 2.7949817, 3.0721962,
3.303253 , 3.2028015, 2.9452498, 3.3317125, 3.3685954, 2.8577948,
3.2133477, 2.846444 , 3.1764314, 2.807072 , 2.778171 , 2.8153605,
3.2772346, 2.8095682, 3.1858308, 2.9680984, 2.7983553, 3.3925478,
2.7187293, 3.3872933, 3.0770774, 2.6717114, 3.4103987, 3.339314 ]],
dtype=float32)PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
dtype='int64', name='draw', length=750))<xarray.Dataset>
Dimensions: (chain: 1, draw: 750, y_dim_0: 100)
Coordinates:
* chain (chain) int64 0
* draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749
* y_dim_0 (y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 91 92 93 94 95 96 97 98 99
Data variables:
y (chain, draw, y_dim_0) float32 -2.681 -1.99 ... -2.229 -2.194
Attributes:
created_at: 2023-05-29T23:56:31.843696
arviz_version: 0.12.1
inference_library: numpyro
inference_library_version: 0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
90, 91, 92, 93, 94, 95, 96, 97, 98, 99])array([[[-2.6805243, -1.9896044, -2.146286 , ..., -2.2449687,
-2.0858808, -2.048276 ],
[-2.4436066, -2.0799394, -2.3294597, ..., -2.4485612,
-2.2511694, -2.198129 ],
[-2.746058 , -1.9728756, -2.108558 , ..., -2.2030478,
-2.051998 , -2.0178077],
...,
[-2.5292451, -1.9190506, -2.2017465, ..., -2.3469207,
-2.1082296, -2.0462952],
[-2.5889425, -2.1491568, -2.2932622, ..., -2.3739157,
-2.2424622, -2.2096944],
[-2.5803223, -2.1288443, -2.2824044, ..., -2.3674629,
-2.2286887, -2.1939304]]], dtype=float32)PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
dtype='int64', name='draw', length=750))PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67,
68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84,
85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99],
dtype='int64', name='y_dim_0'))<xarray.Dataset>
Dimensions: (chain: 1, draw: 750)
Coordinates:
* chain (chain) int64 0
* draw (draw) int64 0 1 2 3 4 5 6 7 ... 742 743 744 745 746 747 748 749
Data variables:
diverging (chain, draw) bool False False False False ... False False False
Attributes:
created_at: 2023-05-29T23:56:31.764392
arviz_version: 0.12.1
inference_library: numpyro
inference_library_version: 0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([[False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
...
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False]])PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
dtype='int64', name='draw', length=750))<xarray.Dataset>
Dimensions: (y_dim_0: 100)
Coordinates:
* y_dim_0 (y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 91 92 93 94 95 96 97 98 99
Data variables:
y (y_dim_0) float32 1.78 5.272 6.843 6.101 ... 6.247 7.294 6.49 6.21
Attributes:
created_at: 2023-05-29T23:56:31.844516
arviz_version: 0.12.1
inference_library: numpyro
inference_library_version: 0.9.2array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
90, 91, 92, 93, 94, 95, 96, 97, 98, 99])array([ 1.7799622 , 5.2715006 , 6.843007 , 6.101341 , 1.3784775 ,
3.5009139 , 7.0576353 , 7.4335885 , 8.715986 , 1.1754153 ,
4.9838333 , 1.7403774 , 5.123005 , 1.4687153 , 6.9200506 ,
6.5501447 , 7.0576353 , 4.713186 , 4.6818757 , 6.7534842 ,
1.7409545 , 3.115359 , 10.992229 , 3.3846278 , 2.3424428 ,
6.3324614 , 0.99764264, 1.5388296 , 10.105146 , 11.020432 ,
4.220744 , 2.1013582 , 6.6121535 , 5.773824 , 0.9355608 ,
7.78474 , 4.040793 , 2.8486586 , 2.0916374 , 4.882524 ,
5.1691275 , 3.7109454 , 6.0718546 , 4.425353 , 1.1271195 ,
3.8772562 , 10.373571 , 3.238866 , 7.975573 , 10.97725 ,
1.273873 , 1.2262683 , 2.7989383 , 4.9672947 , 0.7470391 ,
3.253747 , 3.8842838 , -0.34159485, 2.9436347 , -2.2525995 ,
9.075224 , 8.399537 , 3.4938996 , -1.7792448 , 1.692837 ,
10.415877 , 3.1210582 , 6.8077483 , 9.530612 , 4.725291 ,
9.805911 , 5.4888196 , 2.1056437 , 6.0538917 , 3.962473 ,
5.519883 , 4.078588 , 7.26097 , 4.400875 , 0.30348757,
3.18701 , 4.809211 , 8.965049 , 4.0829873 , 5.228887 ,
9.951787 , 5.6544847 , 0.64701927, 9.465937 , 7.890076 ,
4.776916 , 4.893428 , 10.160401 , 2.3604243 , 7.6586046 ,
7.2983284 , 6.2470894 , 7.294434 , 6.4904785 , 6.210163 ],
dtype=float32)PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67,
68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84,
85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99],
dtype='int64', name='y_dim_0'))fig = plots.plot_st(mcmc, y)
iq_data = pd.read_csv('datasets/TwoGroupIQ.csv')
iq_data['Group'] = iq_data['Group'].astype('category')
smart_group_data = iq_data[iq_data.Group == 'Smart Drug']
smart_group_data.describe()
| Score | |
|---|---|
| count | 63.000000 |
| mean | 107.841270 |
| std | 25.445201 |
| min | 50.000000 |
| 25% | 96.000000 |
| 50% | 107.000000 |
| 75% | 119.000000 |
| max | 208.000000 |
Then, we will apply the one group model to the data and plot the results.
mcmc_key = random.PRNGKey(0)
kernel = NUTS(glm_metric.one_group)
mcmc = MCMC(kernel, num_warmup=250, num_samples=750)
mcmc.run(mcmc_key, jnp.array(smart_group_data.Score.values))
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
mean 107.80 3.24 107.89 102.57 113.19 539.80 1.00
std 26.08 2.35 25.88 22.73 30.35 642.03 1.00
Number of divergences: 0
numpyro_glm.plot_diagnostic(mcmc, ['mean', 'std'])
<xarray.Dataset>
Dimensions: (chain: 1, draw: 750)
Coordinates:
* chain (chain) int64 0
* draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749
Data variables:
mean (chain, draw) float32 110.0 105.3 110.2 109.4 ... 106.8 108.7 108.2
std (chain, draw) float32 24.79 26.16 24.03 23.97 ... 22.13 29.33 28.75
Attributes:
created_at: 2023-05-29T23:56:35.414985
arviz_version: 0.12.1
inference_library: numpyro
inference_library_version: 0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([[109.97527 , 105.29083 , 110.214516, 109.411125, 100.93915 ,
108.557785, 109.7286 , 107.717476, 110.33823 , 105.4407 ,
115.047325, 110.74487 , 109.90371 , 111.92637 , 103.67735 ,
109.43993 , 106.924866, 105.82306 , 109.33133 , 113.90438 ,
108.25385 , 105.84257 , 108.72531 , 109.02865 , 113.91567 ,
113.41177 , 103.41682 , 108.057 , 104.65977 , 110.4926 ,
113.583885, 102.19425 , 113.01051 , 111.27577 , 106.85726 ,
102.25837 , 107.078285, 113.12787 , 107.06807 , 104.69756 ,
104.80107 , 103.58743 , 104.322945, 108.6351 , 106.419586,
109.62201 , 107.34422 , 107.35962 , 110.8534 , 106.74869 ,
109.08335 , 109.06311 , 108.5331 , 104.78836 , 116.33449 ,
114.67767 , 112.495895, 104.07471 , 106.779434, 107.259605,
108.09704 , 108.10233 , 106.08486 , 109.5852 , 106.001076,
105.14359 , 109.282585, 110.47549 , 104.854195, 108.91195 ,
106.613556, 110.542786, 110.542786, 108.08629 , 109.27255 ,
107.772766, 104.430824, 103.29096 , 110.69535 , 112.892365,
106.9933 , 108.17603 , 108.825 , 105.24313 , 104.06096 ,
109.17792 , 107.244606, 108.34675 , 104.49413 , 106.83144 ,
104.965004, 110.35247 , 105.32731 , 109.82167 , 105.720985,
107.46112 , 110.59209 , 109.11344 , 106.73796 , 99.25905 ,
...
106.25765 , 107.159256, 107.43473 , 114.022736, 107.4383 ,
109.01808 , 103.7021 , 111.57398 , 110.96743 , 113.67769 ,
114.12441 , 109.39441 , 104.46808 , 105.55915 , 105.599365,
111.64049 , 110.684875, 109.979706, 113.008514, 113.94847 ,
103.412254, 104.3738 , 108.75197 , 111.8318 , 104.77148 ,
107.812775, 104.56448 , 108.718765, 105.21872 , 105.55679 ,
104.245285, 111.70955 , 104.24466 , 111.4618 , 110.510574,
107.00754 , 110.02091 , 104.797165, 106.90024 , 112.80083 ,
109.23521 , 105.043015, 105.97836 , 103.32183 , 108.15405 ,
107.78267 , 107.18314 , 106.99867 , 108.12765 , 107.89384 ,
109.307655, 112.93956 , 112.01034 , 110.8072 , 112.965576,
112.18863 , 112.52543 , 108.96955 , 102.037285, 107.1725 ,
104.57445 , 109.20819 , 111.1646 , 102.88633 , 108.055016,
104.86002 , 111.12093 , 106.58487 , 109.66695 , 109.28379 ,
113.10869 , 109.12012 , 109.58398 , 106.80441 , 107.53587 ,
100.24453 , 113.68477 , 106.12561 , 106.85108 , 108.66731 ,
106.49102 , 106.621 , 110.45593 , 104.96861 , 113.12082 ,
111.35625 , 102.31912 , 106.25812 , 110.87694 , 109.251045,
107.50152 , 105.6816 , 106.77369 , 108.67319 , 108.24811 ]],
dtype=float32)array([[24.786228, 26.155521, 24.029436, 23.965836, 27.113821, 27.424044,
25.806293, 26.273508, 27.535105, 22.97756 , 32.289055, 24.15046 ,
24.51352 , 24.034502, 26.246254, 27.881187, 23.841017, 24.366045,
25.547419, 28.463753, 27.826414, 27.78451 , 25.26547 , 27.71857 ,
27.465443, 26.184599, 24.367947, 28.29067 , 29.094795, 22.382898,
23.262806, 28.737497, 22.244843, 25.077694, 29.279648, 27.409065,
24.064144, 30.407253, 23.98416 , 27.507038, 27.295832, 26.6867 ,
26.46073 , 22.920782, 25.979725, 22.771954, 25.901804, 26.450947,
24.121893, 24.683409, 31.45096 , 26.017336, 26.34012 , 25.636229,
26.706406, 27.788372, 28.212412, 26.795696, 26.954515, 24.471197,
27.04578 , 24.21856 , 25.055971, 25.211184, 25.342659, 25.057606,
23.631056, 30.170628, 30.582378, 25.426617, 24.681257, 22.228506,
22.228506, 21.75323 , 22.333263, 25.517614, 23.16428 , 23.095354,
29.873755, 27.750286, 24.93584 , 23.626965, 25.444895, 26.3856 ,
30.052254, 24.859749, 22.761942, 28.185104, 23.586464, 25.712185,
22.451385, 29.302467, 22.616674, 22.743664, 23.170364, 23.876362,
27.080822, 25.806063, 24.978598, 26.404415, 27.514174, 26.977419,
31.662498, 29.875046, 31.077866, 25.236607, 24.985004, 28.296259,
22.199097, 27.613825, 24.109753, 23.913204, 26.575354, 28.296556,
29.036882, 29.066261, 29.816078, 23.089298, 24.582733, 26.92522 ,
...
24.014801, 27.789244, 32.343906, 27.632095, 26.07815 , 24.650728,
23.914785, 25.040188, 25.317043, 27.565416, 28.408096, 25.107414,
22.729567, 23.612532, 23.126312, 26.601177, 23.807423, 24.42951 ,
23.662622, 26.85912 , 26.753407, 27.049782, 25.479877, 26.435293,
23.96122 , 26.330206, 26.018835, 25.607841, 28.446865, 28.278666,
27.410109, 28.209177, 30.121782, 30.168917, 22.123705, 23.426945,
31.40853 , 25.059225, 29.351217, 27.96904 , 25.75235 , 31.656654,
25.066679, 27.139214, 31.495024, 26.959902, 25.11425 , 26.711693,
24.692608, 23.715153, 27.323866, 26.20297 , 24.107151, 24.007935,
24.023855, 26.436314, 25.4301 , 25.759718, 24.9387 , 31.076105,
22.513962, 25.891182, 29.503733, 23.962814, 27.266918, 26.573658,
27.158567, 23.648088, 24.148748, 28.249935, 30.239103, 29.819923,
26.691223, 23.647682, 24.47609 , 29.486397, 24.319372, 30.137695,
32.97399 , 23.286343, 28.740875, 24.984743, 30.649132, 34.445366,
27.562742, 27.380943, 25.81166 , 25.387672, 23.473827, 26.251698,
28.780266, 27.655264, 24.728743, 28.889326, 29.331318, 24.23243 ,
27.918295, 23.9464 , 27.325426, 23.267876, 22.971722, 23.432446,
28.220026, 27.862444, 27.050863, 25.492939, 29.31498 , 23.50533 ,
24.160336, 27.551594, 24.947176, 22.131338, 29.331263, 28.75127 ]],
dtype=float32)PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
dtype='int64', name='draw', length=750))<xarray.Dataset>
Dimensions: (chain: 1, draw: 750, y_dim_0: 63)
Coordinates:
* chain (chain) int64 0
* draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749
* y_dim_0 (y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62
Data variables:
y (chain, draw, y_dim_0) float32 -4.181 -4.136 ... -5.051 -4.286
Attributes:
created_at: 2023-05-29T23:56:35.499221
arviz_version: 0.12.1
inference_library: numpyro
inference_library_version: 0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
54, 55, 56, 57, 58, 59, 60, 61, 62])array([[[-4.180992 , -4.136431 , -4.392193 , ..., -4.210211 ,
-5.0714164, -4.132563 ],
[-4.1909137, -4.185134 , -4.312105 , ..., -4.203458 ,
-5.278142 , -4.2158976],
[-4.1566496, -4.107166 , -4.3855066, ..., -4.1885657,
-5.086642 , -4.1009784],
...,
[-4.039196 , -4.0159855, -4.238742 , ..., -4.062772 ,
-5.430601 , -4.0438166],
[-4.3234735, -4.2992196, -4.459157 , ..., -4.341311 ,
-5.022892 , -4.3040247],
[-4.3012333, -4.2785625, -4.4373045, ..., -4.31877 ,
-5.050753 , -4.2861347]]], dtype=float32)PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
dtype='int64', name='draw', length=750))PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62],
dtype='int64', name='y_dim_0'))<xarray.Dataset>
Dimensions: (chain: 1, draw: 750)
Coordinates:
* chain (chain) int64 0
* draw (draw) int64 0 1 2 3 4 5 6 7 ... 742 743 744 745 746 747 748 749
Data variables:
diverging (chain, draw) bool False False False False ... False False False
Attributes:
created_at: 2023-05-29T23:56:35.416760
arviz_version: 0.12.1
inference_library: numpyro
inference_library_version: 0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([[False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
...
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False]])PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
dtype='int64', name='draw', length=750))<xarray.Dataset>
Dimensions: (y_dim_0: 63)
Coordinates:
* y_dim_0 (y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62
Data variables:
y (y_dim_0) int32 102 107 92 101 110 68 119 ... 97 139 72 100 144 112
Attributes:
created_at: 2023-05-29T23:56:35.500056
arviz_version: 0.12.1
inference_library: numpyro
inference_library_version: 0.9.2array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
54, 55, 56, 57, 58, 59, 60, 61, 62])array([102, 107, 92, 101, 110, 68, 119, 106, 99, 103, 90, 93, 79,
89, 137, 119, 126, 110, 71, 114, 100, 95, 91, 99, 97, 106,
106, 129, 115, 124, 137, 73, 69, 95, 102, 116, 111, 134, 102,
110, 139, 112, 122, 84, 129, 112, 127, 106, 113, 109, 208, 114,
107, 50, 169, 133, 50, 97, 139, 72, 100, 144, 112], dtype=int32)PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62],
dtype='int64', name='y_dim_0'))fig = plots.plot_st(
mcmc, smart_group_data.Score.values,
mean_comp_val=100,
std_comp_val=15,
effsize_comp_val=0,
)
numpyro.render_model( glm_metric.one_group_robust, model_args=(jnp.ones(5),), render_params=True)
MEAN = 5
SIGMA = 3
NORMALITY = 3
N = 1000
y = np.random.standard_t(NORMALITY, size=N) * SIGMA + MEAN
# Plot the histogram and true PDF of normal distribution.
fig, ax = plt.subplots()
ax.hist(y, density=True, bins=100, label='Histogram of $y$')
xmin, xmax = ax.get_xlim()
p = np.linspace(xmin, xmax, 1000)
ax.plot(p, t.pdf(p, loc=MEAN, scale=SIGMA, df=NORMALITY), label='Student-$t$ PDF')
ax.plot(p, norm.pdf(p, loc=y.mean(), scale=y.std()),
label='Normal PDF using\ndata mean and std')
ax.legend()
fig.tight_layout()
Using the robust metric model on our synthesis data to see if it can recover the original parameters.
mcmc_key = random.PRNGKey(0)
kernel = NUTS(glm_metric.one_group_robust)
mcmc = MCMC(kernel, num_warmup=250, num_samples=750)
mcmc.run(mcmc_key, jnp.array(y))
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
mean 5.19 0.12 5.19 5.01 5.38 528.65 1.01
nu 3.02 0.32 3.00 2.48 3.54 389.25 1.00
sigma 2.97 0.12 2.97 2.78 3.17 351.15 1.00
Number of divergences: 0
numpyro_glm.plot_diagnostic(mcmc, ['mean', 'sigma', 'nu'])
<xarray.Dataset>
Dimensions: (chain: 1, draw: 750)
Coordinates:
* chain (chain) int64 0
* draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749
Data variables:
mean (chain, draw) float32 5.225 5.124 5.198 5.033 ... 5.201 5.154 5.112
nu (chain, draw) float32 3.03 2.811 3.201 2.794 ... 3.153 3.204 3.224
sigma (chain, draw) float32 3.008 2.894 3.055 2.892 ... 3.009 3.049 3.047
Attributes:
created_at: 2023-05-29T23:56:40.651587
arviz_version: 0.12.1
inference_library: numpyro
inference_library_version: 0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([[5.2252526, 5.1241636, 5.1982026, 5.032679 , 5.0717063, 5.295204 ,
5.3268604, 5.2793097, 5.081863 , 5.24317 , 5.1877646, 5.13194 ,
5.097041 , 5.299657 , 5.100597 , 5.3192477, 5.197055 , 5.197055 ,
5.335744 , 5.362243 , 5.265542 , 5.07826 , 5.033498 , 5.158056 ,
5.043956 , 5.178067 , 5.2793436, 5.1069593, 5.216472 , 5.0388417,
5.3190384, 5.118554 , 5.3592353, 5.297184 , 5.054299 , 5.332943 ,
5.1977186, 5.186516 , 5.0346336, 5.0033903, 5.351208 , 5.346388 ,
5.2818646, 5.0247946, 5.3952494, 5.3240824, 5.038796 , 5.351271 ,
5.0122066, 5.106121 , 5.239272 , 5.0579095, 5.178805 , 5.0543776,
4.985325 , 5.383977 , 5.3296695, 5.2475777, 5.2292995, 5.3917494,
5.176481 , 5.0985913, 5.0745983, 5.3064804, 5.3190703, 5.3065405,
5.073111 , 5.127397 , 5.1585197, 5.004515 , 5.0588617, 5.050305 ,
4.9809313, 5.024628 , 5.301579 , 5.290991 , 5.326011 , 5.0710893,
5.229229 , 5.140679 , 5.171478 , 5.1049523, 5.4097652, 4.9618893,
5.0410748, 5.192658 , 5.006359 , 5.201808 , 5.360339 , 5.293335 ,
5.180471 , 5.2222195, 5.106078 , 5.0473633, 5.445147 , 5.3690085,
5.356301 , 5.314369 , 5.1836066, 5.214131 , 5.137178 , 5.019918 ,
5.1467023, 5.120248 , 5.177939 , 5.1632943, 5.2115273, 5.020947 ,
5.232629 , 5.138412 , 5.414697 , 5.2546616, 5.315323 , 5.3316145,
5.029754 , 5.271732 , 5.2580037, 4.9675856, 4.854825 , 5.1653094,
...
5.1836634, 5.1777854, 5.0916348, 5.1517425, 4.94229 , 5.338954 ,
5.3075275, 4.9510856, 5.355672 , 5.050706 , 5.1464505, 5.1667733,
5.102213 , 5.142809 , 5.2244234, 5.0835414, 5.2575097, 5.3360806,
5.255036 , 5.1094394, 5.2511826, 5.383278 , 5.104622 , 5.010276 ,
5.374145 , 5.1573734, 4.9785376, 5.130616 , 5.086193 , 5.1101737,
5.196637 , 5.2557135, 5.2129374, 5.216792 , 5.3181276, 5.152473 ,
5.1690083, 5.2550898, 5.1514096, 5.111661 , 5.12744 , 5.225885 ,
5.371412 , 5.057382 , 5.313378 , 5.198703 , 5.218926 , 5.1541963,
5.0360336, 5.1063595, 5.252289 , 5.109053 , 5.1100416, 5.109926 ,
5.4005275, 5.206244 , 5.423744 , 5.055256 , 5.262973 , 5.369212 ,
5.2393475, 5.167504 , 5.057998 , 5.1727924, 5.123668 , 5.077471 ,
5.277476 , 5.152246 , 5.3775187, 5.2343764, 5.320696 , 5.0608945,
5.297638 , 5.2998834, 5.0541434, 5.2953877, 5.295298 , 5.37311 ,
5.287508 , 5.177804 , 5.1395016, 4.9199777, 4.9277735, 4.9639325,
5.01235 , 5.1524143, 5.2529345, 5.2919297, 5.009624 , 5.0876145,
5.0876145, 5.080562 , 5.2489963, 5.2815833, 5.2815833, 5.132526 ,
5.18512 , 5.1307178, 5.247394 , 5.1012545, 5.0817285, 5.4392815,
5.058194 , 5.067048 , 5.154724 , 5.154724 , 5.00218 , 5.0527177,
5.2726903, 5.21625 , 5.3625455, 5.200683 , 5.1538286, 5.11227 ]],
dtype=float32)array([[3.0302858, 2.8105295, 3.2014136, 2.7935388, 2.9985032, 2.5012183,
3.0470521, 2.9146385, 3.8911355, 2.3872564, 2.428279 , 2.8282342,
3.0597894, 2.883507 , 3.410241 , 3.499199 , 2.655398 , 2.655398 ,
2.6596022, 3.048019 , 2.955171 , 2.7809849, 2.9156413, 2.7100785,
2.7034278, 2.8511982, 2.8510652, 2.8014731, 2.724396 , 3.0761416,
3.1426032, 2.7679362, 3.0328808, 3.1469035, 3.0298078, 2.9144955,
2.9299479, 2.8942003, 2.8006005, 3.2382188, 3.0903177, 3.0062823,
2.855648 , 3.045023 , 3.5830653, 3.472515 , 3.0960622, 3.2011704,
3.1031928, 3.4019434, 3.2195787, 3.0908947, 3.0280414, 3.4605806,
3.024026 , 3.2612574, 3.1635637, 2.7193754, 3.5790129, 3.5917504,
2.7907567, 3.0398004, 3.322532 , 3.58795 , 3.3344655, 3.4647434,
3.0550628, 3.6886106, 2.9566228, 2.8622136, 2.967669 , 3.297165 ,
3.111523 , 2.8584785, 2.6858768, 3.014183 , 2.8630064, 2.8034956,
3.193019 , 2.812081 , 2.8540955, 2.8850079, 3.1799855, 3.414366 ,
3.5967433, 2.7563825, 2.9630618, 3.2624266, 3.1123424, 3.0870428,
2.725316 , 2.961867 , 3.3157516, 3.561547 , 2.7511356, 2.7754855,
3.1627483, 2.84977 , 3.4220443, 2.3197072, 2.8218226, 2.4043255,
2.6734753, 3.0119743, 3.0013568, 2.9657485, 3.18316 , 3.5373964,
3.4768598, 3.2778487, 3.138904 , 3.0696943, 2.8126042, 2.8059587,
2.7557168, 2.707277 , 2.7965493, 2.5954356, 2.7712111, 2.5360422,
...
3.2974737, 2.608031 , 3.3340845, 3.1168242, 3.371426 , 2.484195 ,
2.5683467, 3.5349627, 3.216056 , 2.9181938, 2.941831 , 3.331728 ,
3.2328591, 3.0172265, 2.7155166, 2.9464223, 2.7724924, 2.6816273,
2.5984507, 3.5753417, 3.6234405, 3.1972866, 3.0058103, 2.4069138,
3.2024572, 3.1628978, 2.934284 , 2.6520753, 2.5028887, 2.4886835,
2.3843923, 2.5833275, 2.5340605, 2.590158 , 2.5927703, 3.110703 ,
3.1948829, 3.28651 , 3.098055 , 3.193087 , 3.2967393, 3.358853 ,
3.7556908, 3.1511366, 2.7770283, 3.003881 , 2.752591 , 3.2983627,
3.2164586, 3.1108088, 3.0520477, 2.815367 , 2.796043 , 2.3711905,
2.977821 , 2.7997148, 2.4689636, 4.148687 , 3.509022 , 3.2277522,
3.6996043, 4.011406 , 3.2496595, 2.5972798, 2.8907168, 3.082364 ,
3.1959908, 3.0188456, 3.1784215, 2.9925566, 3.1192105, 2.6216607,
2.9329195, 2.8352685, 2.9127254, 2.5918915, 3.7815177, 3.487691 ,
3.2226639, 3.4023337, 2.5276363, 2.8847451, 2.830494 , 2.9687147,
3.082112 , 2.9967947, 3.1591215, 2.67426 , 2.7939107, 2.9403849,
2.9403849, 2.820949 , 3.3997335, 2.928617 , 2.928617 , 2.49297 ,
2.991461 , 3.1410506, 3.071447 , 3.406098 , 2.8856087, 2.9846396,
3.5282505, 2.6596055, 2.9256043, 2.9256043, 2.6848874, 2.655949 ,
2.6168833, 3.115498 , 3.1219566, 3.1531057, 3.203574 , 3.224131 ]],
dtype=float32)array([[3.0077553, 2.893796 , 3.0548732, 2.891678 , 2.8712099, 2.8375256,
2.826606 , 2.9720817, 3.263655 , 2.7567494, 2.7421374, 2.9102821,
2.9675138, 2.9290013, 3.1114519, 3.1912806, 2.7879858, 2.7879858,
2.864312 , 2.8857582, 2.9045377, 2.9562974, 2.9119508, 2.9810064,
3.0166419, 2.9744427, 2.8893104, 2.95488 , 2.932642 , 2.9176862,
2.9649498, 2.8088963, 2.9177237, 2.9527833, 2.8886452, 2.962949 ,
2.9381344, 2.8418946, 2.9687228, 3.1053722, 3.0414934, 2.9902358,
2.8479242, 3.091427 , 3.006934 , 3.0294032, 2.9957023, 3.0173352,
3.171182 , 3.0915844, 2.998457 , 3.0699952, 3.1018872, 2.930065 ,
3.0457103, 2.9077508, 3.0073476, 2.8968403, 3.14788 , 3.1162343,
2.9465456, 3.0800798, 2.8778477, 3.0915356, 3.1263587, 2.9847069,
3.1584203, 3.0032854, 3.094975 , 2.8946183, 2.9876933, 3.0804462,
2.974532 , 2.8442578, 2.8904667, 3.00443 , 2.925379 , 2.9279077,
3.0053437, 2.8640773, 2.9765718, 2.9506266, 3.0112603, 3.2222466,
3.1745212, 3.0268316, 2.8382347, 3.2115846, 2.96255 , 2.880831 ,
2.9258776, 2.9168687, 3.1384408, 3.0122447, 2.756393 , 3.0167165,
2.9385285, 2.7939687, 3.0074434, 2.661758 , 2.695194 , 2.7966495,
2.890351 , 3.021478 , 2.91477 , 2.9491024, 3.0953932, 3.014392 ,
3.0871012, 3.0097551, 3.297201 , 3.0935102, 2.9835587, 2.9820387,
2.8816643, 2.9155366, 2.946278 , 2.7360942, 2.7621183, 2.9129941,
...
3.1272457, 3.000786 , 2.9547353, 3.1654298, 3.1500812, 2.7814696,
2.8203695, 3.0340738, 3.0124326, 3.1742008, 3.188828 , 3.1395504,
2.9985125, 3.0134718, 3.045996 , 2.8193517, 2.902401 , 2.9608424,
2.8618124, 3.096644 , 2.9716787, 3.0962532, 2.98246 , 2.7797184,
3.1126304, 2.9470394, 2.9663665, 2.9348457, 2.707904 , 2.7372358,
2.8489664, 2.7766206, 2.8870816, 2.7899947, 2.8266532, 3.0544372,
3.0760381, 3.0813096, 3.097099 , 3.1129563, 3.0446541, 3.1500077,
3.1358485, 3.0074933, 3.1147547, 2.818752 , 2.9542136, 3.0311131,
2.9141474, 3.1281784, 2.9774513, 2.740467 , 3.0301168, 2.88417 ,
2.7082813, 2.867558 , 2.7481692, 3.1149592, 3.3530736, 3.4157538,
3.227895 , 3.1730561, 3.153692 , 2.89388 , 2.861404 , 3.0043843,
3.116252 , 3.0452118, 2.9471655, 2.9493282, 2.8316453, 2.9597483,
2.870535 , 2.9105754, 2.8837337, 2.9166882, 3.1200943, 3.0545652,
3.1874716, 2.8387008, 2.9814684, 3.0079186, 3.067449 , 2.9708407,
2.917632 , 2.831051 , 2.8786335, 2.9744313, 3.0196028, 2.9318008,
2.9318008, 3.0168672, 2.8341033, 3.0753546, 3.0753546, 2.8534124,
2.9582543, 2.9235747, 3.1306293, 3.08414 , 2.9503288, 3.0093284,
3.1156042, 2.906394 , 2.8710785, 2.8710785, 2.8466375, 2.7183504,
2.9024906, 2.939236 , 3.0976253, 3.0085454, 3.0485835, 3.0470424]],
dtype=float32)PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
dtype='int64', name='draw', length=750))<xarray.Dataset>
Dimensions: (chain: 1, draw: 750, y_dim_0: 1000)
Coordinates:
* chain (chain) int64 0
* draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749
* y_dim_0 (y_dim_0) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999
Data variables:
y (chain, draw, y_dim_0) float32 -4.239 -2.245 ... -2.194 -3.763
Attributes:
created_at: 2023-05-29T23:56:40.806417
arviz_version: 0.12.1
inference_library: numpyro
inference_library_version: 0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([ 0, 1, 2, ..., 997, 998, 999])
array([[[-4.23901 , -2.2447944, -3.7616072, ..., -3.9045508,
-2.1726134, -3.819713 ],
[-4.248008 , -2.2480104, -3.7681386, ..., -3.9891958,
-2.1634555, -3.8268566],
[-4.2147064, -2.2553887, -3.7335567, ..., -3.8977628,
-2.1846452, -3.7918887],
...,
[-4.237927 , -2.245505 , -3.753389 , ..., -3.9170911,
-2.172602 , -3.8122478],
[-4.2017717, -2.262422 , -3.7191315, ..., -3.9176188,
-2.189119 , -3.7776227],
[-4.1883435, -2.2696419, -3.70393 , ..., -3.9347727,
-2.1942322, -3.7625935]]], dtype=float32)PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
dtype='int64', name='draw', length=750))PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
990, 991, 992, 993, 994, 995, 996, 997, 998, 999],
dtype='int64', name='y_dim_0', length=1000))<xarray.Dataset>
Dimensions: (chain: 1, draw: 750)
Coordinates:
* chain (chain) int64 0
* draw (draw) int64 0 1 2 3 4 5 6 7 ... 742 743 744 745 746 747 748 749
Data variables:
diverging (chain, draw) bool False False False False ... False False False
Attributes:
created_at: 2023-05-29T23:56:40.653699
arviz_version: 0.12.1
inference_library: numpyro
inference_library_version: 0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([[False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
...
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False]])PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
dtype='int64', name='draw', length=750))<xarray.Dataset>
Dimensions: (y_dim_0: 1000)
Coordinates:
* y_dim_0 (y_dim_0) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999
Data variables:
y (y_dim_0) float32 -1.97 6.648 -0.6971 5.801 ... 11.52 6.219 -0.8494
Attributes:
created_at: 2023-05-29T23:56:40.807396
arviz_version: 0.12.1
inference_library: numpyro
inference_library_version: 0.9.2array([ 0, 1, 2, ..., 997, 998, 999])
array([-1.97044837e+00, 6.64772511e+00, -6.97052896e-01, 5.80093575e+00,
4.54183054e+00, 1.12101612e+01, -4.44931060e-01, 5.82078266e+00,
5.82591629e+00, 3.17236757e+00, 8.37355042e+00, 6.53492546e+00,
4.20800734e+00, 1.01482553e+01, 1.41755648e+01, 1.24651060e+01,
-2.17906322e+01, 4.92048836e+00, 4.85010052e+00, 1.02006369e+01,
2.13908386e+00, 2.05172256e-01, 1.30528660e+01, -2.62980032e+00,
7.23637486e+00, 4.05275726e+00, 3.06575513e+00, 7.08320904e+00,
5.54312420e+00, 1.49230599e+00, 1.01288261e+01, 8.57089424e+00,
1.14836349e+01, 5.61097431e+00, 3.13621688e+00, 9.58228970e+00,
8.92821312e+00, 6.76997042e+00, 1.93540821e+01, 4.11432600e+00,
1.19134083e+01, 9.42990589e+00, 5.60357380e+00, -1.71314931e+00,
4.70048040e-01, 3.88558745e+00, 9.75277519e+00, 5.94671679e+00,
6.14265108e+00, -3.56223047e-01, 1.22409928e+00, 8.90643311e+00,
1.25334752e+00, 7.24736023e+00, 4.03608942e+00, 6.78641033e+00,
-2.69396567e+00, 5.06562614e+00, -1.64162469e+00, 2.85413051e+00,
2.39020967e+00, 4.40503359e+00, 8.13211632e+00, 5.34186506e+00,
6.01666784e+00, 1.21046562e+01, 2.13109040e+00, 2.75819302e-01,
2.81802750e+00, 7.41340733e+00, 6.28000736e+00, 6.85987091e+00,
1.15999956e+01, 4.33733034e+00, 7.85834742e+00, 4.79054356e+00,
-1.75435328e+00, 2.27244854e+00, 1.29107180e+01, 1.78298130e+01,
...
7.67104053e+00, 7.03784704e+00, 2.43009663e+00, 6.72932577e+00,
1.44165316e+01, 5.52429914e+00, 6.58890247e+00, 9.42949772e+00,
1.42814171e+00, 2.61827564e+00, 5.46275520e+00, 5.43681955e+00,
4.64565945e+00, -1.88100077e-02, 1.12393579e+01, 6.41488600e+00,
5.92880297e+00, 7.56008530e+00, 5.77487803e+00, 5.70153856e+00,
1.06024427e+01, 9.22666454e+00, 1.26327670e+00, 6.25757360e+00,
6.05027056e+00, 4.46348190e+00, 5.07479429e+00, 3.75031829e+00,
1.08985357e+01, 8.36709881e+00, 3.84162974e+00, 8.93783290e-03,
6.21623099e-01, 8.29607487e+00, 6.74988127e+00, 4.13313913e+00,
3.56779456e-01, 6.27172899e+00, 1.39344180e+00, 8.97291946e+00,
1.39694996e+01, 7.24138546e+00, 9.02207565e+00, 6.50661051e-01,
-2.29273844e+00, 6.50683498e+00, 2.56033993e+00, 6.01500273e+00,
1.85625410e+00, 7.55919504e+00, 5.12628460e+00, 2.19768524e+00,
6.08607101e+00, 4.16822290e+00, -2.76709747e+00, 4.77808809e+00,
1.25680656e+01, 1.24284792e+00, 6.16621590e+00, 4.93042040e+00,
4.13558197e+00, 1.31871214e+01, 2.85820341e+00, 6.72703362e+00,
7.92491627e+00, 6.70488834e+00, 9.36034393e+00, -7.42135286e-01,
5.76647472e+00, 2.60789084e+00, 1.30372515e+01, 5.43683338e+00,
7.20951414e+00, 1.15234461e+01, 6.21906805e+00, -8.49427700e-01],
dtype=float32)PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
990, 991, 992, 993, 994, 995, 996, 997, 998, 999],
dtype='int64', name='y_dim_0', length=1000))fig = plots.plot_st_2(mcmc, y)
mcmc_key = random.PRNGKey(0)
kernel = NUTS(glm_metric.one_group_robust)
mcmc = MCMC(kernel, num_warmup=250, num_samples=750)
mcmc.run(mcmc_key, jnp.array(smart_group_data.Score.values))
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
mean 107.38 2.96 107.35 102.54 112.11 499.25 1.00
nu 9.63 13.31 5.73 1.20 19.70 158.39 1.01
sigma 19.83 3.53 19.78 14.20 25.66 235.03 1.00
Number of divergences: 0
numpyro_glm.plot_diagnostic(mcmc, ['mean', 'sigma', 'nu'])
<xarray.Dataset>
Dimensions: (chain: 1, draw: 750)
Coordinates:
* chain (chain) int64 0
* draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749
Data variables:
mean (chain, draw) float32 105.6 106.3 108.6 104.4 ... 108.3 106.8 104.6
nu (chain, draw) float32 6.38 6.304 6.77 2.442 ... 9.179 8.01 13.55
sigma (chain, draw) float32 22.7 23.93 23.56 16.29 ... 22.12 23.89 20.27
Attributes:
created_at: 2023-05-29T23:56:45.027781
arviz_version: 0.12.1
inference_library: numpyro
inference_library_version: 0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([[105.55507 , 106.34972 , 108.5634 , 104.371185, 103.78115 ,
108.39524 , 104.8494 , 105.39043 , 110.073586, 107.43894 ,
108.44629 , 106.93347 , 102.90011 , 106.07971 , 106.80335 ,
110.15082 , 109.17264 , 107.16441 , 111.50457 , 113.58138 ,
102.10726 , 108.93787 , 109.68986 , 110.641335, 105.38782 ,
108.96525 , 109.787346, 106.19191 , 107.32411 , 103.05021 ,
109.8304 , 106.86794 , 110.42675 , 104.287094, 103.54164 ,
106.96897 , 107.24161 , 103.478745, 102.1156 , 100.63876 ,
108.324066, 108.720985, 108.63 , 111.166565, 109.553505,
108.90091 , 108.67291 , 105.97729 , 110.96842 , 105.5543 ,
104.999115, 106.73197 , 107.25278 , 105.5463 , 108.59721 ,
112.095764, 111.39011 , 110.41567 , 101.798965, 107.15456 ,
110.26754 , 108.96618 , 105.98647 , 105.55314 , 106.5947 ,
109.240746, 105.711555, 105.50618 , 106.02377 , 111.31043 ,
103.732605, 102.67975 , 107.562675, 106.68475 , 106.863625,
109.01302 , 104.004974, 107.06012 , 107.67128 , 106.929245,
107.940384, 106.64516 , 112.43416 , 113.13488 , 112.6087 ,
108.9748 , 104.34741 , 105.08496 , 111.78247 , 105.56174 ,
106.40983 , 108.041275, 108.06947 , 105.43695 , 108.262535,
111.17323 , 104.075134, 104.57462 , 104.88279 , 105.00093 ,
...
103.982544, 106.368774, 111.052895, 109.16271 , 110.65621 ,
110.26432 , 107.34818 , 112.2538 , 104.68615 , 103.45305 ,
107.17469 , 107.23535 , 108.20257 , 107.94324 , 107.99706 ,
105.0354 , 108.69497 , 107.565155, 105.54674 , 108.33376 ,
106.40116 , 106.42444 , 105.807465, 102.839745, 109.73883 ,
111.335785, 109.93152 , 110.02944 , 109.36531 , 111.29636 ,
110.44721 , 104.324524, 109.9587 , 112.545685, 111.32425 ,
113.93877 , 111.03316 , 106.69128 , 105.634125, 107.761185,
103.34582 , 104.21555 , 104.75715 , 105.540276, 106.058044,
110.38669 , 110.98076 , 104.838936, 103.348526, 111.428215,
111.38477 , 111.552666, 107.4357 , 105.75314 , 106.130035,
105.13054 , 108.275925, 113.34639 , 103.81799 , 103.676956,
105.02578 , 106.404274, 101.28366 , 105.969345, 105.57954 ,
105.5759 , 107.31593 , 108.7727 , 109.47214 , 105.593376,
107.311165, 104.327995, 104.20603 , 107.11131 , 109.11189 ,
109.00499 , 107.64095 , 106.795364, 107.13538 , 110.425575,
107.195335, 102.86533 , 111.92812 , 102.704185, 102.664154,
104.04894 , 109.90849 , 104.54301 , 104.901215, 106.85347 ,
105.591606, 108.840034, 108.32749 , 106.838234, 104.625496]],
dtype=float32)array([[ 6.380366 , 6.3039827, 6.770382 , 2.4415083, 4.205345 ,
2.3898754, 2.4554567, 4.751895 , 30.20579 , 6.2195134,
3.5312512, 4.76675 , 5.0081363, 3.1247911, 21.360893 ,
23.807055 , 10.058445 , 7.779872 , 11.477498 , 21.622583 ,
36.68576 , 5.816049 , 7.4071264, 4.7910275, 3.411479 ,
5.6175413, 3.6065092, 3.8919663, 2.406946 , 6.614321 ,
17.233143 , 2.8523195, 9.354761 , 5.276261 , 4.406109 ,
3.6853158, 3.4503999, 6.083024 , 3.162562 , 7.787067 ,
18.096825 , 6.6824446, 11.993019 , 4.9092307, 12.329353 ,
10.28212 , 12.334496 , 21.102882 , 68.73028 , 28.93483 ,
24.624996 , 22.925674 , 18.960793 , 17.155924 , 4.483561 ,
9.103407 , 8.090643 , 3.4731753, 21.862223 , 3.5392842,
6.277596 , 4.3358054, 10.864571 , 1.2639813, 1.8327861,
5.8132896, 2.6647193, 10.151241 , 4.558977 , 11.674158 ,
13.655367 , 16.772926 , 25.298958 , 1.9189749, 5.538587 ,
3.6070976, 9.118575 , 4.755516 , 4.6680737, 3.8454957,
3.030088 , 3.3910158, 6.3802996, 3.3749137, 4.504877 ,
10.809336 , 6.2159004, 15.667531 , 4.7564974, 8.262956 ,
9.560621 , 9.199344 , 2.9713812, 3.7156253, 5.5498013,
11.150637 , 5.7294626, 4.002439 , 3.909546 , 2.221348 ,
...
10.571391 , 13.065151 , 6.8497195, 3.7558386, 3.5340605,
3.298055 , 2.0803087, 4.189632 , 7.7399063, 3.6138294,
3.9523766, 6.6232204, 6.2300615, 2.3853402, 2.850403 ,
5.229487 , 3.8615527, 3.9326599, 3.8529232, 8.671819 ,
6.301786 , 7.7922463, 5.2313375, 2.3083024, 65.67436 ,
100.1539 , 48.767944 , 33.916534 , 32.376404 , 9.096143 ,
9.358721 , 7.4668074, 17.171707 , 28.13654 , 131.20631 ,
99.750786 , 82.21945 , 150.44986 , 4.315606 , 15.028996 ,
19.251741 , 16.819944 , 54.600857 , 10.102056 , 13.521818 ,
12.604932 , 16.998135 , 30.99062 , 8.554714 , 9.850973 ,
9.87082 , 7.277955 , 5.83263 , 5.397276 , 4.8431597,
7.5538855, 3.3288703, 14.987564 , 7.951935 , 6.340927 ,
11.855715 , 1.8256805, 5.054306 , 7.583917 , 7.169403 ,
7.4230094, 6.0162477, 7.335036 , 3.146756 , 2.879428 ,
7.331243 , 9.373186 , 6.303137 , 11.966533 , 7.999712 ,
4.57536 , 4.0942593, 23.454334 , 4.570996 , 7.025208 ,
3.1406336, 4.4695325, 8.906318 , 28.438923 , 6.1750894,
6.5125046, 3.4833727, 3.2517338, 2.7749972, 2.746849 ,
9.603409 , 9.797531 , 9.178833 , 8.009731 , 13.549246 ]],
dtype=float32)array([[22.70183 , 23.928846 , 23.562433 , 16.289011 , 15.22005 ,
15.044127 , 16.297873 , 17.127993 , 28.735075 , 16.449953 ,
21.467855 , 19.445892 , 21.499855 , 22.8823 , 21.110508 ,
23.649979 , 26.60918 , 30.828 , 25.248333 , 22.17687 ,
21.898846 , 22.394032 , 19.247723 , 20.5953 , 19.776379 ,
18.489138 , 18.343906 , 16.611282 , 17.876183 , 18.076878 ,
20.283577 , 20.365068 , 17.932829 , 17.479145 , 16.480396 ,
16.9544 , 17.29114 , 17.868351 , 20.70608 , 23.331121 ,
21.984573 , 21.473406 , 25.425516 , 20.035864 , 24.10131 ,
22.57912 , 24.053577 , 25.006847 , 25.977291 , 23.044413 ,
23.399036 , 22.522333 , 22.15956 , 26.54244 , 22.261833 ,
18.223854 , 20.073006 , 18.710468 , 23.895777 , 18.726654 ,
17.243109 , 23.96953 , 17.760885 , 16.729784 , 16.708885 ,
14.98461 , 19.943176 , 18.423456 , 20.433556 , 21.729021 ,
21.142508 , 24.806545 , 26.408274 , 16.245544 , 15.434979 ,
20.393446 , 19.199934 , 20.03922 , 18.835415 , 17.204237 ,
18.52296 , 18.065552 , 20.298271 , 17.421255 , 17.687933 ,
19.639505 , 19.606153 , 24.838633 , 18.426218 , 21.44172 ,
22.273306 , 19.690224 , 16.428911 , 21.521833 , 21.86512 ,
20.675076 , 21.54007 , 17.91468 , 17.8993 , 12.195638 ,
...
23.518085 , 20.518326 , 24.155817 , 21.604746 , 13.404511 ,
13.920298 , 14.833247 , 15.817091 , 24.094957 , 17.81413 ,
17.429472 , 16.82248 , 16.594954 , 19.392002 , 17.35203 ,
17.853443 , 16.643085 , 19.361786 , 16.21468 , 19.727438 ,
23.818949 , 22.251196 , 18.389307 , 17.08416 , 28.378626 ,
25.069221 , 23.163519 , 25.550022 , 23.057487 , 24.160658 ,
19.997173 , 22.610897 , 22.570648 , 20.691324 , 24.283613 ,
22.753681 , 25.524794 , 26.249626 , 18.095999 , 22.020086 ,
21.696339 , 26.190704 , 22.211027 , 25.576822 , 24.824036 ,
22.590075 , 21.377657 , 25.288351 , 21.2128 , 19.17415 ,
18.732403 , 18.372541 , 23.47746 , 18.848772 , 19.336374 ,
19.807137 , 19.745995 , 25.764769 , 22.528057 , 24.722696 ,
17.32529 , 20.22306 , 21.694635 , 18.270115 , 17.884037 ,
17.52701 , 16.287584 , 17.014729 , 18.675406 , 17.37952 ,
19.083565 , 17.537128 , 19.86175 , 17.674118 , 21.235544 ,
24.079872 , 14.4767065, 21.373323 , 18.028996 , 22.986351 ,
16.297935 , 18.35883 , 20.795593 , 22.46605 , 20.826462 ,
19.824308 , 18.015593 , 15.6124115, 13.76765 , 14.794092 ,
21.357248 , 22.982603 , 22.121471 , 23.893038 , 20.269114 ]],
dtype=float32)PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
dtype='int64', name='draw', length=750))<xarray.Dataset>
Dimensions: (chain: 1, draw: 750, y_dim_0: 63)
Coordinates:
* chain (chain) int64 0
* draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749
* y_dim_0 (y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62
Data variables:
y (chain, draw, y_dim_0) float32 -4.095 -4.083 ... -5.734 -4.017
Attributes:
created_at: 2023-05-29T23:56:45.160203
arviz_version: 0.12.1
inference_library: numpyro
inference_library_version: 0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
54, 55, 56, 57, 58, 59, 60, 61, 62])array([[[-4.094567 , -4.082753 , -4.281054 , ..., -4.11488 ,
-5.4502263, -4.1267333],
[-4.1526117, -4.1339474, -4.336127 , ..., -4.174086 ,
-5.3432612, -4.1656785],
[-4.1596594, -4.117912 , -4.389084 , ..., -4.1904535,
-5.235262 , -4.1275744],
...,
[-4.0878344, -4.0446672, -4.3361053, ..., -4.1206446,
-5.3121624, -4.05793 ],
[-4.146661 , -4.1236835, -4.3355103, ..., -4.1694927,
-5.312559 , -4.149831 ],
[-3.9554722, -3.953834 , -4.1518583, ..., -3.974376 ,
-5.7338276, -4.017195 ]]], dtype=float32)PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
dtype='int64', name='draw', length=750))PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62],
dtype='int64', name='y_dim_0'))<xarray.Dataset>
Dimensions: (chain: 1, draw: 750)
Coordinates:
* chain (chain) int64 0
* draw (draw) int64 0 1 2 3 4 5 6 7 ... 742 743 744 745 746 747 748 749
Data variables:
diverging (chain, draw) bool False False False False ... False False False
Attributes:
created_at: 2023-05-29T23:56:45.029825
arviz_version: 0.12.1
inference_library: numpyro
inference_library_version: 0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([[False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
...
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False]])PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
dtype='int64', name='draw', length=750))<xarray.Dataset>
Dimensions: (y_dim_0: 63)
Coordinates:
* y_dim_0 (y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62
Data variables:
y (y_dim_0) int32 102 107 92 101 110 68 119 ... 97 139 72 100 144 112
Attributes:
created_at: 2023-05-29T23:56:45.161050
arviz_version: 0.12.1
inference_library: numpyro
inference_library_version: 0.9.2array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
54, 55, 56, 57, 58, 59, 60, 61, 62])array([102, 107, 92, 101, 110, 68, 119, 106, 99, 103, 90, 93, 79,
89, 137, 119, 126, 110, 71, 114, 100, 95, 91, 99, 97, 106,
106, 129, 115, 124, 137, 73, 69, 95, 102, 116, 111, 134, 102,
110, 139, 112, 122, 84, 129, 112, 127, 106, 113, 109, 208, 114,
107, 50, 169, 133, 50, 97, 139, 72, 100, 144, 112], dtype=int32)PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62],
dtype='int64', name='y_dim_0'))fig = plots.plot_st_2(
mcmc, smart_group_data.Score.values,
mean_comp_val=100,
sigma_comp_val=15,
effsize_comp_val=0)
fig = numpyro_glm.plot_pairwise_scatter(mcmc, ['mean', 'sigma', 'nu'])
mcmc_key = random.PRNGKey(0)
kernel = NUTS(glm_metric.multi_groups_robust)
mcmc = MCMC(kernel, num_warmup=250, num_samples=750)
mcmc.run(
mcmc_key,
jnp.array(iq_data['Score'].values),
jnp.array(iq_data['Group'].cat.codes.values),
len(iq_data['Group'].cat.categories),
)
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
mean[0] 99.29 1.76 99.20 96.70 102.33 806.66 1.00
mean[1] 107.16 2.62 107.14 102.90 111.36 910.67 1.00
nu 3.86 1.68 3.45 1.80 5.73 403.33 1.01
sigma[0] 11.30 1.74 11.10 8.37 13.97 655.02 1.00
sigma[1] 17.91 2.72 17.71 13.80 22.56 534.41 1.00
Number of divergences: 0
numpyro_glm.plot_diagnostic(mcmc, ['mean', 'sigma', 'nu'])
<xarray.Dataset>
Dimensions: (chain: 1, draw: 750, mean_dim_0: 2, sigma_dim_0: 2)
Coordinates:
* chain (chain) int64 0
* draw (draw) int64 0 1 2 3 4 5 6 7 ... 743 744 745 746 747 748 749
* mean_dim_0 (mean_dim_0) int64 0 1
* sigma_dim_0 (sigma_dim_0) int64 0 1
Data variables:
mean (chain, draw, mean_dim_0) float32 98.21 108.0 ... 94.3 103.5
nu (chain, draw) float32 4.904 4.618 2.711 ... 2.909 2.738 2.962
sigma (chain, draw, sigma_dim_0) float32 11.06 19.12 ... 10.23 17.3
Attributes:
created_at: 2023-05-29T23:56:52.320366
arviz_version: 0.12.1
inference_library: numpyro
inference_library_version: 0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([0, 1])
array([0, 1])
array([[[ 98.21455 , 107.99677 ],
[ 98.949814, 109.36058 ],
[ 98.70745 , 102.67616 ],
...,
[103.06653 , 106.81406 ],
[ 95.39924 , 107.23023 ],
[ 94.29891 , 103.466194]]], dtype=float32)array([[ 4.90352 , 4.617958 , 2.7109103, 2.9921021, 3.342796 ,
6.3293333, 3.3233557, 4.0153794, 3.9350548, 2.6246867,
7.292013 , 3.497366 , 2.326993 , 2.8841724, 2.2890396,
2.4631345, 4.2260084, 4.054314 , 4.1366997, 4.4986525,
2.7777195, 5.1717806, 2.021302 , 5.061619 , 5.94571 ,
3.14477 , 2.4800138, 2.869721 , 2.4497128, 2.564416 ,
2.9521792, 2.1315536, 2.0498023, 3.8954706, 2.0831642,
2.2922893, 3.3183112, 3.122521 , 3.0120344, 4.1892695,
2.5077238, 2.5695941, 2.2404354, 5.06931 , 3.70294 ,
3.3386614, 2.5039854, 2.3118012, 3.1710594, 2.3974266,
2.5402591, 1.959218 , 3.2319102, 3.2446227, 5.831307 ,
3.9336195, 3.8203804, 2.7046685, 3.9451866, 3.540849 ,
3.7571568, 4.5666227, 3.1252599, 2.6628523, 5.4824295,
3.590702 , 3.329475 , 4.470755 , 3.7153351, 3.0417259,
7.0245333, 2.8254995, 4.068351 , 2.6363196, 6.1589856,
3.6985042, 4.0642147, 2.3869684, 3.4503007, 2.638822 ,
2.5350208, 2.5853083, 2.5492935, 2.7262902, 5.2686357,
4.664855 , 2.334038 , 3.1042252, 2.6422722, 2.2850242,
2.8939147, 2.500326 , 4.2968397, 2.880906 , 3.437034 ,
3.8430223, 6.010608 , 2.493542 , 3.2565322, 3.7440398,
...
4.7892776, 4.238168 , 1.6762153, 3.2976294, 4.928863 ,
2.3080652, 2.6339862, 4.841153 , 2.2997952, 1.9764836,
5.1290674, 2.9647245, 3.793073 , 3.8011038, 4.1641083,
4.1243477, 6.333407 , 2.8298962, 3.2586808, 3.278459 ,
3.8523297, 3.7573292, 3.8873568, 3.3295789, 3.7441294,
3.8092866, 3.505789 , 3.3353996, 3.1752791, 2.874391 ,
3.7779677, 2.5379913, 6.0346184, 3.2346685, 2.308152 ,
3.2602334, 9.1525135, 8.748209 , 4.2054396, 6.5660415,
4.016606 , 7.6482983, 4.4520617, 4.8964925, 3.7809463,
1.9816743, 3.64076 , 5.9388037, 2.9451606, 3.6500442,
4.868866 , 4.1009865, 5.2459726, 2.457327 , 2.44139 ,
7.370051 , 2.4880009, 3.4568632, 3.994024 , 4.073059 ,
5.5238442, 4.672189 , 6.579589 , 8.251197 , 3.1003256,
7.0111647, 2.9936523, 3.945934 , 2.7069554, 7.7330813,
3.0613446, 3.1943877, 5.15743 , 2.1961014, 2.0084407,
4.7910404, 4.3722544, 4.7984753, 4.1131425, 2.9987164,
3.3225734, 2.3282595, 2.9577506, 4.8833666, 5.5551987,
7.0451884, 2.3839269, 2.004416 , 2.6664572, 3.5382771,
2.3100057, 2.0453377, 2.9086428, 2.7375379, 2.9617047]],
dtype=float32)array([[[11.057621 , 19.119974 ],
[12.7135935, 21.039143 ],
[10.781828 , 14.102758 ],
...,
[10.174658 , 16.26035 ],
[11.224569 , 15.996319 ],
[10.232676 , 17.299318 ]]], dtype=float32)PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
dtype='int64', name='draw', length=750))PandasIndex(Int64Index([0, 1], dtype='int64', name='mean_dim_0'))
PandasIndex(Int64Index([0, 1], dtype='int64', name='sigma_dim_0'))
<xarray.Dataset>
Dimensions: (chain: 1, draw: 750, y_dim_0: 120)
Coordinates:
* chain (chain) int64 0
* draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749
* y_dim_0 (y_dim_0) int64 0 1 2 3 4 5 6 7 ... 112 113 114 115 116 117 118 119
Data variables:
y (chain, draw, y_dim_0) float32 -3.979 -3.922 ... -3.338 -5.223
Attributes:
created_at: 2023-05-29T23:56:52.493138
arviz_version: 0.12.1
inference_library: numpyro
inference_library_version: 0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119])array([[[-3.9789479, -3.9219544, -4.3142 , ..., -3.5036287,
-3.5036287, -5.6269355],
[-4.092538 , -4.0267043, -4.4053917, ..., -3.6455019,
-3.6455019, -5.424239 ],
[-3.6572642, -3.718939 , -4.011526 , ..., -3.5696988,
-3.5696988, -5.58177 ],
...,
[-3.8501332, -3.7921975, -4.282725 , ..., -3.8901894,
-3.8901894, -6.1303153],
[-3.8524227, -3.7809741, -4.3153687, ..., -3.457511 ,
-3.457511 , -5.203177 ],
[-3.8573813, -3.8802965, -4.126558 , ..., -3.3382494,
-3.3382494, -5.222832 ]]], dtype=float32)PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
dtype='int64', name='draw', length=750))PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
110, 111, 112, 113, 114, 115, 116, 117, 118, 119],
dtype='int64', name='y_dim_0', length=120))<xarray.Dataset>
Dimensions: (chain: 1, draw: 750)
Coordinates:
* chain (chain) int64 0
* draw (draw) int64 0 1 2 3 4 5 6 7 ... 742 743 744 745 746 747 748 749
Data variables:
diverging (chain, draw) bool False False False False ... False False False
Attributes:
created_at: 2023-05-29T23:56:52.322597
arviz_version: 0.12.1
inference_library: numpyro
inference_library_version: 0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([[False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
...
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False]])PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
dtype='int64', name='draw', length=750))<xarray.Dataset>
Dimensions: (y_dim_0: 120)
Coordinates:
* y_dim_0 (y_dim_0) int64 0 1 2 3 4 5 6 7 ... 112 113 114 115 116 117 118 119
Data variables:
y (y_dim_0) int32 102 107 92 101 110 68 119 ... 82 138 99 93 93 72
Attributes:
created_at: 2023-05-29T23:56:52.493967
arviz_version: 0.12.1
inference_library: numpyro
inference_library_version: 0.9.2array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119])array([102, 107, 92, 101, 110, 68, 119, 106, 99, 103, 90, 93, 79,
89, 137, 119, 126, 110, 71, 114, 100, 95, 91, 99, 97, 106,
106, 129, 115, 124, 137, 73, 69, 95, 102, 116, 111, 134, 102,
110, 139, 112, 122, 84, 129, 112, 127, 106, 113, 109, 208, 114,
107, 50, 169, 133, 50, 97, 139, 72, 100, 144, 112, 109, 98,
106, 101, 100, 111, 117, 104, 106, 89, 84, 88, 94, 78, 108,
102, 95, 99, 90, 116, 97, 107, 102, 91, 94, 95, 86, 108,
115, 108, 88, 102, 102, 120, 112, 100, 105, 105, 88, 82, 111,
96, 92, 109, 91, 92, 123, 61, 59, 105, 184, 82, 138, 99,
93, 93, 72], dtype=int32)PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
110, 111, 112, 113, 114, 115, 116, 117, 118, 119],
dtype='int64', name='y_dim_0', length=120))Plot the resulting posteriors.
fig = plots.plot_st_2(
mcmc,
iq_data[iq_data['Group'] == 'Placebo']['Score'].values,
mean_coords=dict(mean_dim_0=0),
sigma_coords=dict(sigma_dim_0=0),
figtitle='Placebo Posteriors')
fig = plots.plot_st_2(
mcmc,
iq_data[iq_data['Group'] == 'Smart Drug']['Score'].values,
mean_coords=dict(mean_dim_0=1),
sigma_coords=dict(sigma_dim_0=1),
figtitle='Smart Drug Posteriors')
Then, we will plot the difference between the two groups.
idata = az.from_numpyro(mcmc)
posteriors = idata.posterior
fig, axes = plt.subplots(ncols=3, figsize=(18, 6))
fig.suptitle('Difference between Smart Drug and Placebo')
# Plot mean difference.
ax = axes[0]
mean_difference = (posteriors['mean'].sel(dict(mean_dim_0=1))
- posteriors['mean'].sel(dict(mean_dim_0=0)))
az.plot_posterior(
mean_difference,
hdi_prob=0.95,
ref_val=0,
rope=(-0.5, 0.5),
point_estimate='mode',
kind='hist',
ax=ax)
ax.set_title('Difference of Means')
ax.set_xlabel('$\mu[1] - \mu[0]$')
# Plot sigma difference.
ax = axes[1]
sigma_difference = (posteriors['sigma'].sel(dict(sigma_dim_0=1))
- posteriors['sigma'].sel(dict(sigma_dim_0=0)))
az.plot_posterior(
sigma_difference,
hdi_prob=0.95,
ref_val=0,
rope=(-0.5, 0.5),
point_estimate='mode',
kind='hist',
ax=ax)
ax.set_title('Difference of Scales')
ax.set_xlabel('$\sigma[1] - \sigma[0]$')
# Plot effect size.
ax = axes[2]
sigmas_squared = (posteriors['sigma'].sel(dict(sigma_dim_0=1))**2
+ posteriors['sigma'].sel(dict(sigma_dim_0=0))**2)
effect_size = mean_difference / np.sqrt(sigmas_squared)
az.plot_posterior(
effect_size,
hdi_prob=0.95,
ref_val=0,
rope=(-0.1, 0.1),
point_estimate='mode',
kind='hist',
ax=ax)
ax.set_title('Effect Size')
ax.set_xlabel('$(\mu[1] - \mu[0]) / \sqrt{\sigma[1]^2 + \sigma[0]^2}$')
fig.tight_layout()