%cd ..
%load_ext autoreload
%autoreload 2
/home/runner/work/numpyro-doing-bayesian/numpyro-doing-bayesian
import arviz as az
import jax.numpy as jnp
import jax.random as random
import matplotlib.pyplot as plt
import numpy as np
import numpyro
from numpyro.infer import MCMC, NUTS
import numpyro_glm
import numpyro_glm.metric.models as glm_metric
import numpyro_glm.metric.plots as plots
import pandas as pd
from scipy.stats import norm, t
numpyro.render_model( glm_metric.one_group, model_args=(jnp.ones(5), ), render_params=True)
MEAN = 5
STD = 3
N = 100
y = np.random.normal(MEAN, STD, N)
# Plot the histogram and true PDF of normal distribution.
fig, ax = plt.subplots()
ax.hist(y, density=True, label='Histogram of $y$')
xmin, xmax = ax.get_xlim()
p = np.linspace(xmin, xmax, 1000)
ax.plot(p, norm.pdf(p, loc=MEAN, scale=STD), label='Normal PDF')
ax.legend()
fig.tight_layout()
Now, we'll try to apply the metric model on that data to see if it can recover the parameter.
mcmc_key = random.PRNGKey(0)
kernel = NUTS(glm_metric.one_group)
mcmc = MCMC(kernel, num_warmup=250, num_samples=750)
mcmc.run(mcmc_key, jnp.array(y))
mcmc.print_summary()
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
                mean       std    median      5.0%     95.0%     n_eff     r_hat
      mean      4.96      0.31      4.97      4.53      5.51    517.16      1.00
       std      3.05      0.22      3.03      2.73      3.45    746.19      1.00
Number of divergences: 0
Plot diagnostics plot to see if the MCMC chains are well-behaved.
numpyro_glm.plot_diagnostic(mcmc, ['mean', 'std'])
<xarray.Dataset>
Dimensions:  (chain: 1, draw: 750)
Coordinates:
  * chain    (chain) int64 0
  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749
Data variables:
    mean     (chain, draw) float32 5.209 4.529 5.347 5.227 ... 4.773 4.991 4.968
    std      (chain, draw) float32 2.917 3.103 2.868 2.874 ... 2.672 3.41 3.339
Attributes:
    created_at:                 2023-05-29T23:56:31.752566
    arviz_version:              0.12.1
    inference_library:          numpyro
    inference_library_version:  0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([[5.209106 , 4.528548 , 5.3471146, 5.227207 , 4.2945375, 4.9182925,
        4.7712283, 5.2968564, 5.132386 , 4.844204 , 5.586605 , 4.487925 ,
        5.3550844, 5.4572015, 4.448584 , 4.9535537, 5.02507  , 4.8676105,
        5.031705 , 5.547575 , 5.0685067, 4.779368 , 5.002617 , 5.0431023,
        5.4613724, 4.8222647, 5.0249243, 4.903239 , 5.109538 , 4.791892 ,
        5.222782 , 5.2325306, 4.6903687, 4.693336 , 4.9970613, 4.7783775,
        4.8024116, 5.4343495, 5.04768  , 4.604811 , 4.6258793, 4.529574 ,
        4.5864644, 4.955139 , 4.8493094, 5.0941453, 4.96266  , 4.893078 ,
        5.227385 , 4.9269013, 5.0402994, 5.046525 , 5.125822 , 4.6672425,
        4.203383 , 4.907324 , 4.990067 , 4.1970205, 4.35923  , 4.966628 ,
        4.88882  , 5.092228 , 5.1102777, 4.8794503, 5.1425686, 4.9433007,
        4.9666333, 5.1734595, 4.7041483, 5.017935 , 4.9279866, 5.1911583,
        5.1911583, 5.011348 , 5.0969815, 4.9910264, 4.695913 , 4.8254447,
        4.7390966, 5.0573435, 4.6840773, 5.1652255, 5.024248 , 4.6640615,
        4.5730023, 5.0042996, 4.979213 , 4.7948694, 4.718663 , 4.583088 ,
        4.7011843, 5.0055323, 4.8422976, 4.8485436, 4.965094 , 4.7230763,
        5.069906 , 4.9947624, 4.9111443, 4.2580767, 5.2471714, 4.5457106,
        4.775864 , 5.025947 , 5.061445 , 4.8522716, 5.1862683, 5.0539246,
        5.0416985, 5.168807 , 4.7731266, 4.703271 , 4.8041005, 4.589089 ,
        4.500424 , 4.4929605, 4.4101706, 5.3536725, 5.31576  , 4.583191 ,
...
        4.989824 , 4.7870226, 5.1676106, 4.773499 , 4.433352 , 5.527904 ,
        5.376922 , 5.3109035, 4.531634 , 5.121908 , 5.2329516, 4.5748124,
        4.9590473, 4.995907 , 4.777923 , 4.907775 , 5.110473 , 5.484889 ,
        4.984773 , 4.9303713, 4.9617133, 4.9597883, 4.201158 , 4.108816 ,
        5.5899467, 4.4740453, 5.1457696, 5.0359325, 5.559925 , 5.5884337,
        5.177686 , 5.4082894, 4.67077  , 4.835922 , 5.263118 , 5.2465863,
        5.2224298, 5.4455686, 5.6108904, 4.6170073, 4.7014694, 4.8181043,
        5.2965465, 4.779739 , 4.9903846, 4.6685143, 5.001776 , 5.2450037,
        4.6075706, 5.258145 , 5.197925 , 4.683494 , 5.085534 , 5.0441504,
        5.057175 , 4.7345133, 5.0199018, 5.068657 , 4.5667806, 4.5667806,
        5.1042504, 5.0103126, 5.3162837, 5.049756 , 5.0858436, 4.8877597,
        4.871423 , 4.964435 , 4.952388 , 5.177528 , 4.7585597, 4.937545 ,
        5.2067914, 5.404242 , 5.3792343, 5.4727798, 5.503424 , 3.9227598,
        4.616737 , 5.349469 , 4.94926  , 5.276943 , 4.5606093, 4.350913 ,
        4.8356533, 5.3012824, 4.9218187, 5.049454 , 5.0492573, 5.3923044,
        5.154832 , 5.172335 , 4.881768 , 4.9017844, 4.2971644, 5.148047 ,
        4.9634027, 4.969659 , 5.075037 , 4.83324  , 4.8330393, 5.2102723,
        4.702924 , 4.691301 , 5.379582 , 4.30379  , 5.499221 , 4.7738557,
        5.0708055, 4.6337047, 4.5571523, 4.7732024, 4.9907207, 4.967629 ]],
      dtype=float32)array([[2.916654 , 3.1028943, 2.8679268, 2.8739176, 3.1552985, 3.1833546,
        2.927499 , 2.9361997, 3.2647386, 2.7347465, 3.6316378, 3.186601 ,
        2.9091659, 2.878644 , 3.0915058, 3.2480638, 2.8444803, 2.9026027,
        3.0260732, 3.271102 , 3.2191224, 3.219358 , 2.9903677, 3.2239099,
        3.1981618, 3.1861975, 2.8435102, 3.0665843, 2.7998087, 3.2023396,
        3.0847368, 3.129275 , 2.832983 , 2.7954605, 3.4137855, 2.7664967,
        2.963932 , 3.4414675, 2.8450003, 3.208023 , 3.1765056, 3.1218975,
        3.1018298, 2.775923 , 3.06603  , 2.757947 , 3.0582092, 3.1054397,
        2.8913946, 2.94138  , 3.5591953, 3.0489569, 3.0945673, 3.0281289,
        2.956422 , 3.021427 , 3.1175191, 2.9662404, 2.9738827, 3.0137942,
        3.0707273, 2.9698124, 3.1119697, 2.8752227, 3.112528 , 3.043435 ,
        2.8680277, 3.205892 , 3.4510348, 2.9880588, 2.9658875, 2.6942782,
        2.6942782, 2.6509972, 2.7107518, 3.024229 , 2.7974968, 2.8184597,
        3.330964 , 3.1583602, 2.9045746, 2.8659198, 3.016166 , 3.070869 ,
        3.4217649, 2.9372633, 3.2278461, 2.8106005, 2.9665637, 2.9308615,
        2.7869065, 3.2926528, 2.7854984, 3.6578674, 3.8099601, 2.665423 ,
        3.0543578, 2.9724834, 3.0155618, 3.133643 , 2.9342916, 3.0967343,
        3.5509822, 3.3757555, 3.491986 , 3.4743314, 2.8963757, 3.2922297,
        2.6979878, 3.2163372, 2.8820283, 2.8636553, 3.1171484, 3.261943 ,
        3.330661 , 3.3311145, 3.3987849, 2.7803333, 2.9204037, 3.2823446,
...
        2.9119968, 3.0894525, 3.6251438, 3.1715524, 3.0566335, 2.939902 ,
        2.8716073, 2.9746158, 2.9994235, 3.2255704, 3.2961411, 2.9679601,
        3.4174898, 3.3049793, 2.9214787, 3.075012 , 2.8656034, 3.1087933,
        2.8138204, 3.1475034, 3.1343572, 3.1616294, 3.1369445, 3.1915271,
        2.875889 , 3.1248076, 3.087011 , 2.9718072, 3.2590137, 3.2520576,
        3.185291 , 2.859958 , 2.821579 , 2.781268 , 2.8592463, 2.9062448,
        3.534457 , 2.9518278, 3.3549776, 2.7666702, 2.7543926, 3.254572 ,
        2.9607222, 3.1747704, 3.569678 , 3.118485 , 2.9555767, 2.847376 ,
        2.9998915, 3.1964002, 3.1241221, 3.0805674, 3.368725 , 2.869935 ,
        2.7692757, 2.6986244, 3.442483 , 3.3160274, 2.842026 , 2.842026 ,
        3.2840214, 3.4086637, 2.5889783, 2.8690841, 3.1804018, 3.1136734,
        3.1699562, 2.8367765, 2.8882794, 3.1074224, 2.8305895, 2.9813702,
        3.1293077, 2.8376162, 2.9135425, 3.3758152, 3.674394 , 2.779959 ,
        3.299925 , 3.1810057, 3.2408876, 2.957203 , 3.4878654, 2.9024978,
        3.3169794, 3.256874 , 3.084282 , 2.9575238, 2.7949817, 3.0721962,
        3.303253 , 3.2028015, 2.9452498, 3.3317125, 3.3685954, 2.8577948,
        3.2133477, 2.846444 , 3.1764314, 2.807072 , 2.778171 , 2.8153605,
        3.2772346, 2.8095682, 3.1858308, 2.9680984, 2.7983553, 3.3925478,
        2.7187293, 3.3872933, 3.0770774, 2.6717114, 3.4103987, 3.339314 ]],
      dtype=float32)PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
            ...
            740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
           dtype='int64', name='draw', length=750))<xarray.Dataset>
