%cd ..
%load_ext autoreload
%autoreload 2
/home/runner/work/numpyro-doing-bayesian/numpyro-doing-bayesian
import arviz as az
import jax.numpy as jnp
import jax.random as random
import matplotlib.pyplot as plt
import numpy as np
import numpyro
from numpyro.infer import MCMC, NUTS
import numpyro_glm
import numpyro_glm.metric.models as glm_metric
import numpyro_glm.metric.plots as plots
import pandas as pd
from scipy.stats import norm, t
numpyro.render_model( glm_metric.one_group, model_args=(jnp.ones(5), ), render_params=True)
MEAN = 5
STD = 3
N = 100
y = np.random.normal(MEAN, STD, N)
# Plot the histogram and true PDF of normal distribution.
fig, ax = plt.subplots()
ax.hist(y, density=True, label='Histogram of $y$')
xmin, xmax = ax.get_xlim()
p = np.linspace(xmin, xmax, 1000)
ax.plot(p, norm.pdf(p, loc=MEAN, scale=STD), label='Normal PDF')
ax.legend()
fig.tight_layout()
Now, we'll try to apply the metric model on that data to see if it can recover the parameter.
mcmc_key = random.PRNGKey(0)
kernel = NUTS(glm_metric.one_group)
mcmc = MCMC(kernel, num_warmup=250, num_samples=750)
mcmc.run(mcmc_key, jnp.array(y))
mcmc.print_summary()
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
mean std median 5.0% 95.0% n_eff r_hat mean 4.96 0.31 4.97 4.53 5.51 517.16 1.00 std 3.05 0.22 3.03 2.73 3.45 746.19 1.00 Number of divergences: 0
Plot diagnostics plot to see if the MCMC chains are well-behaved.
numpyro_glm.plot_diagnostic(mcmc, ['mean', 'std'])
<xarray.Dataset> Dimensions: (chain: 1, draw: 750) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749 Data variables: mean (chain, draw) float32 5.209 4.529 5.347 5.227 ... 4.773 4.991 4.968 std (chain, draw) float32 2.917 3.103 2.868 2.874 ... 2.672 3.41 3.339 Attributes: created_at: 2023-05-29T23:56:31.752566 arviz_version: 0.12.1 inference_library: numpyro inference_library_version: 0.9.2
array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([[5.209106 , 4.528548 , 5.3471146, 5.227207 , 4.2945375, 4.9182925, 4.7712283, 5.2968564, 5.132386 , 4.844204 , 5.586605 , 4.487925 , 5.3550844, 5.4572015, 4.448584 , 4.9535537, 5.02507 , 4.8676105, 5.031705 , 5.547575 , 5.0685067, 4.779368 , 5.002617 , 5.0431023, 5.4613724, 4.8222647, 5.0249243, 4.903239 , 5.109538 , 4.791892 , 5.222782 , 5.2325306, 4.6903687, 4.693336 , 4.9970613, 4.7783775, 4.8024116, 5.4343495, 5.04768 , 4.604811 , 4.6258793, 4.529574 , 4.5864644, 4.955139 , 4.8493094, 5.0941453, 4.96266 , 4.893078 , 5.227385 , 4.9269013, 5.0402994, 5.046525 , 5.125822 , 4.6672425, 4.203383 , 4.907324 , 4.990067 , 4.1970205, 4.35923 , 4.966628 , 4.88882 , 5.092228 , 5.1102777, 4.8794503, 5.1425686, 4.9433007, 4.9666333, 5.1734595, 4.7041483, 5.017935 , 4.9279866, 5.1911583, 5.1911583, 5.011348 , 5.0969815, 4.9910264, 4.695913 , 4.8254447, 4.7390966, 5.0573435, 4.6840773, 5.1652255, 5.024248 , 4.6640615, 4.5730023, 5.0042996, 4.979213 , 4.7948694, 4.718663 , 4.583088 , 4.7011843, 5.0055323, 4.8422976, 4.8485436, 4.965094 , 4.7230763, 5.069906 , 4.9947624, 4.9111443, 4.2580767, 5.2471714, 4.5457106, 4.775864 , 5.025947 , 5.061445 , 4.8522716, 5.1862683, 5.0539246, 5.0416985, 5.168807 , 4.7731266, 4.703271 , 4.8041005, 4.589089 , 4.500424 , 4.4929605, 4.4101706, 5.3536725, 5.31576 , 4.583191 , ... 4.989824 , 4.7870226, 5.1676106, 4.773499 , 4.433352 , 5.527904 , 5.376922 , 5.3109035, 4.531634 , 5.121908 , 5.2329516, 4.5748124, 4.9590473, 4.995907 , 4.777923 , 4.907775 , 5.110473 , 5.484889 , 4.984773 , 4.9303713, 4.9617133, 4.9597883, 4.201158 , 4.108816 , 5.5899467, 4.4740453, 5.1457696, 5.0359325, 5.559925 , 5.5884337, 5.177686 , 5.4082894, 4.67077 , 4.835922 , 5.263118 , 5.2465863, 5.2224298, 5.4455686, 5.6108904, 4.6170073, 4.7014694, 4.8181043, 5.2965465, 4.779739 , 4.9903846, 4.6685143, 5.001776 , 5.2450037, 4.6075706, 5.258145 , 5.197925 , 4.683494 , 5.085534 , 5.0441504, 5.057175 , 4.7345133, 5.0199018, 5.068657 , 4.5667806, 4.5667806, 5.1042504, 5.0103126, 5.3162837, 5.049756 , 5.0858436, 4.8877597, 4.871423 , 4.964435 , 4.952388 , 5.177528 , 4.7585597, 4.937545 , 5.2067914, 5.404242 , 5.3792343, 5.4727798, 5.503424 , 3.9227598, 4.616737 , 5.349469 , 4.94926 , 5.276943 , 4.5606093, 4.350913 , 4.8356533, 5.3012824, 4.9218187, 5.049454 , 5.0492573, 5.3923044, 5.154832 , 5.172335 , 4.881768 , 4.9017844, 4.2971644, 5.148047 , 4.9634027, 4.969659 , 5.075037 , 4.83324 , 4.8330393, 5.2102723, 4.702924 , 4.691301 , 5.379582 , 4.30379 , 5.499221 , 4.7738557, 5.0708055, 4.6337047, 4.5571523, 4.7732024, 4.9907207, 4.967629 ]], dtype=float32)
array([[2.916654 , 3.1028943, 2.8679268, 2.8739176, 3.1552985, 3.1833546, 2.927499 , 2.9361997, 3.2647386, 2.7347465, 3.6316378, 3.186601 , 2.9091659, 2.878644 , 3.0915058, 3.2480638, 2.8444803, 2.9026027, 3.0260732, 3.271102 , 3.2191224, 3.219358 , 2.9903677, 3.2239099, 3.1981618, 3.1861975, 2.8435102, 3.0665843, 2.7998087, 3.2023396, 3.0847368, 3.129275 , 2.832983 , 2.7954605, 3.4137855, 2.7664967, 2.963932 , 3.4414675, 2.8450003, 3.208023 , 3.1765056, 3.1218975, 3.1018298, 2.775923 , 3.06603 , 2.757947 , 3.0582092, 3.1054397, 2.8913946, 2.94138 , 3.5591953, 3.0489569, 3.0945673, 3.0281289, 2.956422 , 3.021427 , 3.1175191, 2.9662404, 2.9738827, 3.0137942, 3.0707273, 2.9698124, 3.1119697, 2.8752227, 3.112528 , 3.043435 , 2.8680277, 3.205892 , 3.4510348, 2.9880588, 2.9658875, 2.6942782, 2.6942782, 2.6509972, 2.7107518, 3.024229 , 2.7974968, 2.8184597, 3.330964 , 3.1583602, 2.9045746, 2.8659198, 3.016166 , 3.070869 , 3.4217649, 2.9372633, 3.2278461, 2.8106005, 2.9665637, 2.9308615, 2.7869065, 3.2926528, 2.7854984, 3.6578674, 3.8099601, 2.665423 , 3.0543578, 2.9724834, 3.0155618, 3.133643 , 2.9342916, 3.0967343, 3.5509822, 3.3757555, 3.491986 , 3.4743314, 2.8963757, 3.2922297, 2.6979878, 3.2163372, 2.8820283, 2.8636553, 3.1171484, 3.261943 , 3.330661 , 3.3311145, 3.3987849, 2.7803333, 2.9204037, 3.2823446, ... 2.9119968, 3.0894525, 3.6251438, 3.1715524, 3.0566335, 2.939902 , 2.8716073, 2.9746158, 2.9994235, 3.2255704, 3.2961411, 2.9679601, 3.4174898, 3.3049793, 2.9214787, 3.075012 , 2.8656034, 3.1087933, 2.8138204, 3.1475034, 3.1343572, 3.1616294, 3.1369445, 3.1915271, 2.875889 , 3.1248076, 3.087011 , 2.9718072, 3.2590137, 3.2520576, 3.185291 , 2.859958 , 2.821579 , 2.781268 , 2.8592463, 2.9062448, 3.534457 , 2.9518278, 3.3549776, 2.7666702, 2.7543926, 3.254572 , 2.9607222, 3.1747704, 3.569678 , 3.118485 , 2.9555767, 2.847376 , 2.9998915, 3.1964002, 3.1241221, 3.0805674, 3.368725 , 2.869935 , 2.7692757, 2.6986244, 3.442483 , 3.3160274, 2.842026 , 2.842026 , 3.2840214, 3.4086637, 2.5889783, 2.8690841, 3.1804018, 3.1136734, 3.1699562, 2.8367765, 2.8882794, 3.1074224, 2.8305895, 2.9813702, 3.1293077, 2.8376162, 2.9135425, 3.3758152, 3.674394 , 2.779959 , 3.299925 , 3.1810057, 3.2408876, 2.957203 , 3.4878654, 2.9024978, 3.3169794, 3.256874 , 3.084282 , 2.9575238, 2.7949817, 3.0721962, 3.303253 , 3.2028015, 2.9452498, 3.3317125, 3.3685954, 2.8577948, 3.2133477, 2.846444 , 3.1764314, 2.807072 , 2.778171 , 2.8153605, 3.2772346, 2.8095682, 3.1858308, 2.9680984, 2.7983553, 3.3925478, 2.7187293, 3.3872933, 3.0770774, 2.6717114, 3.4103987, 3.339314 ]], dtype=float32)
PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 740, 741, 742, 743, 744, 745, 746, 747, 748, 749], dtype='int64', name='draw', length=750))
<xarray.Dataset> Dimensions: (chain: 1, draw: 750, y_dim_0: 100) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749 * y_dim_0 (y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 91 92 93 94 95 96 97 98 99 Data variables: y (chain, draw, y_dim_0) float32 -2.681 -1.99 ... -2.229 -2.194 Attributes: created_at: 2023-05-29T23:56:31.843696 arviz_version: 0.12.1 inference_library: numpyro inference_library_version: 0.9.2
array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99])
array([[[-2.6805243, -1.9896044, -2.146286 , ..., -2.2449687, -2.0858808, -2.048276 ], [-2.4436066, -2.0799394, -2.3294597, ..., -2.4485612, -2.2511694, -2.198129 ], [-2.746058 , -1.9728756, -2.108558 , ..., -2.2030478, -2.051998 , -2.0178077], ..., [-2.5292451, -1.9190506, -2.2017465, ..., -2.3469207, -2.1082296, -2.0462952], [-2.5889425, -2.1491568, -2.2932622, ..., -2.3739157, -2.2424622, -2.2096944], [-2.5803223, -2.1288443, -2.2824044, ..., -2.3674629, -2.2286887, -2.1939304]]], dtype=float32)
PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 740, 741, 742, 743, 744, 745, 746, 747, 748, 749], dtype='int64', name='draw', length=750))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99], dtype='int64', name='y_dim_0'))
<xarray.Dataset> Dimensions: (chain: 1, draw: 750) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 742 743 744 745 746 747 748 749 Data variables: diverging (chain, draw) bool False False False False ... False False False Attributes: created_at: 2023-05-29T23:56:31.764392 arviz_version: 0.12.1 inference_library: numpyro inference_library_version: 0.9.2
array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, ... False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]])
PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 740, 741, 742, 743, 744, 745, 746, 747, 748, 749], dtype='int64', name='draw', length=750))
<xarray.Dataset> Dimensions: (y_dim_0: 100) Coordinates: * y_dim_0 (y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 91 92 93 94 95 96 97 98 99 Data variables: y (y_dim_0) float32 1.78 5.272 6.843 6.101 ... 6.247 7.294 6.49 6.21 Attributes: created_at: 2023-05-29T23:56:31.844516 arviz_version: 0.12.1 inference_library: numpyro inference_library_version: 0.9.2
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99])
array([ 1.7799622 , 5.2715006 , 6.843007 , 6.101341 , 1.3784775 , 3.5009139 , 7.0576353 , 7.4335885 , 8.715986 , 1.1754153 , 4.9838333 , 1.7403774 , 5.123005 , 1.4687153 , 6.9200506 , 6.5501447 , 7.0576353 , 4.713186 , 4.6818757 , 6.7534842 , 1.7409545 , 3.115359 , 10.992229 , 3.3846278 , 2.3424428 , 6.3324614 , 0.99764264, 1.5388296 , 10.105146 , 11.020432 , 4.220744 , 2.1013582 , 6.6121535 , 5.773824 , 0.9355608 , 7.78474 , 4.040793 , 2.8486586 , 2.0916374 , 4.882524 , 5.1691275 , 3.7109454 , 6.0718546 , 4.425353 , 1.1271195 , 3.8772562 , 10.373571 , 3.238866 , 7.975573 , 10.97725 , 1.273873 , 1.2262683 , 2.7989383 , 4.9672947 , 0.7470391 , 3.253747 , 3.8842838 , -0.34159485, 2.9436347 , -2.2525995 , 9.075224 , 8.399537 , 3.4938996 , -1.7792448 , 1.692837 , 10.415877 , 3.1210582 , 6.8077483 , 9.530612 , 4.725291 , 9.805911 , 5.4888196 , 2.1056437 , 6.0538917 , 3.962473 , 5.519883 , 4.078588 , 7.26097 , 4.400875 , 0.30348757, 3.18701 , 4.809211 , 8.965049 , 4.0829873 , 5.228887 , 9.951787 , 5.6544847 , 0.64701927, 9.465937 , 7.890076 , 4.776916 , 4.893428 , 10.160401 , 2.3604243 , 7.6586046 , 7.2983284 , 6.2470894 , 7.294434 , 6.4904785 , 6.210163 ], dtype=float32)
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99], dtype='int64', name='y_dim_0'))
fig = plots.plot_st(mcmc, y)
iq_data = pd.read_csv('datasets/TwoGroupIQ.csv')
iq_data['Group'] = iq_data['Group'].astype('category')
smart_group_data = iq_data[iq_data.Group == 'Smart Drug']
smart_group_data.describe()
Score | |
---|---|
count | 63.000000 |
mean | 107.841270 |
std | 25.445201 |
min | 50.000000 |
25% | 96.000000 |
50% | 107.000000 |
75% | 119.000000 |
max | 208.000000 |
Then, we will apply the one group model to the data and plot the results.
mcmc_key = random.PRNGKey(0)
kernel = NUTS(glm_metric.one_group)
mcmc = MCMC(kernel, num_warmup=250, num_samples=750)
mcmc.run(mcmc_key, jnp.array(smart_group_data.Score.values))
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat mean 107.80 3.24 107.89 102.57 113.19 539.80 1.00 std 26.08 2.35 25.88 22.73 30.35 642.03 1.00 Number of divergences: 0
numpyro_glm.plot_diagnostic(mcmc, ['mean', 'std'])
<xarray.Dataset> Dimensions: (chain: 1, draw: 750) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749 Data variables: mean (chain, draw) float32 110.0 105.3 110.2 109.4 ... 106.8 108.7 108.2 std (chain, draw) float32 24.79 26.16 24.03 23.97 ... 22.13 29.33 28.75 Attributes: created_at: 2023-05-29T23:56:35.414985 arviz_version: 0.12.1 inference_library: numpyro inference_library_version: 0.9.2
array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([[109.97527 , 105.29083 , 110.214516, 109.411125, 100.93915 , 108.557785, 109.7286 , 107.717476, 110.33823 , 105.4407 , 115.047325, 110.74487 , 109.90371 , 111.92637 , 103.67735 , 109.43993 , 106.924866, 105.82306 , 109.33133 , 113.90438 , 108.25385 , 105.84257 , 108.72531 , 109.02865 , 113.91567 , 113.41177 , 103.41682 , 108.057 , 104.65977 , 110.4926 , 113.583885, 102.19425 , 113.01051 , 111.27577 , 106.85726 , 102.25837 , 107.078285, 113.12787 , 107.06807 , 104.69756 , 104.80107 , 103.58743 , 104.322945, 108.6351 , 106.419586, 109.62201 , 107.34422 , 107.35962 , 110.8534 , 106.74869 , 109.08335 , 109.06311 , 108.5331 , 104.78836 , 116.33449 , 114.67767 , 112.495895, 104.07471 , 106.779434, 107.259605, 108.09704 , 108.10233 , 106.08486 , 109.5852 , 106.001076, 105.14359 , 109.282585, 110.47549 , 104.854195, 108.91195 , 106.613556, 110.542786, 110.542786, 108.08629 , 109.27255 , 107.772766, 104.430824, 103.29096 , 110.69535 , 112.892365, 106.9933 , 108.17603 , 108.825 , 105.24313 , 104.06096 , 109.17792 , 107.244606, 108.34675 , 104.49413 , 106.83144 , 104.965004, 110.35247 , 105.32731 , 109.82167 , 105.720985, 107.46112 , 110.59209 , 109.11344 , 106.73796 , 99.25905 , ... 106.25765 , 107.159256, 107.43473 , 114.022736, 107.4383 , 109.01808 , 103.7021 , 111.57398 , 110.96743 , 113.67769 , 114.12441 , 109.39441 , 104.46808 , 105.55915 , 105.599365, 111.64049 , 110.684875, 109.979706, 113.008514, 113.94847 , 103.412254, 104.3738 , 108.75197 , 111.8318 , 104.77148 , 107.812775, 104.56448 , 108.718765, 105.21872 , 105.55679 , 104.245285, 111.70955 , 104.24466 , 111.4618 , 110.510574, 107.00754 , 110.02091 , 104.797165, 106.90024 , 112.80083 , 109.23521 , 105.043015, 105.97836 , 103.32183 , 108.15405 , 107.78267 , 107.18314 , 106.99867 , 108.12765 , 107.89384 , 109.307655, 112.93956 , 112.01034 , 110.8072 , 112.965576, 112.18863 , 112.52543 , 108.96955 , 102.037285, 107.1725 , 104.57445 , 109.20819 , 111.1646 , 102.88633 , 108.055016, 104.86002 , 111.12093 , 106.58487 , 109.66695 , 109.28379 , 113.10869 , 109.12012 , 109.58398 , 106.80441 , 107.53587 , 100.24453 , 113.68477 , 106.12561 , 106.85108 , 108.66731 , 106.49102 , 106.621 , 110.45593 , 104.96861 , 113.12082 , 111.35625 , 102.31912 , 106.25812 , 110.87694 , 109.251045, 107.50152 , 105.6816 , 106.77369 , 108.67319 , 108.24811 ]], dtype=float32)
array([[24.786228, 26.155521, 24.029436, 23.965836, 27.113821, 27.424044, 25.806293, 26.273508, 27.535105, 22.97756 , 32.289055, 24.15046 , 24.51352 , 24.034502, 26.246254, 27.881187, 23.841017, 24.366045, 25.547419, 28.463753, 27.826414, 27.78451 , 25.26547 , 27.71857 , 27.465443, 26.184599, 24.367947, 28.29067 , 29.094795, 22.382898, 23.262806, 28.737497, 22.244843, 25.077694, 29.279648, 27.409065, 24.064144, 30.407253, 23.98416 , 27.507038, 27.295832, 26.6867 , 26.46073 , 22.920782, 25.979725, 22.771954, 25.901804, 26.450947, 24.121893, 24.683409, 31.45096 , 26.017336, 26.34012 , 25.636229, 26.706406, 27.788372, 28.212412, 26.795696, 26.954515, 24.471197, 27.04578 , 24.21856 , 25.055971, 25.211184, 25.342659, 25.057606, 23.631056, 30.170628, 30.582378, 25.426617, 24.681257, 22.228506, 22.228506, 21.75323 , 22.333263, 25.517614, 23.16428 , 23.095354, 29.873755, 27.750286, 24.93584 , 23.626965, 25.444895, 26.3856 , 30.052254, 24.859749, 22.761942, 28.185104, 23.586464, 25.712185, 22.451385, 29.302467, 22.616674, 22.743664, 23.170364, 23.876362, 27.080822, 25.806063, 24.978598, 26.404415, 27.514174, 26.977419, 31.662498, 29.875046, 31.077866, 25.236607, 24.985004, 28.296259, 22.199097, 27.613825, 24.109753, 23.913204, 26.575354, 28.296556, 29.036882, 29.066261, 29.816078, 23.089298, 24.582733, 26.92522 , ... 24.014801, 27.789244, 32.343906, 27.632095, 26.07815 , 24.650728, 23.914785, 25.040188, 25.317043, 27.565416, 28.408096, 25.107414, 22.729567, 23.612532, 23.126312, 26.601177, 23.807423, 24.42951 , 23.662622, 26.85912 , 26.753407, 27.049782, 25.479877, 26.435293, 23.96122 , 26.330206, 26.018835, 25.607841, 28.446865, 28.278666, 27.410109, 28.209177, 30.121782, 30.168917, 22.123705, 23.426945, 31.40853 , 25.059225, 29.351217, 27.96904 , 25.75235 , 31.656654, 25.066679, 27.139214, 31.495024, 26.959902, 25.11425 , 26.711693, 24.692608, 23.715153, 27.323866, 26.20297 , 24.107151, 24.007935, 24.023855, 26.436314, 25.4301 , 25.759718, 24.9387 , 31.076105, 22.513962, 25.891182, 29.503733, 23.962814, 27.266918, 26.573658, 27.158567, 23.648088, 24.148748, 28.249935, 30.239103, 29.819923, 26.691223, 23.647682, 24.47609 , 29.486397, 24.319372, 30.137695, 32.97399 , 23.286343, 28.740875, 24.984743, 30.649132, 34.445366, 27.562742, 27.380943, 25.81166 , 25.387672, 23.473827, 26.251698, 28.780266, 27.655264, 24.728743, 28.889326, 29.331318, 24.23243 , 27.918295, 23.9464 , 27.325426, 23.267876, 22.971722, 23.432446, 28.220026, 27.862444, 27.050863, 25.492939, 29.31498 , 23.50533 , 24.160336, 27.551594, 24.947176, 22.131338, 29.331263, 28.75127 ]], dtype=float32)
PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 740, 741, 742, 743, 744, 745, 746, 747, 748, 749], dtype='int64', name='draw', length=750))
<xarray.Dataset> Dimensions: (chain: 1, draw: 750, y_dim_0: 63) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749 * y_dim_0 (y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62 Data variables: y (chain, draw, y_dim_0) float32 -4.181 -4.136 ... -5.051 -4.286 Attributes: created_at: 2023-05-29T23:56:35.499221 arviz_version: 0.12.1 inference_library: numpyro inference_library_version: 0.9.2
array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62])
array([[[-4.180992 , -4.136431 , -4.392193 , ..., -4.210211 , -5.0714164, -4.132563 ], [-4.1909137, -4.185134 , -4.312105 , ..., -4.203458 , -5.278142 , -4.2158976], [-4.1566496, -4.107166 , -4.3855066, ..., -4.1885657, -5.086642 , -4.1009784], ..., [-4.039196 , -4.0159855, -4.238742 , ..., -4.062772 , -5.430601 , -4.0438166], [-4.3234735, -4.2992196, -4.459157 , ..., -4.341311 , -5.022892 , -4.3040247], [-4.3012333, -4.2785625, -4.4373045, ..., -4.31877 , -5.050753 , -4.2861347]]], dtype=float32)
PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 740, 741, 742, 743, 744, 745, 746, 747, 748, 749], dtype='int64', name='draw', length=750))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62], dtype='int64', name='y_dim_0'))
<xarray.Dataset> Dimensions: (chain: 1, draw: 750) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 742 743 744 745 746 747 748 749 Data variables: diverging (chain, draw) bool False False False False ... False False False Attributes: created_at: 2023-05-29T23:56:35.416760 arviz_version: 0.12.1 inference_library: numpyro inference_library_version: 0.9.2
array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, ... False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]])
PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 740, 741, 742, 743, 744, 745, 746, 747, 748, 749], dtype='int64', name='draw', length=750))
<xarray.Dataset> Dimensions: (y_dim_0: 63) Coordinates: * y_dim_0 (y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62 Data variables: y (y_dim_0) int32 102 107 92 101 110 68 119 ... 97 139 72 100 144 112 Attributes: created_at: 2023-05-29T23:56:35.500056 arviz_version: 0.12.1 inference_library: numpyro inference_library_version: 0.9.2
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62])
array([102, 107, 92, 101, 110, 68, 119, 106, 99, 103, 90, 93, 79, 89, 137, 119, 126, 110, 71, 114, 100, 95, 91, 99, 97, 106, 106, 129, 115, 124, 137, 73, 69, 95, 102, 116, 111, 134, 102, 110, 139, 112, 122, 84, 129, 112, 127, 106, 113, 109, 208, 114, 107, 50, 169, 133, 50, 97, 139, 72, 100, 144, 112], dtype=int32)
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62], dtype='int64', name='y_dim_0'))
fig = plots.plot_st(
mcmc, smart_group_data.Score.values,
mean_comp_val=100,
std_comp_val=15,
effsize_comp_val=0,
)
numpyro.render_model( glm_metric.one_group_robust, model_args=(jnp.ones(5),), render_params=True)
MEAN = 5
SIGMA = 3
NORMALITY = 3
N = 1000
y = np.random.standard_t(NORMALITY, size=N) * SIGMA + MEAN
# Plot the histogram and true PDF of normal distribution.
fig, ax = plt.subplots()
ax.hist(y, density=True, bins=100, label='Histogram of $y$')
xmin, xmax = ax.get_xlim()
p = np.linspace(xmin, xmax, 1000)
ax.plot(p, t.pdf(p, loc=MEAN, scale=SIGMA, df=NORMALITY), label='Student-$t$ PDF')
ax.plot(p, norm.pdf(p, loc=y.mean(), scale=y.std()),
label='Normal PDF using\ndata mean and std')
ax.legend()
fig.tight_layout()
Using the robust metric model on our synthesis data to see if it can recover the original parameters.
mcmc_key = random.PRNGKey(0)
kernel = NUTS(glm_metric.one_group_robust)
mcmc = MCMC(kernel, num_warmup=250, num_samples=750)
mcmc.run(mcmc_key, jnp.array(y))
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat mean 5.19 0.12 5.19 5.01 5.38 528.65 1.01 nu 3.02 0.32 3.00 2.48 3.54 389.25 1.00 sigma 2.97 0.12 2.97 2.78 3.17 351.15 1.00 Number of divergences: 0
numpyro_glm.plot_diagnostic(mcmc, ['mean', 'sigma', 'nu'])
<xarray.Dataset> Dimensions: (chain: 1, draw: 750) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749 Data variables: mean (chain, draw) float32 5.225 5.124 5.198 5.033 ... 5.201 5.154 5.112 nu (chain, draw) float32 3.03 2.811 3.201 2.794 ... 3.153 3.204 3.224 sigma (chain, draw) float32 3.008 2.894 3.055 2.892 ... 3.009 3.049 3.047 Attributes: created_at: 2023-05-29T23:56:40.651587 arviz_version: 0.12.1 inference_library: numpyro inference_library_version: 0.9.2
array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([[5.2252526, 5.1241636, 5.1982026, 5.032679 , 5.0717063, 5.295204 , 5.3268604, 5.2793097, 5.081863 , 5.24317 , 5.1877646, 5.13194 , 5.097041 , 5.299657 , 5.100597 , 5.3192477, 5.197055 , 5.197055 , 5.335744 , 5.362243 , 5.265542 , 5.07826 , 5.033498 , 5.158056 , 5.043956 , 5.178067 , 5.2793436, 5.1069593, 5.216472 , 5.0388417, 5.3190384, 5.118554 , 5.3592353, 5.297184 , 5.054299 , 5.332943 , 5.1977186, 5.186516 , 5.0346336, 5.0033903, 5.351208 , 5.346388 , 5.2818646, 5.0247946, 5.3952494, 5.3240824, 5.038796 , 5.351271 , 5.0122066, 5.106121 , 5.239272 , 5.0579095, 5.178805 , 5.0543776, 4.985325 , 5.383977 , 5.3296695, 5.2475777, 5.2292995, 5.3917494, 5.176481 , 5.0985913, 5.0745983, 5.3064804, 5.3190703, 5.3065405, 5.073111 , 5.127397 , 5.1585197, 5.004515 , 5.0588617, 5.050305 , 4.9809313, 5.024628 , 5.301579 , 5.290991 , 5.326011 , 5.0710893, 5.229229 , 5.140679 , 5.171478 , 5.1049523, 5.4097652, 4.9618893, 5.0410748, 5.192658 , 5.006359 , 5.201808 , 5.360339 , 5.293335 , 5.180471 , 5.2222195, 5.106078 , 5.0473633, 5.445147 , 5.3690085, 5.356301 , 5.314369 , 5.1836066, 5.214131 , 5.137178 , 5.019918 , 5.1467023, 5.120248 , 5.177939 , 5.1632943, 5.2115273, 5.020947 , 5.232629 , 5.138412 , 5.414697 , 5.2546616, 5.315323 , 5.3316145, 5.029754 , 5.271732 , 5.2580037, 4.9675856, 4.854825 , 5.1653094, ... 5.1836634, 5.1777854, 5.0916348, 5.1517425, 4.94229 , 5.338954 , 5.3075275, 4.9510856, 5.355672 , 5.050706 , 5.1464505, 5.1667733, 5.102213 , 5.142809 , 5.2244234, 5.0835414, 5.2575097, 5.3360806, 5.255036 , 5.1094394, 5.2511826, 5.383278 , 5.104622 , 5.010276 , 5.374145 , 5.1573734, 4.9785376, 5.130616 , 5.086193 , 5.1101737, 5.196637 , 5.2557135, 5.2129374, 5.216792 , 5.3181276, 5.152473 , 5.1690083, 5.2550898, 5.1514096, 5.111661 , 5.12744 , 5.225885 , 5.371412 , 5.057382 , 5.313378 , 5.198703 , 5.218926 , 5.1541963, 5.0360336, 5.1063595, 5.252289 , 5.109053 , 5.1100416, 5.109926 , 5.4005275, 5.206244 , 5.423744 , 5.055256 , 5.262973 , 5.369212 , 5.2393475, 5.167504 , 5.057998 , 5.1727924, 5.123668 , 5.077471 , 5.277476 , 5.152246 , 5.3775187, 5.2343764, 5.320696 , 5.0608945, 5.297638 , 5.2998834, 5.0541434, 5.2953877, 5.295298 , 5.37311 , 5.287508 , 5.177804 , 5.1395016, 4.9199777, 4.9277735, 4.9639325, 5.01235 , 5.1524143, 5.2529345, 5.2919297, 5.009624 , 5.0876145, 5.0876145, 5.080562 , 5.2489963, 5.2815833, 5.2815833, 5.132526 , 5.18512 , 5.1307178, 5.247394 , 5.1012545, 5.0817285, 5.4392815, 5.058194 , 5.067048 , 5.154724 , 5.154724 , 5.00218 , 5.0527177, 5.2726903, 5.21625 , 5.3625455, 5.200683 , 5.1538286, 5.11227 ]], dtype=float32)
array([[3.0302858, 2.8105295, 3.2014136, 2.7935388, 2.9985032, 2.5012183, 3.0470521, 2.9146385, 3.8911355, 2.3872564, 2.428279 , 2.8282342, 3.0597894, 2.883507 , 3.410241 , 3.499199 , 2.655398 , 2.655398 , 2.6596022, 3.048019 , 2.955171 , 2.7809849, 2.9156413, 2.7100785, 2.7034278, 2.8511982, 2.8510652, 2.8014731, 2.724396 , 3.0761416, 3.1426032, 2.7679362, 3.0328808, 3.1469035, 3.0298078, 2.9144955, 2.9299479, 2.8942003, 2.8006005, 3.2382188, 3.0903177, 3.0062823, 2.855648 , 3.045023 , 3.5830653, 3.472515 , 3.0960622, 3.2011704, 3.1031928, 3.4019434, 3.2195787, 3.0908947, 3.0280414, 3.4605806, 3.024026 , 3.2612574, 3.1635637, 2.7193754, 3.5790129, 3.5917504, 2.7907567, 3.0398004, 3.322532 , 3.58795 , 3.3344655, 3.4647434, 3.0550628, 3.6886106, 2.9566228, 2.8622136, 2.967669 , 3.297165 , 3.111523 , 2.8584785, 2.6858768, 3.014183 , 2.8630064, 2.8034956, 3.193019 , 2.812081 , 2.8540955, 2.8850079, 3.1799855, 3.414366 , 3.5967433, 2.7563825, 2.9630618, 3.2624266, 3.1123424, 3.0870428, 2.725316 , 2.961867 , 3.3157516, 3.561547 , 2.7511356, 2.7754855, 3.1627483, 2.84977 , 3.4220443, 2.3197072, 2.8218226, 2.4043255, 2.6734753, 3.0119743, 3.0013568, 2.9657485, 3.18316 , 3.5373964, 3.4768598, 3.2778487, 3.138904 , 3.0696943, 2.8126042, 2.8059587, 2.7557168, 2.707277 , 2.7965493, 2.5954356, 2.7712111, 2.5360422, ... 3.2974737, 2.608031 , 3.3340845, 3.1168242, 3.371426 , 2.484195 , 2.5683467, 3.5349627, 3.216056 , 2.9181938, 2.941831 , 3.331728 , 3.2328591, 3.0172265, 2.7155166, 2.9464223, 2.7724924, 2.6816273, 2.5984507, 3.5753417, 3.6234405, 3.1972866, 3.0058103, 2.4069138, 3.2024572, 3.1628978, 2.934284 , 2.6520753, 2.5028887, 2.4886835, 2.3843923, 2.5833275, 2.5340605, 2.590158 , 2.5927703, 3.110703 , 3.1948829, 3.28651 , 3.098055 , 3.193087 , 3.2967393, 3.358853 , 3.7556908, 3.1511366, 2.7770283, 3.003881 , 2.752591 , 3.2983627, 3.2164586, 3.1108088, 3.0520477, 2.815367 , 2.796043 , 2.3711905, 2.977821 , 2.7997148, 2.4689636, 4.148687 , 3.509022 , 3.2277522, 3.6996043, 4.011406 , 3.2496595, 2.5972798, 2.8907168, 3.082364 , 3.1959908, 3.0188456, 3.1784215, 2.9925566, 3.1192105, 2.6216607, 2.9329195, 2.8352685, 2.9127254, 2.5918915, 3.7815177, 3.487691 , 3.2226639, 3.4023337, 2.5276363, 2.8847451, 2.830494 , 2.9687147, 3.082112 , 2.9967947, 3.1591215, 2.67426 , 2.7939107, 2.9403849, 2.9403849, 2.820949 , 3.3997335, 2.928617 , 2.928617 , 2.49297 , 2.991461 , 3.1410506, 3.071447 , 3.406098 , 2.8856087, 2.9846396, 3.5282505, 2.6596055, 2.9256043, 2.9256043, 2.6848874, 2.655949 , 2.6168833, 3.115498 , 3.1219566, 3.1531057, 3.203574 , 3.224131 ]], dtype=float32)
array([[3.0077553, 2.893796 , 3.0548732, 2.891678 , 2.8712099, 2.8375256, 2.826606 , 2.9720817, 3.263655 , 2.7567494, 2.7421374, 2.9102821, 2.9675138, 2.9290013, 3.1114519, 3.1912806, 2.7879858, 2.7879858, 2.864312 , 2.8857582, 2.9045377, 2.9562974, 2.9119508, 2.9810064, 3.0166419, 2.9744427, 2.8893104, 2.95488 , 2.932642 , 2.9176862, 2.9649498, 2.8088963, 2.9177237, 2.9527833, 2.8886452, 2.962949 , 2.9381344, 2.8418946, 2.9687228, 3.1053722, 3.0414934, 2.9902358, 2.8479242, 3.091427 , 3.006934 , 3.0294032, 2.9957023, 3.0173352, 3.171182 , 3.0915844, 2.998457 , 3.0699952, 3.1018872, 2.930065 , 3.0457103, 2.9077508, 3.0073476, 2.8968403, 3.14788 , 3.1162343, 2.9465456, 3.0800798, 2.8778477, 3.0915356, 3.1263587, 2.9847069, 3.1584203, 3.0032854, 3.094975 , 2.8946183, 2.9876933, 3.0804462, 2.974532 , 2.8442578, 2.8904667, 3.00443 , 2.925379 , 2.9279077, 3.0053437, 2.8640773, 2.9765718, 2.9506266, 3.0112603, 3.2222466, 3.1745212, 3.0268316, 2.8382347, 3.2115846, 2.96255 , 2.880831 , 2.9258776, 2.9168687, 3.1384408, 3.0122447, 2.756393 , 3.0167165, 2.9385285, 2.7939687, 3.0074434, 2.661758 , 2.695194 , 2.7966495, 2.890351 , 3.021478 , 2.91477 , 2.9491024, 3.0953932, 3.014392 , 3.0871012, 3.0097551, 3.297201 , 3.0935102, 2.9835587, 2.9820387, 2.8816643, 2.9155366, 2.946278 , 2.7360942, 2.7621183, 2.9129941, ... 3.1272457, 3.000786 , 2.9547353, 3.1654298, 3.1500812, 2.7814696, 2.8203695, 3.0340738, 3.0124326, 3.1742008, 3.188828 , 3.1395504, 2.9985125, 3.0134718, 3.045996 , 2.8193517, 2.902401 , 2.9608424, 2.8618124, 3.096644 , 2.9716787, 3.0962532, 2.98246 , 2.7797184, 3.1126304, 2.9470394, 2.9663665, 2.9348457, 2.707904 , 2.7372358, 2.8489664, 2.7766206, 2.8870816, 2.7899947, 2.8266532, 3.0544372, 3.0760381, 3.0813096, 3.097099 , 3.1129563, 3.0446541, 3.1500077, 3.1358485, 3.0074933, 3.1147547, 2.818752 , 2.9542136, 3.0311131, 2.9141474, 3.1281784, 2.9774513, 2.740467 , 3.0301168, 2.88417 , 2.7082813, 2.867558 , 2.7481692, 3.1149592, 3.3530736, 3.4157538, 3.227895 , 3.1730561, 3.153692 , 2.89388 , 2.861404 , 3.0043843, 3.116252 , 3.0452118, 2.9471655, 2.9493282, 2.8316453, 2.9597483, 2.870535 , 2.9105754, 2.8837337, 2.9166882, 3.1200943, 3.0545652, 3.1874716, 2.8387008, 2.9814684, 3.0079186, 3.067449 , 2.9708407, 2.917632 , 2.831051 , 2.8786335, 2.9744313, 3.0196028, 2.9318008, 2.9318008, 3.0168672, 2.8341033, 3.0753546, 3.0753546, 2.8534124, 2.9582543, 2.9235747, 3.1306293, 3.08414 , 2.9503288, 3.0093284, 3.1156042, 2.906394 , 2.8710785, 2.8710785, 2.8466375, 2.7183504, 2.9024906, 2.939236 , 3.0976253, 3.0085454, 3.0485835, 3.0470424]], dtype=float32)
PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 740, 741, 742, 743, 744, 745, 746, 747, 748, 749], dtype='int64', name='draw', length=750))
<xarray.Dataset> Dimensions: (chain: 1, draw: 750, y_dim_0: 1000) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749 * y_dim_0 (y_dim_0) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999 Data variables: y (chain, draw, y_dim_0) float32 -4.239 -2.245 ... -2.194 -3.763 Attributes: created_at: 2023-05-29T23:56:40.806417 arviz_version: 0.12.1 inference_library: numpyro inference_library_version: 0.9.2
array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([ 0, 1, 2, ..., 997, 998, 999])
array([[[-4.23901 , -2.2447944, -3.7616072, ..., -3.9045508, -2.1726134, -3.819713 ], [-4.248008 , -2.2480104, -3.7681386, ..., -3.9891958, -2.1634555, -3.8268566], [-4.2147064, -2.2553887, -3.7335567, ..., -3.8977628, -2.1846452, -3.7918887], ..., [-4.237927 , -2.245505 , -3.753389 , ..., -3.9170911, -2.172602 , -3.8122478], [-4.2017717, -2.262422 , -3.7191315, ..., -3.9176188, -2.189119 , -3.7776227], [-4.1883435, -2.2696419, -3.70393 , ..., -3.9347727, -2.1942322, -3.7625935]]], dtype=float32)
PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 740, 741, 742, 743, 744, 745, 746, 747, 748, 749], dtype='int64', name='draw', length=750))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 990, 991, 992, 993, 994, 995, 996, 997, 998, 999], dtype='int64', name='y_dim_0', length=1000))
<xarray.Dataset> Dimensions: (chain: 1, draw: 750) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 742 743 744 745 746 747 748 749 Data variables: diverging (chain, draw) bool False False False False ... False False False Attributes: created_at: 2023-05-29T23:56:40.653699 arviz_version: 0.12.1 inference_library: numpyro inference_library_version: 0.9.2
array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, ... False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]])
PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 740, 741, 742, 743, 744, 745, 746, 747, 748, 749], dtype='int64', name='draw', length=750))
<xarray.Dataset> Dimensions: (y_dim_0: 1000) Coordinates: * y_dim_0 (y_dim_0) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999 Data variables: y (y_dim_0) float32 -1.97 6.648 -0.6971 5.801 ... 11.52 6.219 -0.8494 Attributes: created_at: 2023-05-29T23:56:40.807396 arviz_version: 0.12.1 inference_library: numpyro inference_library_version: 0.9.2
array([ 0, 1, 2, ..., 997, 998, 999])
array([-1.97044837e+00, 6.64772511e+00, -6.97052896e-01, 5.80093575e+00, 4.54183054e+00, 1.12101612e+01, -4.44931060e-01, 5.82078266e+00, 5.82591629e+00, 3.17236757e+00, 8.37355042e+00, 6.53492546e+00, 4.20800734e+00, 1.01482553e+01, 1.41755648e+01, 1.24651060e+01, -2.17906322e+01, 4.92048836e+00, 4.85010052e+00, 1.02006369e+01, 2.13908386e+00, 2.05172256e-01, 1.30528660e+01, -2.62980032e+00, 7.23637486e+00, 4.05275726e+00, 3.06575513e+00, 7.08320904e+00, 5.54312420e+00, 1.49230599e+00, 1.01288261e+01, 8.57089424e+00, 1.14836349e+01, 5.61097431e+00, 3.13621688e+00, 9.58228970e+00, 8.92821312e+00, 6.76997042e+00, 1.93540821e+01, 4.11432600e+00, 1.19134083e+01, 9.42990589e+00, 5.60357380e+00, -1.71314931e+00, 4.70048040e-01, 3.88558745e+00, 9.75277519e+00, 5.94671679e+00, 6.14265108e+00, -3.56223047e-01, 1.22409928e+00, 8.90643311e+00, 1.25334752e+00, 7.24736023e+00, 4.03608942e+00, 6.78641033e+00, -2.69396567e+00, 5.06562614e+00, -1.64162469e+00, 2.85413051e+00, 2.39020967e+00, 4.40503359e+00, 8.13211632e+00, 5.34186506e+00, 6.01666784e+00, 1.21046562e+01, 2.13109040e+00, 2.75819302e-01, 2.81802750e+00, 7.41340733e+00, 6.28000736e+00, 6.85987091e+00, 1.15999956e+01, 4.33733034e+00, 7.85834742e+00, 4.79054356e+00, -1.75435328e+00, 2.27244854e+00, 1.29107180e+01, 1.78298130e+01, ... 7.67104053e+00, 7.03784704e+00, 2.43009663e+00, 6.72932577e+00, 1.44165316e+01, 5.52429914e+00, 6.58890247e+00, 9.42949772e+00, 1.42814171e+00, 2.61827564e+00, 5.46275520e+00, 5.43681955e+00, 4.64565945e+00, -1.88100077e-02, 1.12393579e+01, 6.41488600e+00, 5.92880297e+00, 7.56008530e+00, 5.77487803e+00, 5.70153856e+00, 1.06024427e+01, 9.22666454e+00, 1.26327670e+00, 6.25757360e+00, 6.05027056e+00, 4.46348190e+00, 5.07479429e+00, 3.75031829e+00, 1.08985357e+01, 8.36709881e+00, 3.84162974e+00, 8.93783290e-03, 6.21623099e-01, 8.29607487e+00, 6.74988127e+00, 4.13313913e+00, 3.56779456e-01, 6.27172899e+00, 1.39344180e+00, 8.97291946e+00, 1.39694996e+01, 7.24138546e+00, 9.02207565e+00, 6.50661051e-01, -2.29273844e+00, 6.50683498e+00, 2.56033993e+00, 6.01500273e+00, 1.85625410e+00, 7.55919504e+00, 5.12628460e+00, 2.19768524e+00, 6.08607101e+00, 4.16822290e+00, -2.76709747e+00, 4.77808809e+00, 1.25680656e+01, 1.24284792e+00, 6.16621590e+00, 4.93042040e+00, 4.13558197e+00, 1.31871214e+01, 2.85820341e+00, 6.72703362e+00, 7.92491627e+00, 6.70488834e+00, 9.36034393e+00, -7.42135286e-01, 5.76647472e+00, 2.60789084e+00, 1.30372515e+01, 5.43683338e+00, 7.20951414e+00, 1.15234461e+01, 6.21906805e+00, -8.49427700e-01], dtype=float32)
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 990, 991, 992, 993, 994, 995, 996, 997, 998, 999], dtype='int64', name='y_dim_0', length=1000))
fig = plots.plot_st_2(mcmc, y)
mcmc_key = random.PRNGKey(0)
kernel = NUTS(glm_metric.one_group_robust)
mcmc = MCMC(kernel, num_warmup=250, num_samples=750)
mcmc.run(mcmc_key, jnp.array(smart_group_data.Score.values))
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat mean 107.38 2.96 107.35 102.54 112.11 499.25 1.00 nu 9.63 13.31 5.73 1.20 19.70 158.39 1.01 sigma 19.83 3.53 19.78 14.20 25.66 235.03 1.00 Number of divergences: 0
numpyro_glm.plot_diagnostic(mcmc, ['mean', 'sigma', 'nu'])
<xarray.Dataset> Dimensions: (chain: 1, draw: 750) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749 Data variables: mean (chain, draw) float32 105.6 106.3 108.6 104.4 ... 108.3 106.8 104.6 nu (chain, draw) float32 6.38 6.304 6.77 2.442 ... 9.179 8.01 13.55 sigma (chain, draw) float32 22.7 23.93 23.56 16.29 ... 22.12 23.89 20.27 Attributes: created_at: 2023-05-29T23:56:45.027781 arviz_version: 0.12.1 inference_library: numpyro inference_library_version: 0.9.2
array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([[105.55507 , 106.34972 , 108.5634 , 104.371185, 103.78115 , 108.39524 , 104.8494 , 105.39043 , 110.073586, 107.43894 , 108.44629 , 106.93347 , 102.90011 , 106.07971 , 106.80335 , 110.15082 , 109.17264 , 107.16441 , 111.50457 , 113.58138 , 102.10726 , 108.93787 , 109.68986 , 110.641335, 105.38782 , 108.96525 , 109.787346, 106.19191 , 107.32411 , 103.05021 , 109.8304 , 106.86794 , 110.42675 , 104.287094, 103.54164 , 106.96897 , 107.24161 , 103.478745, 102.1156 , 100.63876 , 108.324066, 108.720985, 108.63 , 111.166565, 109.553505, 108.90091 , 108.67291 , 105.97729 , 110.96842 , 105.5543 , 104.999115, 106.73197 , 107.25278 , 105.5463 , 108.59721 , 112.095764, 111.39011 , 110.41567 , 101.798965, 107.15456 , 110.26754 , 108.96618 , 105.98647 , 105.55314 , 106.5947 , 109.240746, 105.711555, 105.50618 , 106.02377 , 111.31043 , 103.732605, 102.67975 , 107.562675, 106.68475 , 106.863625, 109.01302 , 104.004974, 107.06012 , 107.67128 , 106.929245, 107.940384, 106.64516 , 112.43416 , 113.13488 , 112.6087 , 108.9748 , 104.34741 , 105.08496 , 111.78247 , 105.56174 , 106.40983 , 108.041275, 108.06947 , 105.43695 , 108.262535, 111.17323 , 104.075134, 104.57462 , 104.88279 , 105.00093 , ... 103.982544, 106.368774, 111.052895, 109.16271 , 110.65621 , 110.26432 , 107.34818 , 112.2538 , 104.68615 , 103.45305 , 107.17469 , 107.23535 , 108.20257 , 107.94324 , 107.99706 , 105.0354 , 108.69497 , 107.565155, 105.54674 , 108.33376 , 106.40116 , 106.42444 , 105.807465, 102.839745, 109.73883 , 111.335785, 109.93152 , 110.02944 , 109.36531 , 111.29636 , 110.44721 , 104.324524, 109.9587 , 112.545685, 111.32425 , 113.93877 , 111.03316 , 106.69128 , 105.634125, 107.761185, 103.34582 , 104.21555 , 104.75715 , 105.540276, 106.058044, 110.38669 , 110.98076 , 104.838936, 103.348526, 111.428215, 111.38477 , 111.552666, 107.4357 , 105.75314 , 106.130035, 105.13054 , 108.275925, 113.34639 , 103.81799 , 103.676956, 105.02578 , 106.404274, 101.28366 , 105.969345, 105.57954 , 105.5759 , 107.31593 , 108.7727 , 109.47214 , 105.593376, 107.311165, 104.327995, 104.20603 , 107.11131 , 109.11189 , 109.00499 , 107.64095 , 106.795364, 107.13538 , 110.425575, 107.195335, 102.86533 , 111.92812 , 102.704185, 102.664154, 104.04894 , 109.90849 , 104.54301 , 104.901215, 106.85347 , 105.591606, 108.840034, 108.32749 , 106.838234, 104.625496]], dtype=float32)
array([[ 6.380366 , 6.3039827, 6.770382 , 2.4415083, 4.205345 , 2.3898754, 2.4554567, 4.751895 , 30.20579 , 6.2195134, 3.5312512, 4.76675 , 5.0081363, 3.1247911, 21.360893 , 23.807055 , 10.058445 , 7.779872 , 11.477498 , 21.622583 , 36.68576 , 5.816049 , 7.4071264, 4.7910275, 3.411479 , 5.6175413, 3.6065092, 3.8919663, 2.406946 , 6.614321 , 17.233143 , 2.8523195, 9.354761 , 5.276261 , 4.406109 , 3.6853158, 3.4503999, 6.083024 , 3.162562 , 7.787067 , 18.096825 , 6.6824446, 11.993019 , 4.9092307, 12.329353 , 10.28212 , 12.334496 , 21.102882 , 68.73028 , 28.93483 , 24.624996 , 22.925674 , 18.960793 , 17.155924 , 4.483561 , 9.103407 , 8.090643 , 3.4731753, 21.862223 , 3.5392842, 6.277596 , 4.3358054, 10.864571 , 1.2639813, 1.8327861, 5.8132896, 2.6647193, 10.151241 , 4.558977 , 11.674158 , 13.655367 , 16.772926 , 25.298958 , 1.9189749, 5.538587 , 3.6070976, 9.118575 , 4.755516 , 4.6680737, 3.8454957, 3.030088 , 3.3910158, 6.3802996, 3.3749137, 4.504877 , 10.809336 , 6.2159004, 15.667531 , 4.7564974, 8.262956 , 9.560621 , 9.199344 , 2.9713812, 3.7156253, 5.5498013, 11.150637 , 5.7294626, 4.002439 , 3.909546 , 2.221348 , ... 10.571391 , 13.065151 , 6.8497195, 3.7558386, 3.5340605, 3.298055 , 2.0803087, 4.189632 , 7.7399063, 3.6138294, 3.9523766, 6.6232204, 6.2300615, 2.3853402, 2.850403 , 5.229487 , 3.8615527, 3.9326599, 3.8529232, 8.671819 , 6.301786 , 7.7922463, 5.2313375, 2.3083024, 65.67436 , 100.1539 , 48.767944 , 33.916534 , 32.376404 , 9.096143 , 9.358721 , 7.4668074, 17.171707 , 28.13654 , 131.20631 , 99.750786 , 82.21945 , 150.44986 , 4.315606 , 15.028996 , 19.251741 , 16.819944 , 54.600857 , 10.102056 , 13.521818 , 12.604932 , 16.998135 , 30.99062 , 8.554714 , 9.850973 , 9.87082 , 7.277955 , 5.83263 , 5.397276 , 4.8431597, 7.5538855, 3.3288703, 14.987564 , 7.951935 , 6.340927 , 11.855715 , 1.8256805, 5.054306 , 7.583917 , 7.169403 , 7.4230094, 6.0162477, 7.335036 , 3.146756 , 2.879428 , 7.331243 , 9.373186 , 6.303137 , 11.966533 , 7.999712 , 4.57536 , 4.0942593, 23.454334 , 4.570996 , 7.025208 , 3.1406336, 4.4695325, 8.906318 , 28.438923 , 6.1750894, 6.5125046, 3.4833727, 3.2517338, 2.7749972, 2.746849 , 9.603409 , 9.797531 , 9.178833 , 8.009731 , 13.549246 ]], dtype=float32)
array([[22.70183 , 23.928846 , 23.562433 , 16.289011 , 15.22005 , 15.044127 , 16.297873 , 17.127993 , 28.735075 , 16.449953 , 21.467855 , 19.445892 , 21.499855 , 22.8823 , 21.110508 , 23.649979 , 26.60918 , 30.828 , 25.248333 , 22.17687 , 21.898846 , 22.394032 , 19.247723 , 20.5953 , 19.776379 , 18.489138 , 18.343906 , 16.611282 , 17.876183 , 18.076878 , 20.283577 , 20.365068 , 17.932829 , 17.479145 , 16.480396 , 16.9544 , 17.29114 , 17.868351 , 20.70608 , 23.331121 , 21.984573 , 21.473406 , 25.425516 , 20.035864 , 24.10131 , 22.57912 , 24.053577 , 25.006847 , 25.977291 , 23.044413 , 23.399036 , 22.522333 , 22.15956 , 26.54244 , 22.261833 , 18.223854 , 20.073006 , 18.710468 , 23.895777 , 18.726654 , 17.243109 , 23.96953 , 17.760885 , 16.729784 , 16.708885 , 14.98461 , 19.943176 , 18.423456 , 20.433556 , 21.729021 , 21.142508 , 24.806545 , 26.408274 , 16.245544 , 15.434979 , 20.393446 , 19.199934 , 20.03922 , 18.835415 , 17.204237 , 18.52296 , 18.065552 , 20.298271 , 17.421255 , 17.687933 , 19.639505 , 19.606153 , 24.838633 , 18.426218 , 21.44172 , 22.273306 , 19.690224 , 16.428911 , 21.521833 , 21.86512 , 20.675076 , 21.54007 , 17.91468 , 17.8993 , 12.195638 , ... 23.518085 , 20.518326 , 24.155817 , 21.604746 , 13.404511 , 13.920298 , 14.833247 , 15.817091 , 24.094957 , 17.81413 , 17.429472 , 16.82248 , 16.594954 , 19.392002 , 17.35203 , 17.853443 , 16.643085 , 19.361786 , 16.21468 , 19.727438 , 23.818949 , 22.251196 , 18.389307 , 17.08416 , 28.378626 , 25.069221 , 23.163519 , 25.550022 , 23.057487 , 24.160658 , 19.997173 , 22.610897 , 22.570648 , 20.691324 , 24.283613 , 22.753681 , 25.524794 , 26.249626 , 18.095999 , 22.020086 , 21.696339 , 26.190704 , 22.211027 , 25.576822 , 24.824036 , 22.590075 , 21.377657 , 25.288351 , 21.2128 , 19.17415 , 18.732403 , 18.372541 , 23.47746 , 18.848772 , 19.336374 , 19.807137 , 19.745995 , 25.764769 , 22.528057 , 24.722696 , 17.32529 , 20.22306 , 21.694635 , 18.270115 , 17.884037 , 17.52701 , 16.287584 , 17.014729 , 18.675406 , 17.37952 , 19.083565 , 17.537128 , 19.86175 , 17.674118 , 21.235544 , 24.079872 , 14.4767065, 21.373323 , 18.028996 , 22.986351 , 16.297935 , 18.35883 , 20.795593 , 22.46605 , 20.826462 , 19.824308 , 18.015593 , 15.6124115, 13.76765 , 14.794092 , 21.357248 , 22.982603 , 22.121471 , 23.893038 , 20.269114 ]], dtype=float32)
PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 740, 741, 742, 743, 744, 745, 746, 747, 748, 749], dtype='int64', name='draw', length=750))
<xarray.Dataset> Dimensions: (chain: 1, draw: 750, y_dim_0: 63) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749 * y_dim_0 (y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62 Data variables: y (chain, draw, y_dim_0) float32 -4.095 -4.083 ... -5.734 -4.017 Attributes: created_at: 2023-05-29T23:56:45.160203 arviz_version: 0.12.1 inference_library: numpyro inference_library_version: 0.9.2
array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62])
array([[[-4.094567 , -4.082753 , -4.281054 , ..., -4.11488 , -5.4502263, -4.1267333], [-4.1526117, -4.1339474, -4.336127 , ..., -4.174086 , -5.3432612, -4.1656785], [-4.1596594, -4.117912 , -4.389084 , ..., -4.1904535, -5.235262 , -4.1275744], ..., [-4.0878344, -4.0446672, -4.3361053, ..., -4.1206446, -5.3121624, -4.05793 ], [-4.146661 , -4.1236835, -4.3355103, ..., -4.1694927, -5.312559 , -4.149831 ], [-3.9554722, -3.953834 , -4.1518583, ..., -3.974376 , -5.7338276, -4.017195 ]]], dtype=float32)
PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 740, 741, 742, 743, 744, 745, 746, 747, 748, 749], dtype='int64', name='draw', length=750))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62], dtype='int64', name='y_dim_0'))
<xarray.Dataset> Dimensions: (chain: 1, draw: 750) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 742 743 744 745 746 747 748 749 Data variables: diverging (chain, draw) bool False False False False ... False False False Attributes: created_at: 2023-05-29T23:56:45.029825 arviz_version: 0.12.1 inference_library: numpyro inference_library_version: 0.9.2
array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, ... False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]])
PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 740, 741, 742, 743, 744, 745, 746, 747, 748, 749], dtype='int64', name='draw', length=750))
<xarray.Dataset> Dimensions: (y_dim_0: 63) Coordinates: * y_dim_0 (y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62 Data variables: y (y_dim_0) int32 102 107 92 101 110 68 119 ... 97 139 72 100 144 112 Attributes: created_at: 2023-05-29T23:56:45.161050 arviz_version: 0.12.1 inference_library: numpyro inference_library_version: 0.9.2
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62])
array([102, 107, 92, 101, 110, 68, 119, 106, 99, 103, 90, 93, 79, 89, 137, 119, 126, 110, 71, 114, 100, 95, 91, 99, 97, 106, 106, 129, 115, 124, 137, 73, 69, 95, 102, 116, 111, 134, 102, 110, 139, 112, 122, 84, 129, 112, 127, 106, 113, 109, 208, 114, 107, 50, 169, 133, 50, 97, 139, 72, 100, 144, 112], dtype=int32)
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62], dtype='int64', name='y_dim_0'))
fig = plots.plot_st_2(
mcmc, smart_group_data.Score.values,
mean_comp_val=100,
sigma_comp_val=15,
effsize_comp_val=0)
fig = numpyro_glm.plot_pairwise_scatter(mcmc, ['mean', 'sigma', 'nu'])
mcmc_key = random.PRNGKey(0)
kernel = NUTS(glm_metric.multi_groups_robust)
mcmc = MCMC(kernel, num_warmup=250, num_samples=750)
mcmc.run(
mcmc_key,
jnp.array(iq_data['Score'].values),
jnp.array(iq_data['Group'].cat.codes.values),
len(iq_data['Group'].cat.categories),
)
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat mean[0] 99.29 1.76 99.20 96.70 102.33 806.66 1.00 mean[1] 107.16 2.62 107.14 102.90 111.36 910.67 1.00 nu 3.86 1.68 3.45 1.80 5.73 403.33 1.01 sigma[0] 11.30 1.74 11.10 8.37 13.97 655.02 1.00 sigma[1] 17.91 2.72 17.71 13.80 22.56 534.41 1.00 Number of divergences: 0
numpyro_glm.plot_diagnostic(mcmc, ['mean', 'sigma', 'nu'])
<xarray.Dataset> Dimensions: (chain: 1, draw: 750, mean_dim_0: 2, sigma_dim_0: 2) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 743 744 745 746 747 748 749 * mean_dim_0 (mean_dim_0) int64 0 1 * sigma_dim_0 (sigma_dim_0) int64 0 1 Data variables: mean (chain, draw, mean_dim_0) float32 98.21 108.0 ... 94.3 103.5 nu (chain, draw) float32 4.904 4.618 2.711 ... 2.909 2.738 2.962 sigma (chain, draw, sigma_dim_0) float32 11.06 19.12 ... 10.23 17.3 Attributes: created_at: 2023-05-29T23:56:52.320366 arviz_version: 0.12.1 inference_library: numpyro inference_library_version: 0.9.2
array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([0, 1])
array([0, 1])
array([[[ 98.21455 , 107.99677 ], [ 98.949814, 109.36058 ], [ 98.70745 , 102.67616 ], ..., [103.06653 , 106.81406 ], [ 95.39924 , 107.23023 ], [ 94.29891 , 103.466194]]], dtype=float32)
array([[ 4.90352 , 4.617958 , 2.7109103, 2.9921021, 3.342796 , 6.3293333, 3.3233557, 4.0153794, 3.9350548, 2.6246867, 7.292013 , 3.497366 , 2.326993 , 2.8841724, 2.2890396, 2.4631345, 4.2260084, 4.054314 , 4.1366997, 4.4986525, 2.7777195, 5.1717806, 2.021302 , 5.061619 , 5.94571 , 3.14477 , 2.4800138, 2.869721 , 2.4497128, 2.564416 , 2.9521792, 2.1315536, 2.0498023, 3.8954706, 2.0831642, 2.2922893, 3.3183112, 3.122521 , 3.0120344, 4.1892695, 2.5077238, 2.5695941, 2.2404354, 5.06931 , 3.70294 , 3.3386614, 2.5039854, 2.3118012, 3.1710594, 2.3974266, 2.5402591, 1.959218 , 3.2319102, 3.2446227, 5.831307 , 3.9336195, 3.8203804, 2.7046685, 3.9451866, 3.540849 , 3.7571568, 4.5666227, 3.1252599, 2.6628523, 5.4824295, 3.590702 , 3.329475 , 4.470755 , 3.7153351, 3.0417259, 7.0245333, 2.8254995, 4.068351 , 2.6363196, 6.1589856, 3.6985042, 4.0642147, 2.3869684, 3.4503007, 2.638822 , 2.5350208, 2.5853083, 2.5492935, 2.7262902, 5.2686357, 4.664855 , 2.334038 , 3.1042252, 2.6422722, 2.2850242, 2.8939147, 2.500326 , 4.2968397, 2.880906 , 3.437034 , 3.8430223, 6.010608 , 2.493542 , 3.2565322, 3.7440398, ... 4.7892776, 4.238168 , 1.6762153, 3.2976294, 4.928863 , 2.3080652, 2.6339862, 4.841153 , 2.2997952, 1.9764836, 5.1290674, 2.9647245, 3.793073 , 3.8011038, 4.1641083, 4.1243477, 6.333407 , 2.8298962, 3.2586808, 3.278459 , 3.8523297, 3.7573292, 3.8873568, 3.3295789, 3.7441294, 3.8092866, 3.505789 , 3.3353996, 3.1752791, 2.874391 , 3.7779677, 2.5379913, 6.0346184, 3.2346685, 2.308152 , 3.2602334, 9.1525135, 8.748209 , 4.2054396, 6.5660415, 4.016606 , 7.6482983, 4.4520617, 4.8964925, 3.7809463, 1.9816743, 3.64076 , 5.9388037, 2.9451606, 3.6500442, 4.868866 , 4.1009865, 5.2459726, 2.457327 , 2.44139 , 7.370051 , 2.4880009, 3.4568632, 3.994024 , 4.073059 , 5.5238442, 4.672189 , 6.579589 , 8.251197 , 3.1003256, 7.0111647, 2.9936523, 3.945934 , 2.7069554, 7.7330813, 3.0613446, 3.1943877, 5.15743 , 2.1961014, 2.0084407, 4.7910404, 4.3722544, 4.7984753, 4.1131425, 2.9987164, 3.3225734, 2.3282595, 2.9577506, 4.8833666, 5.5551987, 7.0451884, 2.3839269, 2.004416 , 2.6664572, 3.5382771, 2.3100057, 2.0453377, 2.9086428, 2.7375379, 2.9617047]], dtype=float32)
array([[[11.057621 , 19.119974 ], [12.7135935, 21.039143 ], [10.781828 , 14.102758 ], ..., [10.174658 , 16.26035 ], [11.224569 , 15.996319 ], [10.232676 , 17.299318 ]]], dtype=float32)
PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 740, 741, 742, 743, 744, 745, 746, 747, 748, 749], dtype='int64', name='draw', length=750))
PandasIndex(Int64Index([0, 1], dtype='int64', name='mean_dim_0'))
PandasIndex(Int64Index([0, 1], dtype='int64', name='sigma_dim_0'))
<xarray.Dataset> Dimensions: (chain: 1, draw: 750, y_dim_0: 120) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749 * y_dim_0 (y_dim_0) int64 0 1 2 3 4 5 6 7 ... 112 113 114 115 116 117 118 119 Data variables: y (chain, draw, y_dim_0) float32 -3.979 -3.922 ... -3.338 -5.223 Attributes: created_at: 2023-05-29T23:56:52.493138 arviz_version: 0.12.1 inference_library: numpyro inference_library_version: 0.9.2
array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119])
array([[[-3.9789479, -3.9219544, -4.3142 , ..., -3.5036287, -3.5036287, -5.6269355], [-4.092538 , -4.0267043, -4.4053917, ..., -3.6455019, -3.6455019, -5.424239 ], [-3.6572642, -3.718939 , -4.011526 , ..., -3.5696988, -3.5696988, -5.58177 ], ..., [-3.8501332, -3.7921975, -4.282725 , ..., -3.8901894, -3.8901894, -6.1303153], [-3.8524227, -3.7809741, -4.3153687, ..., -3.457511 , -3.457511 , -5.203177 ], [-3.8573813, -3.8802965, -4.126558 , ..., -3.3382494, -3.3382494, -5.222832 ]]], dtype=float32)
PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 740, 741, 742, 743, 744, 745, 746, 747, 748, 749], dtype='int64', name='draw', length=750))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 110, 111, 112, 113, 114, 115, 116, 117, 118, 119], dtype='int64', name='y_dim_0', length=120))
<xarray.Dataset> Dimensions: (chain: 1, draw: 750) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 742 743 744 745 746 747 748 749 Data variables: diverging (chain, draw) bool False False False False ... False False False Attributes: created_at: 2023-05-29T23:56:52.322597 arviz_version: 0.12.1 inference_library: numpyro inference_library_version: 0.9.2
array([0])
array([ 0, 1, 2, ..., 747, 748, 749])
array([[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, ... False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]])
PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 740, 741, 742, 743, 744, 745, 746, 747, 748, 749], dtype='int64', name='draw', length=750))
<xarray.Dataset> Dimensions: (y_dim_0: 120) Coordinates: * y_dim_0 (y_dim_0) int64 0 1 2 3 4 5 6 7 ... 112 113 114 115 116 117 118 119 Data variables: y (y_dim_0) int32 102 107 92 101 110 68 119 ... 82 138 99 93 93 72 Attributes: created_at: 2023-05-29T23:56:52.493967 arviz_version: 0.12.1 inference_library: numpyro inference_library_version: 0.9.2
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119])
array([102, 107, 92, 101, 110, 68, 119, 106, 99, 103, 90, 93, 79, 89, 137, 119, 126, 110, 71, 114, 100, 95, 91, 99, 97, 106, 106, 129, 115, 124, 137, 73, 69, 95, 102, 116, 111, 134, 102, 110, 139, 112, 122, 84, 129, 112, 127, 106, 113, 109, 208, 114, 107, 50, 169, 133, 50, 97, 139, 72, 100, 144, 112, 109, 98, 106, 101, 100, 111, 117, 104, 106, 89, 84, 88, 94, 78, 108, 102, 95, 99, 90, 116, 97, 107, 102, 91, 94, 95, 86, 108, 115, 108, 88, 102, 102, 120, 112, 100, 105, 105, 88, 82, 111, 96, 92, 109, 91, 92, 123, 61, 59, 105, 184, 82, 138, 99, 93, 93, 72], dtype=int32)
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 110, 111, 112, 113, 114, 115, 116, 117, 118, 119], dtype='int64', name='y_dim_0', length=120))
Plot the resulting posteriors.
fig = plots.plot_st_2(
mcmc,
iq_data[iq_data['Group'] == 'Placebo']['Score'].values,
mean_coords=dict(mean_dim_0=0),
sigma_coords=dict(sigma_dim_0=0),
figtitle='Placebo Posteriors')
fig = plots.plot_st_2(
mcmc,
iq_data[iq_data['Group'] == 'Smart Drug']['Score'].values,
mean_coords=dict(mean_dim_0=1),
sigma_coords=dict(sigma_dim_0=1),
figtitle='Smart Drug Posteriors')
Then, we will plot the difference between the two groups.
idata = az.from_numpyro(mcmc)
posteriors = idata.posterior
fig, axes = plt.subplots(ncols=3, figsize=(18, 6))
fig.suptitle('Difference between Smart Drug and Placebo')
# Plot mean difference.
ax = axes[0]
mean_difference = (posteriors['mean'].sel(dict(mean_dim_0=1))
- posteriors['mean'].sel(dict(mean_dim_0=0)))
az.plot_posterior(
mean_difference,
hdi_prob=0.95,
ref_val=0,
rope=(-0.5, 0.5),
point_estimate='mode',
kind='hist',
ax=ax)
ax.set_title('Difference of Means')
ax.set_xlabel('$\mu[1] - \mu[0]$')
# Plot sigma difference.
ax = axes[1]
sigma_difference = (posteriors['sigma'].sel(dict(sigma_dim_0=1))
- posteriors['sigma'].sel(dict(sigma_dim_0=0)))
az.plot_posterior(
sigma_difference,
hdi_prob=0.95,
ref_val=0,
rope=(-0.5, 0.5),
point_estimate='mode',
kind='hist',
ax=ax)
ax.set_title('Difference of Scales')
ax.set_xlabel('$\sigma[1] - \sigma[0]$')
# Plot effect size.
ax = axes[2]
sigmas_squared = (posteriors['sigma'].sel(dict(sigma_dim_0=1))**2
+ posteriors['sigma'].sel(dict(sigma_dim_0=0))**2)
effect_size = mean_difference / np.sqrt(sigmas_squared)
az.plot_posterior(
effect_size,
hdi_prob=0.95,
ref_val=0,
rope=(-0.1, 0.1),
point_estimate='mode',
kind='hist',
ax=ax)
ax.set_title('Effect Size')
ax.set_xlabel('$(\mu[1] - \mu[0]) / \sqrt{\sigma[1]^2 + \sigma[0]^2}$')
fig.tight_layout()