Dimensions:  (chain: 1, draw: 750, y_dim_0: 100)
Coordinates:
  * chain    (chain) int64 0
  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749
  * y_dim_0  (y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 91 92 93 94 95 96 97 98 99
Data variables:
    y        (chain, draw, y_dim_0) float32 -2.681 -1.99 ... -2.229 -2.194
Attributes:
    created_at:                 2023-05-29T23:56:31.843696
    arviz_version:              0.12.1
    inference_library:          numpyro
    inference_library_version:  0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
       18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
       36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
       54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
       72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
       90, 91, 92, 93, 94, 95, 96, 97, 98, 99])array([[[-2.6805243, -1.9896044, -2.146286 , ..., -2.2449687,
         -2.0858808, -2.048276 ],
        [-2.4436066, -2.0799394, -2.3294597, ..., -2.4485612,
         -2.2511694, -2.198129 ],
        [-2.746058 , -1.9728756, -2.108558 , ..., -2.2030478,
         -2.051998 , -2.0178077],
        ...,
        [-2.5292451, -1.9190506, -2.2017465, ..., -2.3469207,
         -2.1082296, -2.0462952],
        [-2.5889425, -2.1491568, -2.2932622, ..., -2.3739157,
         -2.2424622, -2.2096944],
        [-2.5803223, -2.1288443, -2.2824044, ..., -2.3674629,
         -2.2286887, -2.1939304]]], dtype=float32)PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
            ...
            740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
           dtype='int64', name='draw', length=750))PandasIndex(Int64Index([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
            17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
            34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
            51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67,
            68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84,
            85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99],
           dtype='int64', name='y_dim_0'))<xarray.Dataset>
Dimensions:    (chain: 1, draw: 750)
Coordinates:
  * chain      (chain) int64 0
  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 742 743 744 745 746 747 748 749
Data variables:
    diverging  (chain, draw) bool False False False False ... False False False
Attributes:
    created_at:                 2023-05-29T23:56:31.764392
    arviz_version:              0.12.1
    inference_library:          numpyro
    inference_library_version:  0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([[False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
...
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False]])PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
            ...
            740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
           dtype='int64', name='draw', length=750))<xarray.Dataset>
Dimensions:  (y_dim_0: 100)
Coordinates:
  * y_dim_0  (y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 91 92 93 94 95 96 97 98 99
Data variables:
    y        (y_dim_0) float32 1.78 5.272 6.843 6.101 ... 6.247 7.294 6.49 6.21
Attributes:
    created_at:                 2023-05-29T23:56:31.844516
    arviz_version:              0.12.1
    inference_library:          numpyro
    inference_library_version:  0.9.2array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
       18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
       36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
       54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
       72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
       90, 91, 92, 93, 94, 95, 96, 97, 98, 99])array([ 1.7799622 ,  5.2715006 ,  6.843007  ,  6.101341  ,  1.3784775 ,
        3.5009139 ,  7.0576353 ,  7.4335885 ,  8.715986  ,  1.1754153 ,
        4.9838333 ,  1.7403774 ,  5.123005  ,  1.4687153 ,  6.9200506 ,
        6.5501447 ,  7.0576353 ,  4.713186  ,  4.6818757 ,  6.7534842 ,
        1.7409545 ,  3.115359  , 10.992229  ,  3.3846278 ,  2.3424428 ,
        6.3324614 ,  0.99764264,  1.5388296 , 10.105146  , 11.020432  ,
        4.220744  ,  2.1013582 ,  6.6121535 ,  5.773824  ,  0.9355608 ,
        7.78474   ,  4.040793  ,  2.8486586 ,  2.0916374 ,  4.882524  ,
        5.1691275 ,  3.7109454 ,  6.0718546 ,  4.425353  ,  1.1271195 ,
        3.8772562 , 10.373571  ,  3.238866  ,  7.975573  , 10.97725   ,
        1.273873  ,  1.2262683 ,  2.7989383 ,  4.9672947 ,  0.7470391 ,
        3.253747  ,  3.8842838 , -0.34159485,  2.9436347 , -2.2525995 ,
        9.075224  ,  8.399537  ,  3.4938996 , -1.7792448 ,  1.692837  ,
       10.415877  ,  3.1210582 ,  6.8077483 ,  9.530612  ,  4.725291  ,
        9.805911  ,  5.4888196 ,  2.1056437 ,  6.0538917 ,  3.962473  ,
        5.519883  ,  4.078588  ,  7.26097   ,  4.400875  ,  0.30348757,
        3.18701   ,  4.809211  ,  8.965049  ,  4.0829873 ,  5.228887  ,
        9.951787  ,  5.6544847 ,  0.64701927,  9.465937  ,  7.890076  ,
        4.776916  ,  4.893428  , 10.160401  ,  2.3604243 ,  7.6586046 ,
        7.2983284 ,  6.2470894 ,  7.294434  ,  6.4904785 ,  6.210163  ],
      dtype=float32)PandasIndex(Int64Index([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
            17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
            34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
            51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67,
            68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84,
            85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99],
           dtype='int64', name='y_dim_0'))fig = plots.plot_st(mcmc, y)
iq_data = pd.read_csv('datasets/TwoGroupIQ.csv')
iq_data['Group'] = iq_data['Group'].astype('category')
smart_group_data = iq_data[iq_data.Group == 'Smart Drug']
smart_group_data.describe()
| Score | |
|---|---|
| count | 63.000000 | 
| mean | 107.841270 | 
| std | 25.445201 | 
| min | 50.000000 | 
| 25% | 96.000000 | 
| 50% | 107.000000 | 
| 75% | 119.000000 | 
| max | 208.000000 | 
Then, we will apply the one group model to the data and plot the results.
mcmc_key = random.PRNGKey(0)
kernel = NUTS(glm_metric.one_group)
mcmc = MCMC(kernel, num_warmup=250, num_samples=750)
mcmc.run(mcmc_key, jnp.array(smart_group_data.Score.values))
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
      mean    107.80      3.24    107.89    102.57    113.19    539.80      1.00
       std     26.08      2.35     25.88     22.73     30.35    642.03      1.00
Number of divergences: 0
numpyro_glm.plot_diagnostic(mcmc, ['mean', 'std'])
<xarray.Dataset>
Dimensions:  (chain: 1, draw: 750)
Coordinates:
  * chain    (chain) int64 0
  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749
Data variables:
    mean     (chain, draw) float32 110.0 105.3 110.2 109.4 ... 106.8 108.7 108.2
    std      (chain, draw) float32 24.79 26.16 24.03 23.97 ... 22.13 29.33 28.75
Attributes:
    created_at:                 2023-05-29T23:56:35.414985
    arviz_version:              0.12.1
    inference_library:          numpyro
    inference_library_version:  0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([[109.97527 , 105.29083 , 110.214516, 109.411125, 100.93915 ,
        108.557785, 109.7286  , 107.717476, 110.33823 , 105.4407  ,
        115.047325, 110.74487 , 109.90371 , 111.92637 , 103.67735 ,
        109.43993 , 106.924866, 105.82306 , 109.33133 , 113.90438 ,
        108.25385 , 105.84257 , 108.72531 , 109.02865 , 113.91567 ,
        113.41177 , 103.41682 , 108.057   , 104.65977 , 110.4926  ,
        113.583885, 102.19425 , 113.01051 , 111.27577 , 106.85726 ,
        102.25837 , 107.078285, 113.12787 , 107.06807 , 104.69756 ,
        104.80107 , 103.58743 , 104.322945, 108.6351  , 106.419586,
        109.62201 , 107.34422 , 107.35962 , 110.8534  , 106.74869 ,
        109.08335 , 109.06311 , 108.5331  , 104.78836 , 116.33449 ,
        114.67767 , 112.495895, 104.07471 , 106.779434, 107.259605,
        108.09704 , 108.10233 , 106.08486 , 109.5852  , 106.001076,
        105.14359 , 109.282585, 110.47549 , 104.854195, 108.91195 ,
        106.613556, 110.542786, 110.542786, 108.08629 , 109.27255 ,
        107.772766, 104.430824, 103.29096 , 110.69535 , 112.892365,
        106.9933  , 108.17603 , 108.825   , 105.24313 , 104.06096 ,
        109.17792 , 107.244606, 108.34675 , 104.49413 , 106.83144 ,
        104.965004, 110.35247 , 105.32731 , 109.82167 , 105.720985,
        107.46112 , 110.59209 , 109.11344 , 106.73796 ,  99.25905 ,
...
        106.25765 , 107.159256, 107.43473 , 114.022736, 107.4383  ,
        109.01808 , 103.7021  , 111.57398 , 110.96743 , 113.67769 ,
        114.12441 , 109.39441 , 104.46808 , 105.55915 , 105.599365,
        111.64049 , 110.684875, 109.979706, 113.008514, 113.94847 ,
        103.412254, 104.3738  , 108.75197 , 111.8318  , 104.77148 ,
        107.812775, 104.56448 , 108.718765, 105.21872 , 105.55679 ,
        104.245285, 111.70955 , 104.24466 , 111.4618  , 110.510574,
        107.00754 , 110.02091 , 104.797165, 106.90024 , 112.80083 ,
        109.23521 , 105.043015, 105.97836 , 103.32183 , 108.15405 ,
        107.78267 , 107.18314 , 106.99867 , 108.12765 , 107.89384 ,
        109.307655, 112.93956 , 112.01034 , 110.8072  , 112.965576,
        112.18863 , 112.52543 , 108.96955 , 102.037285, 107.1725  ,
        104.57445 , 109.20819 , 111.1646  , 102.88633 , 108.055016,
        104.86002 , 111.12093 , 106.58487 , 109.66695 , 109.28379 ,
        113.10869 , 109.12012 , 109.58398 , 106.80441 , 107.53587 ,
        100.24453 , 113.68477 , 106.12561 , 106.85108 , 108.66731 ,
        106.49102 , 106.621   , 110.45593 , 104.96861 , 113.12082 ,
        111.35625 , 102.31912 , 106.25812 , 110.87694 , 109.251045,
        107.50152 , 105.6816  , 106.77369 , 108.67319 , 108.24811 ]],
      dtype=float32)array([[24.786228, 26.155521, 24.029436, 23.965836, 27.113821, 27.424044,
        25.806293, 26.273508, 27.535105, 22.97756 , 32.289055, 24.15046 ,
        24.51352 , 24.034502, 26.246254, 27.881187, 23.841017, 24.366045,
        25.547419, 28.463753, 27.826414, 27.78451 , 25.26547 , 27.71857 ,
        27.465443, 26.184599, 24.367947, 28.29067 , 29.094795, 22.382898,
        23.262806, 28.737497, 22.244843, 25.077694, 29.279648, 27.409065,
        24.064144, 30.407253, 23.98416 , 27.507038, 27.295832, 26.6867  ,
        26.46073 , 22.920782, 25.979725, 22.771954, 25.901804, 26.450947,
        24.121893, 24.683409, 31.45096 , 26.017336, 26.34012 , 25.636229,
        26.706406, 27.788372, 28.212412, 26.795696, 26.954515, 24.471197,
        27.04578 , 24.21856 , 25.055971, 25.211184, 25.342659, 25.057606,
        23.631056, 30.170628, 30.582378, 25.426617, 24.681257, 22.228506,
        22.228506, 21.75323 , 22.333263, 25.517614, 23.16428 , 23.095354,
        29.873755, 27.750286, 24.93584 , 23.626965, 25.444895, 26.3856  ,
        30.052254, 24.859749, 22.761942, 28.185104, 23.586464, 25.712185,
        22.451385, 29.302467, 22.616674, 22.743664, 23.170364, 23.876362,
        27.080822, 25.806063, 24.978598, 26.404415, 27.514174, 26.977419,
        31.662498, 29.875046, 31.077866, 25.236607, 24.985004, 28.296259,
        22.199097, 27.613825, 24.109753, 23.913204, 26.575354, 28.296556,
        29.036882, 29.066261, 29.816078, 23.089298, 24.582733, 26.92522 ,
...
        24.014801, 27.789244, 32.343906, 27.632095, 26.07815 , 24.650728,
        23.914785, 25.040188, 25.317043, 27.565416, 28.408096, 25.107414,
        22.729567, 23.612532, 23.126312, 26.601177, 23.807423, 24.42951 ,
        23.662622, 26.85912 , 26.753407, 27.049782, 25.479877, 26.435293,
        23.96122 , 26.330206, 26.018835, 25.607841, 28.446865, 28.278666,
        27.410109, 28.209177, 30.121782, 30.168917, 22.123705, 23.426945,
        31.40853 , 25.059225, 29.351217, 27.96904 , 25.75235 , 31.656654,
        25.066679, 27.139214, 31.495024, 26.959902, 25.11425 , 26.711693,
        24.692608, 23.715153, 27.323866, 26.20297 , 24.107151, 24.007935,
        24.023855, 26.436314, 25.4301  , 25.759718, 24.9387  , 31.076105,
        22.513962, 25.891182, 29.503733, 23.962814, 27.266918, 26.573658,
        27.158567, 23.648088, 24.148748, 28.249935, 30.239103, 29.819923,
        26.691223, 23.647682, 24.47609 , 29.486397, 24.319372, 30.137695,
        32.97399 , 23.286343, 28.740875, 24.984743, 30.649132, 34.445366,
        27.562742, 27.380943, 25.81166 , 25.387672, 23.473827, 26.251698,
        28.780266, 27.655264, 24.728743, 28.889326, 29.331318, 24.23243 ,
        27.918295, 23.9464  , 27.325426, 23.267876, 22.971722, 23.432446,
        28.220026, 27.862444, 27.050863, 25.492939, 29.31498 , 23.50533 ,
        24.160336, 27.551594, 24.947176, 22.131338, 29.331263, 28.75127 ]],
      dtype=float32)PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
            ...
            740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
           dtype='int64', name='draw', length=750))<xarray.Dataset>
Dimensions:  (chain: 1, draw: 750, y_dim_0: 63)
Coordinates:
  * chain    (chain) int64 0
  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749
  * y_dim_0  (y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62
Data variables:
    y        (chain, draw, y_dim_0) float32 -4.181 -4.136 ... -5.051 -4.286
Attributes:
    created_at:                 2023-05-29T23:56:35.499221
    arviz_version:              0.12.1
    inference_library:          numpyro
    inference_library_version:  0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
       18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
       36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
       54, 55, 56, 57, 58, 59, 60, 61, 62])array([[[-4.180992 , -4.136431 , -4.392193 , ..., -4.210211 ,
         -5.0714164, -4.132563 ],
        [-4.1909137, -4.185134 , -4.312105 , ..., -4.203458 ,
         -5.278142 , -4.2158976],
        [-4.1566496, -4.107166 , -4.3855066, ..., -4.1885657,
         -5.086642 , -4.1009784],
        ...,
        [-4.039196 , -4.0159855, -4.238742 , ..., -4.062772 ,
         -5.430601 , -4.0438166],
        [-4.3234735, -4.2992196, -4.459157 , ..., -4.341311 ,
         -5.022892 , -4.3040247],
        [-4.3012333, -4.2785625, -4.4373045, ..., -4.31877  ,
         -5.050753 , -4.2861347]]], dtype=float32)PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
            ...
            740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
           dtype='int64', name='draw', length=750))PandasIndex(Int64Index([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
            17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
            34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
            51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62],
           dtype='int64', name='y_dim_0'))<xarray.Dataset>
Dimensions:    (chain: 1, draw: 750)
Coordinates:
  * chain      (chain) int64 0
  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 742 743 744 745 746 747 748 749
Data variables:
    diverging  (chain, draw) bool False False False False ... False False False
Attributes:
    created_at:                 2023-05-29T23:56:35.416760
    arviz_version:              0.12.1
    inference_library:          numpyro
    inference_library_version:  0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([[False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
...
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False]])PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
            ...
            740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
           dtype='int64', name='draw', length=750))<xarray.Dataset>
Dimensions:  (y_dim_0: 63)
Coordinates:
  * y_dim_0  (y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62
Data variables:
    y        (y_dim_0) int32 102 107 92 101 110 68 119 ... 97 139 72 100 144 112
Attributes:
    created_at:                 2023-05-29T23:56:35.500056
    arviz_version:              0.12.1
    inference_library:          numpyro
    inference_library_version:  0.9.2array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
       18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
       36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
       54, 55, 56, 57, 58, 59, 60, 61, 62])array([102, 107,  92, 101, 110,  68, 119, 106,  99, 103,  90,  93,  79,
        89, 137, 119, 126, 110,  71, 114, 100,  95,  91,  99,  97, 106,
       106, 129, 115, 124, 137,  73,  69,  95, 102, 116, 111, 134, 102,
       110, 139, 112, 122,  84, 129, 112, 127, 106, 113, 109, 208, 114,
       107,  50, 169, 133,  50,  97, 139,  72, 100, 144, 112], dtype=int32)PandasIndex(Int64Index([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
            17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
            34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
            51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62],
           dtype='int64', name='y_dim_0'))fig = plots.plot_st(
    mcmc, smart_group_data.Score.values,
    mean_comp_val=100,
    std_comp_val=15,
    effsize_comp_val=0,
)
numpyro.render_model( glm_metric.one_group_robust, model_args=(jnp.ones(5),), render_params=True)
MEAN = 5
SIGMA = 3
NORMALITY = 3
N = 1000
y = np.random.standard_t(NORMALITY, size=N) * SIGMA + MEAN
# Plot the histogram and true PDF of normal distribution.
fig, ax = plt.subplots()
ax.hist(y, density=True, bins=100, label='Histogram of $y$')
xmin, xmax = ax.get_xlim()
p = np.linspace(xmin, xmax, 1000)
ax.plot(p, t.pdf(p, loc=MEAN, scale=SIGMA, df=NORMALITY), label='Student-$t$ PDF')
ax.plot(p, norm.pdf(p, loc=y.mean(), scale=y.std()),
        label='Normal PDF using\ndata mean and std')
ax.legend()
fig.tight_layout()
Using the robust metric model on our synthesis data to see if it can recover the original parameters.
mcmc_key = random.PRNGKey(0)
kernel = NUTS(glm_metric.one_group_robust)
mcmc = MCMC(kernel, num_warmup=250, num_samples=750)
mcmc.run(mcmc_key, jnp.array(y))
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
      mean      5.19      0.12      5.19      5.01      5.38    528.65      1.01
        nu      3.02      0.32      3.00      2.48      3.54    389.25      1.00
     sigma      2.97      0.12      2.97      2.78      3.17    351.15      1.00
Number of divergences: 0
numpyro_glm.plot_diagnostic(mcmc, ['mean', 'sigma', 'nu'])
<xarray.Dataset>
Dimensions:  (chain: 1, draw: 750)
Coordinates:
  * chain    (chain) int64 0
  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749
Data variables:
    mean     (chain, draw) float32 5.225 5.124 5.198 5.033 ... 5.201 5.154 5.112
    nu       (chain, draw) float32 3.03 2.811 3.201 2.794 ... 3.153 3.204 3.224
    sigma    (chain, draw) float32 3.008 2.894 3.055 2.892 ... 3.009 3.049 3.047
Attributes:
    created_at:                 2023-05-29T23:56:40.651587
    arviz_version:              0.12.1
    inference_library:          numpyro
    inference_library_version:  0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([[5.2252526, 5.1241636, 5.1982026, 5.032679 , 5.0717063, 5.295204 ,
        5.3268604, 5.2793097, 5.081863 , 5.24317  , 5.1877646, 5.13194  ,
        5.097041 , 5.299657 , 5.100597 , 5.3192477, 5.197055 , 5.197055 ,
        5.335744 , 5.362243 , 5.265542 , 5.07826  , 5.033498 , 5.158056 ,
        5.043956 , 5.178067 , 5.2793436, 5.1069593, 5.216472 , 5.0388417,
        5.3190384, 5.118554 , 5.3592353, 5.297184 , 5.054299 , 5.332943 ,
        5.1977186, 5.186516 , 5.0346336, 5.0033903, 5.351208 , 5.346388 ,
        5.2818646, 5.0247946, 5.3952494, 5.3240824, 5.038796 , 5.351271 ,
        5.0122066, 5.106121 , 5.239272 , 5.0579095, 5.178805 , 5.0543776,
        4.985325 , 5.383977 , 5.3296695, 5.2475777, 5.2292995, 5.3917494,
        5.176481 , 5.0985913, 5.0745983, 5.3064804, 5.3190703, 5.3065405,
        5.073111 , 5.127397 , 5.1585197, 5.004515 , 5.0588617, 5.050305 ,
        4.9809313, 5.024628 , 5.301579 , 5.290991 , 5.326011 , 5.0710893,
        5.229229 , 5.140679 , 5.171478 , 5.1049523, 5.4097652, 4.9618893,
        5.0410748, 5.192658 , 5.006359 , 5.201808 , 5.360339 , 5.293335 ,
        5.180471 , 5.2222195, 5.106078 , 5.0473633, 5.445147 , 5.3690085,
        5.356301 , 5.314369 , 5.1836066, 5.214131 , 5.137178 , 5.019918 ,
        5.1467023, 5.120248 , 5.177939 , 5.1632943, 5.2115273, 5.020947 ,
        5.232629 , 5.138412 , 5.414697 , 5.2546616, 5.315323 , 5.3316145,
        5.029754 , 5.271732 , 5.2580037, 4.9675856, 4.854825 , 5.1653094,
...
        5.1836634, 5.1777854, 5.0916348, 5.1517425, 4.94229  , 5.338954 ,
        5.3075275, 4.9510856, 5.355672 , 5.050706 , 5.1464505, 5.1667733,
        5.102213 , 5.142809 , 5.2244234, 5.0835414, 5.2575097, 5.3360806,
        5.255036 , 5.1094394, 5.2511826, 5.383278 , 5.104622 , 5.010276 ,
        5.374145 , 5.1573734, 4.9785376, 5.130616 , 5.086193 , 5.1101737,
        5.196637 , 5.2557135, 5.2129374, 5.216792 , 5.3181276, 5.152473 ,
        5.1690083, 5.2550898, 5.1514096, 5.111661 , 5.12744  , 5.225885 ,
        5.371412 , 5.057382 , 5.313378 , 5.198703 , 5.218926 , 5.1541963,
        5.0360336, 5.1063595, 5.252289 , 5.109053 , 5.1100416, 5.109926 ,
        5.4005275, 5.206244 , 5.423744 , 5.055256 , 5.262973 , 5.369212 ,
        5.2393475, 5.167504 , 5.057998 , 5.1727924, 5.123668 , 5.077471 ,
        5.277476 , 5.152246 , 5.3775187, 5.2343764, 5.320696 , 5.0608945,
        5.297638 , 5.2998834, 5.0541434, 5.2953877, 5.295298 , 5.37311  ,
        5.287508 , 5.177804 , 5.1395016, 4.9199777, 4.9277735, 4.9639325,
        5.01235  , 5.1524143, 5.2529345, 5.2919297, 5.009624 , 5.0876145,
        5.0876145, 5.080562 , 5.2489963, 5.2815833, 5.2815833, 5.132526 ,
        5.18512  , 5.1307178, 5.247394 , 5.1012545, 5.0817285, 5.4392815,
        5.058194 , 5.067048 , 5.154724 , 5.154724 , 5.00218  , 5.0527177,
        5.2726903, 5.21625  , 5.3625455, 5.200683 , 5.1538286, 5.11227  ]],
      dtype=float32)array([[3.0302858, 2.8105295, 3.2014136, 2.7935388, 2.9985032, 2.5012183,
        3.0470521, 2.9146385, 3.8911355, 2.3872564, 2.428279 , 2.8282342,
        3.0597894, 2.883507 , 3.410241 , 3.499199 , 2.655398 , 2.655398 ,
        2.6596022, 3.048019 , 2.955171 , 2.7809849, 2.9156413, 2.7100785,
        2.7034278, 2.8511982, 2.8510652, 2.8014731, 2.724396 , 3.0761416,
        3.1426032, 2.7679362, 3.0328808, 3.1469035, 3.0298078, 2.9144955,
        2.9299479, 2.8942003, 2.8006005, 3.2382188, 3.0903177, 3.0062823,
        2.855648 , 3.045023 , 3.5830653, 3.472515 , 3.0960622, 3.2011704,
        3.1031928, 3.4019434, 3.2195787, 3.0908947, 3.0280414, 3.4605806,
        3.024026 , 3.2612574, 3.1635637, 2.7193754, 3.5790129, 3.5917504,
        2.7907567, 3.0398004, 3.322532 , 3.58795  , 3.3344655, 3.4647434,
        3.0550628, 3.6886106, 2.9566228, 2.8622136, 2.967669 , 3.297165 ,
        3.111523 , 2.8584785, 2.6858768, 3.014183 , 2.8630064, 2.8034956,
        3.193019 , 2.812081 , 2.8540955, 2.8850079, 3.1799855, 3.414366 ,
        3.5967433, 2.7563825, 2.9630618, 3.2624266, 3.1123424, 3.0870428,
        2.725316 , 2.961867 , 3.3157516, 3.561547 , 2.7511356, 2.7754855,
        3.1627483, 2.84977  , 3.4220443, 2.3197072, 2.8218226, 2.4043255,
        2.6734753, 3.0119743, 3.0013568, 2.9657485, 3.18316  , 3.5373964,
        3.4768598, 3.2778487, 3.138904 , 3.0696943, 2.8126042, 2.8059587,
        2.7557168, 2.707277 , 2.7965493, 2.5954356, 2.7712111, 2.5360422,
...
        3.2974737, 2.608031 , 3.3340845, 3.1168242, 3.371426 , 2.484195 ,
        2.5683467, 3.5349627, 3.216056 , 2.9181938, 2.941831 , 3.331728 ,
        3.2328591, 3.0172265, 2.7155166, 2.9464223, 2.7724924, 2.6816273,
        2.5984507, 3.5753417, 3.6234405, 3.1972866, 3.0058103, 2.4069138,
        3.2024572, 3.1628978, 2.934284 , 2.6520753, 2.5028887, 2.4886835,
        2.3843923, 2.5833275, 2.5340605, 2.590158 , 2.5927703, 3.110703 ,
        3.1948829, 3.28651  , 3.098055 , 3.193087 , 3.2967393, 3.358853 ,
        3.7556908, 3.1511366, 2.7770283, 3.003881 , 2.752591 , 3.2983627,
        3.2164586, 3.1108088, 3.0520477, 2.815367 , 2.796043 , 2.3711905,
        2.977821 , 2.7997148, 2.4689636, 4.148687 , 3.509022 , 3.2277522,
        3.6996043, 4.011406 , 3.2496595, 2.5972798, 2.8907168, 3.082364 ,
        3.1959908, 3.0188456, 3.1784215, 2.9925566, 3.1192105, 2.6216607,
        2.9329195, 2.8352685, 2.9127254, 2.5918915, 3.7815177, 3.487691 ,
        3.2226639, 3.4023337, 2.5276363, 2.8847451, 2.830494 , 2.9687147,
        3.082112 , 2.9967947, 3.1591215, 2.67426  , 2.7939107, 2.9403849,
        2.9403849, 2.820949 , 3.3997335, 2.928617 , 2.928617 , 2.49297  ,
        2.991461 , 3.1410506, 3.071447 , 3.406098 , 2.8856087, 2.9846396,
        3.5282505, 2.6596055, 2.9256043, 2.9256043, 2.6848874, 2.655949 ,
        2.6168833, 3.115498 , 3.1219566, 3.1531057, 3.203574 , 3.224131 ]],
      dtype=float32)array([[3.0077553, 2.893796 , 3.0548732, 2.891678 , 2.8712099, 2.8375256,
        2.826606 , 2.9720817, 3.263655 , 2.7567494, 2.7421374, 2.9102821,
        2.9675138, 2.9290013, 3.1114519, 3.1912806, 2.7879858, 2.7879858,
        2.864312 , 2.8857582, 2.9045377, 2.9562974, 2.9119508, 2.9810064,
        3.0166419, 2.9744427, 2.8893104, 2.95488  , 2.932642 , 2.9176862,
        2.9649498, 2.8088963, 2.9177237, 2.9527833, 2.8886452, 2.962949 ,
        2.9381344, 2.8418946, 2.9687228, 3.1053722, 3.0414934, 2.9902358,
        2.8479242, 3.091427 , 3.006934 , 3.0294032, 2.9957023, 3.0173352,
        3.171182 , 3.0915844, 2.998457 , 3.0699952, 3.1018872, 2.930065 ,
        3.0457103, 2.9077508, 3.0073476, 2.8968403, 3.14788  , 3.1162343,
        2.9465456, 3.0800798, 2.8778477, 3.0915356, 3.1263587, 2.9847069,
        3.1584203, 3.0032854, 3.094975 , 2.8946183, 2.9876933, 3.0804462,
        2.974532 , 2.8442578, 2.8904667, 3.00443  , 2.925379 , 2.9279077,
        3.0053437, 2.8640773, 2.9765718, 2.9506266, 3.0112603, 3.2222466,
        3.1745212, 3.0268316, 2.8382347, 3.2115846, 2.96255  , 2.880831 ,
        2.9258776, 2.9168687, 3.1384408, 3.0122447, 2.756393 , 3.0167165,
        2.9385285, 2.7939687, 3.0074434, 2.661758 , 2.695194 , 2.7966495,
        2.890351 , 3.021478 , 2.91477  , 2.9491024, 3.0953932, 3.014392 ,
        3.0871012, 3.0097551, 3.297201 , 3.0935102, 2.9835587, 2.9820387,
        2.8816643, 2.9155366, 2.946278 , 2.7360942, 2.7621183, 2.9129941,
...
        3.1272457, 3.000786 , 2.9547353, 3.1654298, 3.1500812, 2.7814696,
        2.8203695, 3.0340738, 3.0124326, 3.1742008, 3.188828 , 3.1395504,
        2.9985125, 3.0134718, 3.045996 , 2.8193517, 2.902401 , 2.9608424,
        2.8618124, 3.096644 , 2.9716787, 3.0962532, 2.98246  , 2.7797184,
        3.1126304, 2.9470394, 2.9663665, 2.9348457, 2.707904 , 2.7372358,
        2.8489664, 2.7766206, 2.8870816, 2.7899947, 2.8266532, 3.0544372,
        3.0760381, 3.0813096, 3.097099 , 3.1129563, 3.0446541, 3.1500077,
        3.1358485, 3.0074933, 3.1147547, 2.818752 , 2.9542136, 3.0311131,
        2.9141474, 3.1281784, 2.9774513, 2.740467 , 3.0301168, 2.88417  ,
        2.7082813, 2.867558 , 2.7481692, 3.1149592, 3.3530736, 3.4157538,
        3.227895 , 3.1730561, 3.153692 , 2.89388  , 2.861404 , 3.0043843,
        3.116252 , 3.0452118, 2.9471655, 2.9493282, 2.8316453, 2.9597483,
        2.870535 , 2.9105754, 2.8837337, 2.9166882, 3.1200943, 3.0545652,
        3.1874716, 2.8387008, 2.9814684, 3.0079186, 3.067449 , 2.9708407,
        2.917632 , 2.831051 , 2.8786335, 2.9744313, 3.0196028, 2.9318008,
        2.9318008, 3.0168672, 2.8341033, 3.0753546, 3.0753546, 2.8534124,
        2.9582543, 2.9235747, 3.1306293, 3.08414  , 2.9503288, 3.0093284,
        3.1156042, 2.906394 , 2.8710785, 2.8710785, 2.8466375, 2.7183504,
        2.9024906, 2.939236 , 3.0976253, 3.0085454, 3.0485835, 3.0470424]],
      dtype=float32)PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
            ...
            740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
           dtype='int64', name='draw', length=750))<xarray.Dataset>
Dimensions:  (chain: 1, draw: 750, y_dim_0: 1000)
Coordinates:
  * chain    (chain) int64 0
  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749
  * y_dim_0  (y_dim_0) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999
Data variables:
    y        (chain, draw, y_dim_0) float32 -4.239 -2.245 ... -2.194 -3.763
Attributes:
    created_at:                 2023-05-29T23:56:40.806417
    arviz_version:              0.12.1
    inference_library:          numpyro
    inference_library_version:  0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([ 0, 1, 2, ..., 997, 998, 999])
array([[[-4.23901  , -2.2447944, -3.7616072, ..., -3.9045508,
         -2.1726134, -3.819713 ],
        [-4.248008 , -2.2480104, -3.7681386, ..., -3.9891958,
         -2.1634555, -3.8268566],
        [-4.2147064, -2.2553887, -3.7335567, ..., -3.8977628,
         -2.1846452, -3.7918887],
        ...,
        [-4.237927 , -2.245505 , -3.753389 , ..., -3.9170911,
         -2.172602 , -3.8122478],
        [-4.2017717, -2.262422 , -3.7191315, ..., -3.9176188,
         -2.189119 , -3.7776227],
        [-4.1883435, -2.2696419, -3.70393  , ..., -3.9347727,
         -2.1942322, -3.7625935]]], dtype=float32)PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
            ...
            740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
           dtype='int64', name='draw', length=750))PandasIndex(Int64Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
            ...
            990, 991, 992, 993, 994, 995, 996, 997, 998, 999],
           dtype='int64', name='y_dim_0', length=1000))<xarray.Dataset>
Dimensions:    (chain: 1, draw: 750)
Coordinates:
  * chain      (chain) int64 0
  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 742 743 744 745 746 747 748 749
Data variables:
    diverging  (chain, draw) bool False False False False ... False False False
Attributes:
    created_at:                 2023-05-29T23:56:40.653699
    arviz_version:              0.12.1
    inference_library:          numpyro
    inference_library_version:  0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([[False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
...
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False]])PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
            ...
            740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
           dtype='int64', name='draw', length=750))<xarray.Dataset>
Dimensions:  (y_dim_0: 1000)
Coordinates:
  * y_dim_0  (y_dim_0) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999
Data variables:
    y        (y_dim_0) float32 -1.97 6.648 -0.6971 5.801 ... 11.52 6.219 -0.8494
Attributes:
    created_at:                 2023-05-29T23:56:40.807396
    arviz_version:              0.12.1
    inference_library:          numpyro
    inference_library_version:  0.9.2array([ 0, 1, 2, ..., 997, 998, 999])
array([-1.97044837e+00,  6.64772511e+00, -6.97052896e-01,  5.80093575e+00,
        4.54183054e+00,  1.12101612e+01, -4.44931060e-01,  5.82078266e+00,
        5.82591629e+00,  3.17236757e+00,  8.37355042e+00,  6.53492546e+00,
        4.20800734e+00,  1.01482553e+01,  1.41755648e+01,  1.24651060e+01,
       -2.17906322e+01,  4.92048836e+00,  4.85010052e+00,  1.02006369e+01,
        2.13908386e+00,  2.05172256e-01,  1.30528660e+01, -2.62980032e+00,
        7.23637486e+00,  4.05275726e+00,  3.06575513e+00,  7.08320904e+00,
        5.54312420e+00,  1.49230599e+00,  1.01288261e+01,  8.57089424e+00,
        1.14836349e+01,  5.61097431e+00,  3.13621688e+00,  9.58228970e+00,
        8.92821312e+00,  6.76997042e+00,  1.93540821e+01,  4.11432600e+00,
        1.19134083e+01,  9.42990589e+00,  5.60357380e+00, -1.71314931e+00,
        4.70048040e-01,  3.88558745e+00,  9.75277519e+00,  5.94671679e+00,
        6.14265108e+00, -3.56223047e-01,  1.22409928e+00,  8.90643311e+00,
        1.25334752e+00,  7.24736023e+00,  4.03608942e+00,  6.78641033e+00,
       -2.69396567e+00,  5.06562614e+00, -1.64162469e+00,  2.85413051e+00,
        2.39020967e+00,  4.40503359e+00,  8.13211632e+00,  5.34186506e+00,
        6.01666784e+00,  1.21046562e+01,  2.13109040e+00,  2.75819302e-01,
        2.81802750e+00,  7.41340733e+00,  6.28000736e+00,  6.85987091e+00,
        1.15999956e+01,  4.33733034e+00,  7.85834742e+00,  4.79054356e+00,
       -1.75435328e+00,  2.27244854e+00,  1.29107180e+01,  1.78298130e+01,
...
        7.67104053e+00,  7.03784704e+00,  2.43009663e+00,  6.72932577e+00,
        1.44165316e+01,  5.52429914e+00,  6.58890247e+00,  9.42949772e+00,
        1.42814171e+00,  2.61827564e+00,  5.46275520e+00,  5.43681955e+00,
        4.64565945e+00, -1.88100077e-02,  1.12393579e+01,  6.41488600e+00,
        5.92880297e+00,  7.56008530e+00,  5.77487803e+00,  5.70153856e+00,
        1.06024427e+01,  9.22666454e+00,  1.26327670e+00,  6.25757360e+00,
        6.05027056e+00,  4.46348190e+00,  5.07479429e+00,  3.75031829e+00,
        1.08985357e+01,  8.36709881e+00,  3.84162974e+00,  8.93783290e-03,
        6.21623099e-01,  8.29607487e+00,  6.74988127e+00,  4.13313913e+00,
        3.56779456e-01,  6.27172899e+00,  1.39344180e+00,  8.97291946e+00,
        1.39694996e+01,  7.24138546e+00,  9.02207565e+00,  6.50661051e-01,
       -2.29273844e+00,  6.50683498e+00,  2.56033993e+00,  6.01500273e+00,
        1.85625410e+00,  7.55919504e+00,  5.12628460e+00,  2.19768524e+00,
        6.08607101e+00,  4.16822290e+00, -2.76709747e+00,  4.77808809e+00,
        1.25680656e+01,  1.24284792e+00,  6.16621590e+00,  4.93042040e+00,
        4.13558197e+00,  1.31871214e+01,  2.85820341e+00,  6.72703362e+00,
        7.92491627e+00,  6.70488834e+00,  9.36034393e+00, -7.42135286e-01,
        5.76647472e+00,  2.60789084e+00,  1.30372515e+01,  5.43683338e+00,
        7.20951414e+00,  1.15234461e+01,  6.21906805e+00, -8.49427700e-01],
      dtype=float32)PandasIndex(Int64Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
            ...
            990, 991, 992, 993, 994, 995, 996, 997, 998, 999],
           dtype='int64', name='y_dim_0', length=1000))fig = plots.plot_st_2(mcmc, y)
mcmc_key = random.PRNGKey(0)
kernel = NUTS(glm_metric.one_group_robust)
mcmc = MCMC(kernel, num_warmup=250, num_samples=750)
mcmc.run(mcmc_key, jnp.array(smart_group_data.Score.values))
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
      mean    107.38      2.96    107.35    102.54    112.11    499.25      1.00
        nu      9.63     13.31      5.73      1.20     19.70    158.39      1.01
     sigma     19.83      3.53     19.78     14.20     25.66    235.03      1.00
Number of divergences: 0
numpyro_glm.plot_diagnostic(mcmc, ['mean', 'sigma', 'nu'])
<xarray.Dataset>
Dimensions:  (chain: 1, draw: 750)
Coordinates:
  * chain    (chain) int64 0
  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749
Data variables:
    mean     (chain, draw) float32 105.6 106.3 108.6 104.4 ... 108.3 106.8 104.6
    nu       (chain, draw) float32 6.38 6.304 6.77 2.442 ... 9.179 8.01 13.55
    sigma    (chain, draw) float32 22.7 23.93 23.56 16.29 ... 22.12 23.89 20.27
Attributes:
    created_at:                 2023-05-29T23:56:45.027781
    arviz_version:              0.12.1
    inference_library:          numpyro
    inference_library_version:  0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([[105.55507 , 106.34972 , 108.5634  , 104.371185, 103.78115 ,
        108.39524 , 104.8494  , 105.39043 , 110.073586, 107.43894 ,
        108.44629 , 106.93347 , 102.90011 , 106.07971 , 106.80335 ,
        110.15082 , 109.17264 , 107.16441 , 111.50457 , 113.58138 ,
        102.10726 , 108.93787 , 109.68986 , 110.641335, 105.38782 ,
        108.96525 , 109.787346, 106.19191 , 107.32411 , 103.05021 ,
        109.8304  , 106.86794 , 110.42675 , 104.287094, 103.54164 ,
        106.96897 , 107.24161 , 103.478745, 102.1156  , 100.63876 ,
        108.324066, 108.720985, 108.63    , 111.166565, 109.553505,
        108.90091 , 108.67291 , 105.97729 , 110.96842 , 105.5543  ,
        104.999115, 106.73197 , 107.25278 , 105.5463  , 108.59721 ,
        112.095764, 111.39011 , 110.41567 , 101.798965, 107.15456 ,
        110.26754 , 108.96618 , 105.98647 , 105.55314 , 106.5947  ,
        109.240746, 105.711555, 105.50618 , 106.02377 , 111.31043 ,
        103.732605, 102.67975 , 107.562675, 106.68475 , 106.863625,
        109.01302 , 104.004974, 107.06012 , 107.67128 , 106.929245,
        107.940384, 106.64516 , 112.43416 , 113.13488 , 112.6087  ,
        108.9748  , 104.34741 , 105.08496 , 111.78247 , 105.56174 ,
        106.40983 , 108.041275, 108.06947 , 105.43695 , 108.262535,
        111.17323 , 104.075134, 104.57462 , 104.88279 , 105.00093 ,
...
        103.982544, 106.368774, 111.052895, 109.16271 , 110.65621 ,
        110.26432 , 107.34818 , 112.2538  , 104.68615 , 103.45305 ,
        107.17469 , 107.23535 , 108.20257 , 107.94324 , 107.99706 ,
        105.0354  , 108.69497 , 107.565155, 105.54674 , 108.33376 ,
        106.40116 , 106.42444 , 105.807465, 102.839745, 109.73883 ,
        111.335785, 109.93152 , 110.02944 , 109.36531 , 111.29636 ,
        110.44721 , 104.324524, 109.9587  , 112.545685, 111.32425 ,
        113.93877 , 111.03316 , 106.69128 , 105.634125, 107.761185,
        103.34582 , 104.21555 , 104.75715 , 105.540276, 106.058044,
        110.38669 , 110.98076 , 104.838936, 103.348526, 111.428215,
        111.38477 , 111.552666, 107.4357  , 105.75314 , 106.130035,
        105.13054 , 108.275925, 113.34639 , 103.81799 , 103.676956,
        105.02578 , 106.404274, 101.28366 , 105.969345, 105.57954 ,
        105.5759  , 107.31593 , 108.7727  , 109.47214 , 105.593376,
        107.311165, 104.327995, 104.20603 , 107.11131 , 109.11189 ,
        109.00499 , 107.64095 , 106.795364, 107.13538 , 110.425575,
        107.195335, 102.86533 , 111.92812 , 102.704185, 102.664154,
        104.04894 , 109.90849 , 104.54301 , 104.901215, 106.85347 ,
        105.591606, 108.840034, 108.32749 , 106.838234, 104.625496]],
      dtype=float32)array([[  6.380366 ,   6.3039827,   6.770382 ,   2.4415083,   4.205345 ,
          2.3898754,   2.4554567,   4.751895 ,  30.20579  ,   6.2195134,
          3.5312512,   4.76675  ,   5.0081363,   3.1247911,  21.360893 ,
         23.807055 ,  10.058445 ,   7.779872 ,  11.477498 ,  21.622583 ,
         36.68576  ,   5.816049 ,   7.4071264,   4.7910275,   3.411479 ,
          5.6175413,   3.6065092,   3.8919663,   2.406946 ,   6.614321 ,
         17.233143 ,   2.8523195,   9.354761 ,   5.276261 ,   4.406109 ,
          3.6853158,   3.4503999,   6.083024 ,   3.162562 ,   7.787067 ,
         18.096825 ,   6.6824446,  11.993019 ,   4.9092307,  12.329353 ,
         10.28212  ,  12.334496 ,  21.102882 ,  68.73028  ,  28.93483  ,
         24.624996 ,  22.925674 ,  18.960793 ,  17.155924 ,   4.483561 ,
          9.103407 ,   8.090643 ,   3.4731753,  21.862223 ,   3.5392842,
          6.277596 ,   4.3358054,  10.864571 ,   1.2639813,   1.8327861,
          5.8132896,   2.6647193,  10.151241 ,   4.558977 ,  11.674158 ,
         13.655367 ,  16.772926 ,  25.298958 ,   1.9189749,   5.538587 ,
          3.6070976,   9.118575 ,   4.755516 ,   4.6680737,   3.8454957,
          3.030088 ,   3.3910158,   6.3802996,   3.3749137,   4.504877 ,
         10.809336 ,   6.2159004,  15.667531 ,   4.7564974,   8.262956 ,
          9.560621 ,   9.199344 ,   2.9713812,   3.7156253,   5.5498013,
         11.150637 ,   5.7294626,   4.002439 ,   3.909546 ,   2.221348 ,
...
         10.571391 ,  13.065151 ,   6.8497195,   3.7558386,   3.5340605,
          3.298055 ,   2.0803087,   4.189632 ,   7.7399063,   3.6138294,
          3.9523766,   6.6232204,   6.2300615,   2.3853402,   2.850403 ,
          5.229487 ,   3.8615527,   3.9326599,   3.8529232,   8.671819 ,
          6.301786 ,   7.7922463,   5.2313375,   2.3083024,  65.67436  ,
        100.1539   ,  48.767944 ,  33.916534 ,  32.376404 ,   9.096143 ,
          9.358721 ,   7.4668074,  17.171707 ,  28.13654  , 131.20631  ,
         99.750786 ,  82.21945  , 150.44986  ,   4.315606 ,  15.028996 ,
         19.251741 ,  16.819944 ,  54.600857 ,  10.102056 ,  13.521818 ,
         12.604932 ,  16.998135 ,  30.99062  ,   8.554714 ,   9.850973 ,
          9.87082  ,   7.277955 ,   5.83263  ,   5.397276 ,   4.8431597,
          7.5538855,   3.3288703,  14.987564 ,   7.951935 ,   6.340927 ,
         11.855715 ,   1.8256805,   5.054306 ,   7.583917 ,   7.169403 ,
          7.4230094,   6.0162477,   7.335036 ,   3.146756 ,   2.879428 ,
          7.331243 ,   9.373186 ,   6.303137 ,  11.966533 ,   7.999712 ,
          4.57536  ,   4.0942593,  23.454334 ,   4.570996 ,   7.025208 ,
          3.1406336,   4.4695325,   8.906318 ,  28.438923 ,   6.1750894,
          6.5125046,   3.4833727,   3.2517338,   2.7749972,   2.746849 ,
          9.603409 ,   9.797531 ,   9.178833 ,   8.009731 ,  13.549246 ]],
      dtype=float32)array([[22.70183  , 23.928846 , 23.562433 , 16.289011 , 15.22005  ,
        15.044127 , 16.297873 , 17.127993 , 28.735075 , 16.449953 ,
        21.467855 , 19.445892 , 21.499855 , 22.8823   , 21.110508 ,
        23.649979 , 26.60918  , 30.828    , 25.248333 , 22.17687  ,
        21.898846 , 22.394032 , 19.247723 , 20.5953   , 19.776379 ,
        18.489138 , 18.343906 , 16.611282 , 17.876183 , 18.076878 ,
        20.283577 , 20.365068 , 17.932829 , 17.479145 , 16.480396 ,
        16.9544   , 17.29114  , 17.868351 , 20.70608  , 23.331121 ,
        21.984573 , 21.473406 , 25.425516 , 20.035864 , 24.10131  ,
        22.57912  , 24.053577 , 25.006847 , 25.977291 , 23.044413 ,
        23.399036 , 22.522333 , 22.15956  , 26.54244  , 22.261833 ,
        18.223854 , 20.073006 , 18.710468 , 23.895777 , 18.726654 ,
        17.243109 , 23.96953  , 17.760885 , 16.729784 , 16.708885 ,
        14.98461  , 19.943176 , 18.423456 , 20.433556 , 21.729021 ,
        21.142508 , 24.806545 , 26.408274 , 16.245544 , 15.434979 ,
        20.393446 , 19.199934 , 20.03922  , 18.835415 , 17.204237 ,
        18.52296  , 18.065552 , 20.298271 , 17.421255 , 17.687933 ,
        19.639505 , 19.606153 , 24.838633 , 18.426218 , 21.44172  ,
        22.273306 , 19.690224 , 16.428911 , 21.521833 , 21.86512  ,
        20.675076 , 21.54007  , 17.91468  , 17.8993   , 12.195638 ,
...
        23.518085 , 20.518326 , 24.155817 , 21.604746 , 13.404511 ,
        13.920298 , 14.833247 , 15.817091 , 24.094957 , 17.81413  ,
        17.429472 , 16.82248  , 16.594954 , 19.392002 , 17.35203  ,
        17.853443 , 16.643085 , 19.361786 , 16.21468  , 19.727438 ,
        23.818949 , 22.251196 , 18.389307 , 17.08416  , 28.378626 ,
        25.069221 , 23.163519 , 25.550022 , 23.057487 , 24.160658 ,
        19.997173 , 22.610897 , 22.570648 , 20.691324 , 24.283613 ,
        22.753681 , 25.524794 , 26.249626 , 18.095999 , 22.020086 ,
        21.696339 , 26.190704 , 22.211027 , 25.576822 , 24.824036 ,
        22.590075 , 21.377657 , 25.288351 , 21.2128   , 19.17415  ,
        18.732403 , 18.372541 , 23.47746  , 18.848772 , 19.336374 ,
        19.807137 , 19.745995 , 25.764769 , 22.528057 , 24.722696 ,
        17.32529  , 20.22306  , 21.694635 , 18.270115 , 17.884037 ,
        17.52701  , 16.287584 , 17.014729 , 18.675406 , 17.37952  ,
        19.083565 , 17.537128 , 19.86175  , 17.674118 , 21.235544 ,
        24.079872 , 14.4767065, 21.373323 , 18.028996 , 22.986351 ,
        16.297935 , 18.35883  , 20.795593 , 22.46605  , 20.826462 ,
        19.824308 , 18.015593 , 15.6124115, 13.76765  , 14.794092 ,
        21.357248 , 22.982603 , 22.121471 , 23.893038 , 20.269114 ]],
      dtype=float32)PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
            ...
            740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
           dtype='int64', name='draw', length=750))<xarray.Dataset>
Dimensions:  (chain: 1, draw: 750, y_dim_0: 63)
Coordinates:
  * chain    (chain) int64 0
  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749
  * y_dim_0  (y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62
Data variables:
    y        (chain, draw, y_dim_0) float32 -4.095 -4.083 ... -5.734 -4.017
Attributes:
    created_at:                 2023-05-29T23:56:45.160203
    arviz_version:              0.12.1
    inference_library:          numpyro
    inference_library_version:  0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
       18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
       36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
       54, 55, 56, 57, 58, 59, 60, 61, 62])array([[[-4.094567 , -4.082753 , -4.281054 , ..., -4.11488  ,
         -5.4502263, -4.1267333],
        [-4.1526117, -4.1339474, -4.336127 , ..., -4.174086 ,
         -5.3432612, -4.1656785],
        [-4.1596594, -4.117912 , -4.389084 , ..., -4.1904535,
         -5.235262 , -4.1275744],
        ...,
        [-4.0878344, -4.0446672, -4.3361053, ..., -4.1206446,
         -5.3121624, -4.05793  ],
        [-4.146661 , -4.1236835, -4.3355103, ..., -4.1694927,
         -5.312559 , -4.149831 ],
        [-3.9554722, -3.953834 , -4.1518583, ..., -3.974376 ,
         -5.7338276, -4.017195 ]]], dtype=float32)PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
            ...
            740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
           dtype='int64', name='draw', length=750))PandasIndex(Int64Index([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
            17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
            34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
            51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62],
           dtype='int64', name='y_dim_0'))<xarray.Dataset>
Dimensions:    (chain: 1, draw: 750)
Coordinates:
  * chain      (chain) int64 0
  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 742 743 744 745 746 747 748 749
Data variables:
    diverging  (chain, draw) bool False False False False ... False False False
Attributes:
    created_at:                 2023-05-29T23:56:45.029825
    arviz_version:              0.12.1
    inference_library:          numpyro
    inference_library_version:  0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([[False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
...
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False]])PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
            ...
            740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
           dtype='int64', name='draw', length=750))<xarray.Dataset>
Dimensions:  (y_dim_0: 63)
Coordinates:
  * y_dim_0  (y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62
Data variables:
    y        (y_dim_0) int32 102 107 92 101 110 68 119 ... 97 139 72 100 144 112
Attributes:
    created_at:                 2023-05-29T23:56:45.161050
    arviz_version:              0.12.1
    inference_library:          numpyro
    inference_library_version:  0.9.2array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
       18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
       36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
       54, 55, 56, 57, 58, 59, 60, 61, 62])array([102, 107,  92, 101, 110,  68, 119, 106,  99, 103,  90,  93,  79,
        89, 137, 119, 126, 110,  71, 114, 100,  95,  91,  99,  97, 106,
       106, 129, 115, 124, 137,  73,  69,  95, 102, 116, 111, 134, 102,
       110, 139, 112, 122,  84, 129, 112, 127, 106, 113, 109, 208, 114,
       107,  50, 169, 133,  50,  97, 139,  72, 100, 144, 112], dtype=int32)PandasIndex(Int64Index([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
            17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
            34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
            51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62],
           dtype='int64', name='y_dim_0'))fig = plots.plot_st_2(
    mcmc, smart_group_data.Score.values,
    mean_comp_val=100,
    sigma_comp_val=15,
    effsize_comp_val=0)
fig = numpyro_glm.plot_pairwise_scatter(mcmc, ['mean', 'sigma', 'nu'])
mcmc_key = random.PRNGKey(0)
kernel = NUTS(glm_metric.multi_groups_robust)
mcmc = MCMC(kernel, num_warmup=250, num_samples=750)
mcmc.run(
    mcmc_key,
    jnp.array(iq_data['Score'].values),
    jnp.array(iq_data['Group'].cat.codes.values),
    len(iq_data['Group'].cat.categories),
)
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
   mean[0]     99.29      1.76     99.20     96.70    102.33    806.66      1.00
   mean[1]    107.16      2.62    107.14    102.90    111.36    910.67      1.00
        nu      3.86      1.68      3.45      1.80      5.73    403.33      1.01
  sigma[0]     11.30      1.74     11.10      8.37     13.97    655.02      1.00
  sigma[1]     17.91      2.72     17.71     13.80     22.56    534.41      1.00
Number of divergences: 0
numpyro_glm.plot_diagnostic(mcmc, ['mean', 'sigma', 'nu'])
<xarray.Dataset>
Dimensions:      (chain: 1, draw: 750, mean_dim_0: 2, sigma_dim_0: 2)
Coordinates:
  * chain        (chain) int64 0
  * draw         (draw) int64 0 1 2 3 4 5 6 7 ... 743 744 745 746 747 748 749
  * mean_dim_0   (mean_dim_0) int64 0 1
  * sigma_dim_0  (sigma_dim_0) int64 0 1
Data variables:
    mean         (chain, draw, mean_dim_0) float32 98.21 108.0 ... 94.3 103.5
    nu           (chain, draw) float32 4.904 4.618 2.711 ... 2.909 2.738 2.962
    sigma        (chain, draw, sigma_dim_0) float32 11.06 19.12 ... 10.23 17.3
Attributes:
    created_at:                 2023-05-29T23:56:52.320366
    arviz_version:              0.12.1
    inference_library:          numpyro
    inference_library_version:  0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([0, 1])
array([0, 1])
array([[[ 98.21455 , 107.99677 ],
        [ 98.949814, 109.36058 ],
        [ 98.70745 , 102.67616 ],
        ...,
        [103.06653 , 106.81406 ],
        [ 95.39924 , 107.23023 ],
        [ 94.29891 , 103.466194]]], dtype=float32)array([[ 4.90352  ,  4.617958 ,  2.7109103,  2.9921021,  3.342796 ,
         6.3293333,  3.3233557,  4.0153794,  3.9350548,  2.6246867,
         7.292013 ,  3.497366 ,  2.326993 ,  2.8841724,  2.2890396,
         2.4631345,  4.2260084,  4.054314 ,  4.1366997,  4.4986525,
         2.7777195,  5.1717806,  2.021302 ,  5.061619 ,  5.94571  ,
         3.14477  ,  2.4800138,  2.869721 ,  2.4497128,  2.564416 ,
         2.9521792,  2.1315536,  2.0498023,  3.8954706,  2.0831642,
         2.2922893,  3.3183112,  3.122521 ,  3.0120344,  4.1892695,
         2.5077238,  2.5695941,  2.2404354,  5.06931  ,  3.70294  ,
         3.3386614,  2.5039854,  2.3118012,  3.1710594,  2.3974266,
         2.5402591,  1.959218 ,  3.2319102,  3.2446227,  5.831307 ,
         3.9336195,  3.8203804,  2.7046685,  3.9451866,  3.540849 ,
         3.7571568,  4.5666227,  3.1252599,  2.6628523,  5.4824295,
         3.590702 ,  3.329475 ,  4.470755 ,  3.7153351,  3.0417259,
         7.0245333,  2.8254995,  4.068351 ,  2.6363196,  6.1589856,
         3.6985042,  4.0642147,  2.3869684,  3.4503007,  2.638822 ,
         2.5350208,  2.5853083,  2.5492935,  2.7262902,  5.2686357,
         4.664855 ,  2.334038 ,  3.1042252,  2.6422722,  2.2850242,
         2.8939147,  2.500326 ,  4.2968397,  2.880906 ,  3.437034 ,
         3.8430223,  6.010608 ,  2.493542 ,  3.2565322,  3.7440398,
...
         4.7892776,  4.238168 ,  1.6762153,  3.2976294,  4.928863 ,
         2.3080652,  2.6339862,  4.841153 ,  2.2997952,  1.9764836,
         5.1290674,  2.9647245,  3.793073 ,  3.8011038,  4.1641083,
         4.1243477,  6.333407 ,  2.8298962,  3.2586808,  3.278459 ,
         3.8523297,  3.7573292,  3.8873568,  3.3295789,  3.7441294,
         3.8092866,  3.505789 ,  3.3353996,  3.1752791,  2.874391 ,
         3.7779677,  2.5379913,  6.0346184,  3.2346685,  2.308152 ,
         3.2602334,  9.1525135,  8.748209 ,  4.2054396,  6.5660415,
         4.016606 ,  7.6482983,  4.4520617,  4.8964925,  3.7809463,
         1.9816743,  3.64076  ,  5.9388037,  2.9451606,  3.6500442,
         4.868866 ,  4.1009865,  5.2459726,  2.457327 ,  2.44139  ,
         7.370051 ,  2.4880009,  3.4568632,  3.994024 ,  4.073059 ,
         5.5238442,  4.672189 ,  6.579589 ,  8.251197 ,  3.1003256,
         7.0111647,  2.9936523,  3.945934 ,  2.7069554,  7.7330813,
         3.0613446,  3.1943877,  5.15743  ,  2.1961014,  2.0084407,
         4.7910404,  4.3722544,  4.7984753,  4.1131425,  2.9987164,
         3.3225734,  2.3282595,  2.9577506,  4.8833666,  5.5551987,
         7.0451884,  2.3839269,  2.004416 ,  2.6664572,  3.5382771,
         2.3100057,  2.0453377,  2.9086428,  2.7375379,  2.9617047]],
      dtype=float32)array([[[11.057621 , 19.119974 ],
        [12.7135935, 21.039143 ],
        [10.781828 , 14.102758 ],
        ...,
        [10.174658 , 16.26035  ],
        [11.224569 , 15.996319 ],
        [10.232676 , 17.299318 ]]], dtype=float32)PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
            ...
            740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
           dtype='int64', name='draw', length=750))PandasIndex(Int64Index([0, 1], dtype='int64', name='mean_dim_0'))
PandasIndex(Int64Index([0, 1], dtype='int64', name='sigma_dim_0'))
<xarray.Dataset>
Dimensions:  (chain: 1, draw: 750, y_dim_0: 120)
Coordinates:
  * chain    (chain) int64 0
  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749
  * y_dim_0  (y_dim_0) int64 0 1 2 3 4 5 6 7 ... 112 113 114 115 116 117 118 119
Data variables:
    y        (chain, draw, y_dim_0) float32 -3.979 -3.922 ... -3.338 -5.223
Attributes:
    created_at:                 2023-05-29T23:56:52.493138
    arviz_version:              0.12.1
    inference_library:          numpyro
    inference_library_version:  0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
        14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
        28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
        42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
        56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
        70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
        84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
        98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
       112, 113, 114, 115, 116, 117, 118, 119])array([[[-3.9789479, -3.9219544, -4.3142   , ..., -3.5036287,
         -3.5036287, -5.6269355],
        [-4.092538 , -4.0267043, -4.4053917, ..., -3.6455019,
         -3.6455019, -5.424239 ],
        [-3.6572642, -3.718939 , -4.011526 , ..., -3.5696988,
         -3.5696988, -5.58177  ],
        ...,
        [-3.8501332, -3.7921975, -4.282725 , ..., -3.8901894,
         -3.8901894, -6.1303153],
        [-3.8524227, -3.7809741, -4.3153687, ..., -3.457511 ,
         -3.457511 , -5.203177 ],
        [-3.8573813, -3.8802965, -4.126558 , ..., -3.3382494,
         -3.3382494, -5.222832 ]]], dtype=float32)PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
            ...
            740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
           dtype='int64', name='draw', length=750))PandasIndex(Int64Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
            ...
            110, 111, 112, 113, 114, 115, 116, 117, 118, 119],
           dtype='int64', name='y_dim_0', length=120))<xarray.Dataset>
Dimensions:    (chain: 1, draw: 750)
Coordinates:
  * chain      (chain) int64 0
  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 742 743 744 745 746 747 748 749
Data variables:
    diverging  (chain, draw) bool False False False False ... False False False
Attributes:
    created_at:                 2023-05-29T23:56:52.322597
    arviz_version:              0.12.1
    inference_library:          numpyro
    inference_library_version:  0.9.2array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([[False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
...
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False]])PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
            ...
            740, 741, 742, 743, 744, 745, 746, 747, 748, 749],
           dtype='int64', name='draw', length=750))<xarray.Dataset>
Dimensions:  (y_dim_0: 120)
Coordinates:
  * y_dim_0  (y_dim_0) int64 0 1 2 3 4 5 6 7 ... 112 113 114 115 116 117 118 119
Data variables:
    y        (y_dim_0) int32 102 107 92 101 110 68 119 ... 82 138 99 93 93 72
Attributes:
    created_at:                 2023-05-29T23:56:52.493967
    arviz_version:              0.12.1
    inference_library:          numpyro
    inference_library_version:  0.9.2array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
        14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
        28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
        42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
        56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
        70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
        84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
        98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
       112, 113, 114, 115, 116, 117, 118, 119])array([102, 107,  92, 101, 110,  68, 119, 106,  99, 103,  90,  93,  79,
        89, 137, 119, 126, 110,  71, 114, 100,  95,  91,  99,  97, 106,
       106, 129, 115, 124, 137,  73,  69,  95, 102, 116, 111, 134, 102,
       110, 139, 112, 122,  84, 129, 112, 127, 106, 113, 109, 208, 114,
       107,  50, 169, 133,  50,  97, 139,  72, 100, 144, 112, 109,  98,
       106, 101, 100, 111, 117, 104, 106,  89,  84,  88,  94,  78, 108,
       102,  95,  99,  90, 116,  97, 107, 102,  91,  94,  95,  86, 108,
       115, 108,  88, 102, 102, 120, 112, 100, 105, 105,  88,  82, 111,
        96,  92, 109,  91,  92, 123,  61,  59, 105, 184,  82, 138,  99,
        93,  93,  72], dtype=int32)PandasIndex(Int64Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
            ...
            110, 111, 112, 113, 114, 115, 116, 117, 118, 119],
           dtype='int64', name='y_dim_0', length=120))Plot the resulting posteriors.
fig = plots.plot_st_2(
    mcmc,
    iq_data[iq_data['Group'] == 'Placebo']['Score'].values,
    mean_coords=dict(mean_dim_0=0),
    sigma_coords=dict(sigma_dim_0=0),
    figtitle='Placebo Posteriors')
fig = plots.plot_st_2(
    mcmc,
    iq_data[iq_data['Group'] == 'Smart Drug']['Score'].values,
    mean_coords=dict(mean_dim_0=1),
    sigma_coords=dict(sigma_dim_0=1),
    figtitle='Smart Drug Posteriors')
Then, we will plot the difference between the two groups.
idata = az.from_numpyro(mcmc)
posteriors = idata.posterior
fig, axes = plt.subplots(ncols=3, figsize=(18, 6))
fig.suptitle('Difference between Smart Drug and Placebo')
# Plot mean difference.
ax = axes[0]
mean_difference = (posteriors['mean'].sel(dict(mean_dim_0=1))
                   - posteriors['mean'].sel(dict(mean_dim_0=0)))
az.plot_posterior(
    mean_difference,
    hdi_prob=0.95,
    ref_val=0,
    rope=(-0.5, 0.5),
    point_estimate='mode',
    kind='hist',
    ax=ax)
ax.set_title('Difference of Means')
ax.set_xlabel('$\mu[1] - \mu[0]$')
# Plot sigma difference.
ax = axes[1]
sigma_difference = (posteriors['sigma'].sel(dict(sigma_dim_0=1))
                    - posteriors['sigma'].sel(dict(sigma_dim_0=0)))
az.plot_posterior(
    sigma_difference,
    hdi_prob=0.95,
    ref_val=0,
    rope=(-0.5, 0.5),
    point_estimate='mode',
    kind='hist',
    ax=ax)
ax.set_title('Difference of Scales')
ax.set_xlabel('$\sigma[1] - \sigma[0]$')
# Plot effect size.
ax = axes[2]
sigmas_squared = (posteriors['sigma'].sel(dict(sigma_dim_0=1))**2
                  + posteriors['sigma'].sel(dict(sigma_dim_0=0))**2)
effect_size = mean_difference / np.sqrt(sigmas_squared)
az.plot_posterior(
    effect_size,
    hdi_prob=0.95,
    ref_val=0,
    rope=(-0.1, 0.1),
    point_estimate='mode',
    kind='hist',
    ax=ax)
ax.set_title('Effect Size')
ax.set_xlabel('$(\mu[1] - \mu[0]) / \sqrt{\sigma[1]^2 + \sigma[0]^2}$')
fig.tight_layout()