%cd ..
%load_ext autoreload
%autoreload 2
/home/runner/work/numpyro-doing-bayesian/numpyro-doing-bayesian
import arviz as az
import jax.numpy as jnp
import jax.random as random
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from numpyro.infer import MCMC, NUTS
import numpyro_glm
import numpyro_glm.metric.models as glm_metric
x = np.random.uniform(-10, 10, size=500)
y = np.random.normal(10 + 2 * x, 2)
fig, ax = plt.subplots()
ax.scatter(x, y, c='black', s=4)
ax.set_title('Normal PDF around linear function')
xline = np.linspace(-10, 10, 1000)
ax.plot(xline, 10 + 2 * xline, lw=4, c='#87ceeb')
# TODO
for xinterval in [-7.5, -2.5, 2.5, 7.5]:
y_ = np.linspace(xinterval - 6, xinterval + 6, 1000)
height_weight_30_data = pd.read_csv('datasets/HtWtData30.csv')
height_weight_30_data.describe()
male | height | weight | |
---|---|---|---|
count | 30.000000 | 30.000000 | 30.000000 |
mean | 0.466667 | 66.983333 | 154.013333 |
std | 0.507416 | 4.201074 | 31.702570 |
min | 0.000000 | 57.500000 | 96.500000 |
25% | 0.000000 | 64.100000 | 128.850000 |
50% | 0.000000 | 66.450000 | 147.900000 |
75% | 1.000000 | 69.575000 | 179.650000 |
max | 1.000000 | 76.000000 | 215.100000 |
We will test the model with raw data (no standardization applied to both y and x).
mcmc_key = random.PRNGKey(0)
kernel = NUTS(glm_metric.one_metric_predictor_robust_no_standardization)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=20000)
mcmc.run(
mcmc_key,
jnp.array(height_weight_30_data.height.values),
jnp.array(height_weight_30_data.weight.values),
)
mcmc.print_summary()
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
mean std median 5.0% 95.0% n_eff r_hat b0 48.72 3.84 48.95 42.70 55.15 6216.35 1.00 b1 0.12 0.02 0.12 0.08 0.16 6252.35 1.00 nu 23.76 25.95 14.17 1.11 57.71 8596.60 1.00 sigma 3.57 0.75 3.52 2.30 4.71 6426.80 1.00 Number of divergences: 0
numpyro_glm.plot_diagnostic(mcmc, ['b0', 'b1', 'nu', 'sigma'])
<xarray.Dataset> Dimensions: (chain: 1, draw: 20000, mean_dim_0: 30) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 ... 19994 19995 19996 19997 19998 19999 * mean_dim_0 (mean_dim_0) int64 0 1 2 3 4 5 6 7 8 ... 22 23 24 25 26 27 28 29 Data variables: b0 (chain, draw) float32 41.41 43.14 43.32 ... 48.34 45.5 50.65 b1 (chain, draw) float32 0.1631 0.148 0.1502 ... 0.1378 0.1047 mean (chain, draw, mean_dim_0) float32 63.66 76.49 ... 69.36 62.29 nu (chain, draw) float32 50.76 97.48 102.5 ... 2.813 11.31 3.153 sigma (chain, draw) float32 4.595 4.4 4.371 ... 2.978 2.951 3.091 Attributes: created_at: 2023-05-29T23:52:41.500946 arviz_version: 0.12.1 inference_library: numpyro inference_library_version: 0.9.2
array([0])
array([ 0, 1, 2, ..., 19997, 19998, 19999])
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])
array([[41.410774, 43.14378 , 43.31898 , ..., 48.3416 , 45.503754, 50.65438 ]], dtype=float32)
array([[0.1631024 , 0.14796117, 0.1502028 , ..., 0.12550929, 0.1377722 , 0.104709 ]], dtype=float32)
array([[[63.65794 , 76.4941 , 69.72535 , ..., 65.533615, 70.54086 , 59.531452], [63.325684, 74.97023 , 68.82984 , ..., 65.02724 , 69.56965 , 59.582264], [63.80664 , 75.6276 , 69.39419 , ..., 65.53397 , 70.1452 , 60.00651 ], ..., [65.46107 , 75.338646, 70.13001 , ..., 66.90442 , 70.75756 , 62.28568 ], [64.29588 , 75.13856 , 69.421005, ..., 65.880264, 70.10987 , 60.810246], [64.93669 , 73.177284, 68.83186 , ..., 66.14084 , 69.35541 , 62.28755 ]]], dtype=float32)
array([[ 50.764534 , 97.48212 , 102.51505 , ..., 2.8131351, 11.306918 , 3.1532779]], dtype=float32)
array([[4.594737 , 4.39954 , 4.37117 , ..., 2.9783728, 2.950543 , 3.0909798]], dtype=float32)
PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 19990, 19991, 19992, 19993, 19994, 19995, 19996, 19997, 19998, 19999], dtype='int64', name='draw', length=20000))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], dtype='int64', name='mean_dim_0'))
<xarray.Dataset> Dimensions: (chain: 1, draw: 20000, y_dim_0: 30) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 ... 19994 19995 19996 19997 19998 19999 * y_dim_0 (y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 21 22 23 24 25 26 27 28 29 Data variables: y (chain, draw, y_dim_0) float32 -2.452 -6.907 ... -2.201 -3.235 Attributes: created_at: 2023-05-29T23:52:41.687343 arviz_version: 0.12.1 inference_library: numpyro inference_library_version: 0.9.2
array([0])
array([ 0, 1, 2, ..., 19997, 19998, 19999])
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])
array([[[-2.451594 , -6.9073153, -2.5291097, ..., -2.4494371, -2.5697544, -3.7278814], [-2.4148343, -6.423703 , -2.4255276, ..., -2.4147794, -2.4450195, -3.7809734], [-2.3973503, -6.88908 , -2.4553216, ..., -2.3970907, -2.48625 , -3.6370234], ..., [-2.2540216, -6.0169435, -2.4439926, ..., -2.2052567, -2.5107615, -3.2737486], [-2.028479 , -8.076545 , -2.1659567, ..., -2.02504 , -2.2244449, -3.9904213], [-2.1851215, -5.4372835, -2.184517 , ..., -2.13886 , -2.200901 , -3.2349527]]], dtype=float32)
PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 19990, 19991, 19992, 19993, 19994, 19995, 19996, 19997, 19998, 19999], dtype='int64', name='draw', length=20000))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], dtype='int64', name='y_dim_0'))
<xarray.Dataset> Dimensions: (chain: 1, draw: 20000) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 ... 19994 19995 19996 19997 19998 19999 Data variables: diverging (chain, draw) bool False False False False ... False False False Attributes: created_at: 2023-05-29T23:52:41.513411 arviz_version: 0.12.1 inference_library: numpyro inference_library_version: 0.9.2
array([0])
array([ 0, 1, 2, ..., 19997, 19998, 19999])
array([[False, False, False, ..., False, False, False]])
PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 19990, 19991, 19992, 19993, 19994, 19995, 19996, 19997, 19998, 19999], dtype='int64', name='draw', length=20000))
<xarray.Dataset> Dimensions: (y_dim_0: 30) Coordinates: * y_dim_0 (y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 21 22 23 24 25 26 27 28 29 Data variables: y (y_dim_0) float32 64.0 62.3 67.9 64.2 64.8 ... 66.4 65.7 68.3 66.9 Attributes: created_at: 2023-05-29T23:52:41.688187 arviz_version: 0.12.1 inference_library: numpyro inference_library_version: 0.9.2
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])
array([64. , 62.3, 67.9, 64.2, 64.8, 57.5, 65.6, 70.2, 63.9, 71.1, 66.5, 68.1, 62.9, 75.1, 64.6, 69.2, 68.1, 72.6, 63.2, 64.1, 64.1, 71.5, 76. , 69.7, 73.3, 61.7, 66.4, 65.7, 68.3, 66.9], dtype=float32)
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], dtype='int64', name='y_dim_0'))
Here is the model with both y and x standardized.
mcmc_key = random.PRNGKey(0)
kernel = NUTS(glm_metric.one_metric_predictor_robust)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=20000)
mcmc.run(
mcmc_key,
jnp.array(height_weight_30_data.height.values),
jnp.array(height_weight_30_data.weight.values),
)
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat nu 28.91 28.39 19.63 1.26 65.21 15679.32 1.00 zb0 0.03 0.17 0.03 -0.25 0.29 19537.09 1.00 zb1 0.58 0.18 0.58 0.27 0.85 18687.78 1.00 zsigma 0.84 0.15 0.83 0.60 1.08 12668.37 1.00 Number of divergences: 0
numpyro_glm.plot_diagnostic(mcmc, ['zb0', 'zb1', 'nu', 'zsigma', 'b0'])
<xarray.Dataset> Dimensions: (chain: 1, draw: 20000, zmean_dim_0: 30) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 ... 19995 19996 19997 19998 19999 * zmean_dim_0 (zmean_dim_0) int64 0 1 2 3 4 5 6 7 ... 22 23 24 25 26 27 28 29 Data variables: b0 (chain, draw) float32 54.14 59.34 55.78 ... 53.82 57.74 53.81 b1 (chain, draw) float32 0.08307 0.05473 ... 0.0589 0.08952 nu (chain, draw) float32 8.131 44.38 25.45 ... 7.959 19.05 13.96 sigma (chain, draw) float32 2.563 3.453 3.461 ... 2.761 3.062 3.668 zb0 (chain, draw) float32 -0.0121 0.1893 ... -0.04201 0.1493 zb1 (chain, draw) float32 0.6268 0.413 0.5652 ... 0.4445 0.6756 zmean (chain, draw, zmean_dim_0) float32 -0.3663 1.216 ... -0.7808 zsigma (chain, draw) float32 0.6205 0.8361 0.838 ... 0.7412 0.8879 Attributes: created_at: 2023-05-29T23:52:45.902172 arviz_version: 0.12.1 inference_library: numpyro inference_library_version: 0.9.2
array([0])
array([ 0, 1, 2, ..., 19997, 19998, 19999])
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])
array([[54.13991 , 59.33543 , 55.783638, ..., 53.81867 , 57.738403, 53.81255 ]], dtype=float32)
array([[0.08306717, 0.0547338 , 0.07489246, ..., 0.08649762, 0.05890025, 0.08952103]], dtype=float32)
array([[ 8.131379, 44.376976, 25.454565, ..., 7.959042, 19.045025, 13.961558]], dtype=float32)
array([[2.5627956, 3.4533787, 3.4614177, ..., 2.7611022, 3.0616624, 3.6676064]], dtype=float32)
array([[-0.01209797, 0.18928441, 0.08104113, ..., 0.03803969, -0.04200768, 0.1492924 ]], dtype=float32)
array([[0.6268499 , 0.41303775, 0.56516105, ..., 0.6527371 , 0.44447905, 0.67555267]], dtype=float32)
array([[[-0.36631748, 1.2164077 , 0.38180685, ..., -0.13504256, 0.4823612 , -0.8751222 ], [-0.04411441, 0.9987592 , 0.4488323 , ..., 0.10827498, 0.51508856, -0.37937102], [-0.23831932, 1.1886485 , 0.43618146, ..., -0.02980437, 0.52684015, -0.6970521 ], ..., [-0.3308081 , 1.3172792 , 0.44821173, ..., -0.08998217, 0.5529187 , -0.8606251 ], [-0.2931733 , 0.82908607, 0.23729753, ..., -0.12918372, 0.30859736, -0.65395033], [-0.23244801, 1.4732461 , 0.57380146, ..., 0.01679569, 0.6821683 , -0.7807841 ]]], dtype=float32)
array([[0.6204621 , 0.8360755 , 0.83802176, ..., 0.6684729 , 0.7412396 , 0.8879408 ]], dtype=float32)
PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 19990, 19991, 19992, 19993, 19994, 19995, 19996, 19997, 19998, 19999], dtype='int64', name='draw', length=20000))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], dtype='int64', name='zmean_dim_0'))
<xarray.Dataset> Dimensions: (chain: 1, draw: 20000, yz_dim_0: 30) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 ... 19994 19995 19996 19997 19998 19999 * yz_dim_0 (yz_dim_0) int64 0 1 2 3 4 5 6 7 8 ... 21 22 23 24 25 26 27 28 29 Data variables: yz (chain, draw, yz_dim_0) float32 -0.6535 -5.115 ... -0.9072 -1.201 Attributes: created_at: 2023-05-29T23:52:46.094067 arviz_version: 0.12.1 inference_library: numpyro inference_library_version: 0.9.2
array([0])
array([ 0, 1, 2, ..., 19997, 19998, 19999])
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])
array([[[-0.65347993, -5.115089 , -0.50944793, ..., -0.51710135, -0.511185 , -1.4304454 ], [-1.0794393 , -3.8496172 , -0.7831586 , ..., -0.87356246, -0.77370465, -0.83970374], [-0.92422324, -4.240124 , -0.78596765, ..., -0.8102995 , -0.7840403 , -1.0867834 ], ..., [-0.73648375, -4.9790025 , -0.6115416 , ..., -0.60845196, -0.61603576, -1.3590188 ], [-0.8074536 , -3.7748375 , -0.6328528 , ..., -0.6641354 , -0.63272566, -1.0101484 ], [-0.9792839 , -4.4151945 , -0.9016542 , ..., -0.8905166 , -0.90719205, -1.2011554 ]]], dtype=float32)
PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 19990, 19991, 19992, 19993, 19994, 19995, 19996, 19997, 19998, 19999], dtype='int64', name='draw', length=20000))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], dtype='int64', name='yz_dim_0'))
<xarray.Dataset> Dimensions: (chain: 1, draw: 20000) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 ... 19994 19995 19996 19997 19998 19999 Data variables: diverging (chain, draw) bool False False False False ... False False False Attributes: created_at: 2023-05-29T23:52:45.906057 arviz_version: 0.12.1 inference_library: numpyro inference_library_version: 0.9.2
array([0])
array([ 0, 1, 2, ..., 19997, 19998, 19999])
array([[False, False, False, ..., False, False, False]])
PandasIndex(Int64Index([0], dtype='int64', name='chain'))
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 19990, 19991, 19992, 19993, 19994, 19995, 19996, 19997, 19998, 19999], dtype='int64', name='draw', length=20000))
<xarray.Dataset> Dimensions: (yz_dim_0: 30) Coordinates: * yz_dim_0 (yz_dim_0) int64 0 1 2 3 4 5 6 7 8 ... 21 22 23 24 25 26 27 28 29 Data variables: yz (yz_dim_0) float32 -0.7223 -1.134 0.2219 ... 0.3188 -0.02018 Attributes: created_at: 2023-05-29T23:52:46.094933 arviz_version: 0.12.1 inference_library: numpyro inference_library_version: 0.9.2
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])
array([-0.7222768 , -1.1338532 , 0.22192772, -0.67385685, -0.5285932 , -2.2959504 , -0.33491138, 0.77876496, -0.7464868 , 0.99665856, -0.11701774, 0.27034768, -0.9885904 , 1.9650731 , -0.577015 , 0.5366613 , 0.27034768, 1.359814 , -0.91595954, -0.69806683, -0.69806683, 1.0935004 , 2.1829667 , 0.6577131 , 1.5292877 , -1.279115 , -0.14122774, -0.31070137, 0.31876954, -0.02017592], dtype=float32)
PandasIndex(Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], dtype='int64', name='yz_dim_0'))
hier_linear_reg_data = pd.read_csv('datasets/HierLinRegressData.csv')
hier_linear_reg_data['Subj'] = hier_linear_reg_data['Subj'].astype('category')
hier_linear_reg_data.describe()
X | Y | |
---|---|---|
count | 132.000000 | 132.000000 |
mean | 67.084848 | 153.328788 |
std | 7.146830 | 32.041570 |
min | 49.100000 | 70.000000 |
25% | 62.500000 | 131.300000 |
50% | 66.250000 | 153.600000 |
75% | 71.400000 | 172.550000 |
max | 86.200000 | 240.400000 |
mcmc_key = random.PRNGKey(0)
kernel = NUTS(glm_metric.hierarchical_one_metric_predictor_multi_groups_robust)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=20000)
mcmc.run(
mcmc_key,
jnp.array(hier_linear_reg_data.Y.values),
jnp.array(hier_linear_reg_data.X.values),
jnp.array(hier_linear_reg_data.Subj.cat.codes.values),
hier_linear_reg_data.Subj.cat.categories.size,
)
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat nu 39.88 31.29 30.97 3.52 82.04 25047.26 1.00 zb0_[0] 0.66 0.34 0.66 0.12 1.24 5833.33 1.00 zb0_[1] 1.07 0.36 1.05 0.47 1.65 5239.55 1.00 zb0_[2] -0.70 0.36 -0.70 -1.30 -0.13 5389.81 1.00 zb0_[3] 0.64 0.44 0.64 -0.08 1.36 10989.48 1.00 zb0_[4] -0.73 0.38 -0.73 -1.32 -0.09 6746.15 1.00 zb0_[5] -2.17 0.49 -2.14 -2.98 -1.39 5626.90 1.00 zb0_[6] -0.14 0.41 -0.14 -0.82 0.52 8740.15 1.00 zb0_[7] -0.21 0.32 -0.20 -0.75 0.31 6024.53 1.00 zb0_[8] -0.56 0.34 -0.56 -1.08 0.01 5614.86 1.00 zb0_[9] -1.29 0.39 -1.27 -1.92 -0.65 4831.80 1.00 zb0_[10] 0.72 0.34 0.71 0.18 1.30 5819.55 1.00 zb0_[11] -0.46 0.36 -0.45 -1.05 0.14 7087.15 1.00 zb0_[12] 1.32 0.43 1.30 0.64 2.05 7005.42 1.00 zb0_[13] 1.01 0.36 1.00 0.42 1.61 5923.62 1.00 zb0_[14] 0.18 0.44 0.18 -0.55 0.90 10327.46 1.00 zb0_[15] 0.32 0.34 0.32 -0.28 0.83 6205.49 1.00 zb0_[16] -0.05 0.29 -0.05 -0.53 0.41 4622.13 1.00 zb0_[17] -0.83 0.32 -0.82 -1.34 -0.30 4492.31 1.00 zb0_[18] 0.25 0.48 0.25 -0.50 1.06 11627.68 1.00 zb0_[19] 0.73 0.36 0.72 0.12 1.30 6564.84 1.00 zb0_[20] 1.48 0.46 1.46 0.72 2.23 7476.43 1.00 zb0_[21] -0.71 0.33 -0.70 -1.24 -0.18 4987.13 1.00 zb0_[22] -0.51 0.31 -0.50 -1.02 0.02 4961.07 1.00 zb0_[23] 0.96 0.69 0.98 -0.22 2.06 15145.85 1.00 zb0_[24] -1.08 0.70 -1.11 -2.22 0.09 14422.36 1.00 zb0_mean 0.08 0.23 0.08 -0.30 0.45 2846.28 1.00 zb0_std 1.06 0.20 1.03 0.73 1.35 4242.65 1.00 zb1_[0] 0.16 0.91 0.16 -1.38 1.61 24686.58 1.00 zb1_[1] -0.62 0.88 -0.65 -2.08 0.83 19244.29 1.00 zb1_[2] -0.81 0.84 -0.84 -2.20 0.55 13748.94 1.00 zb1_[3] 0.03 0.99 0.04 -1.63 1.63 24966.31 1.00 zb1_[4] -0.30 0.84 -0.31 -1.70 1.07 23699.34 1.00 zb1_[5] 0.82 0.88 0.85 -0.66 2.25 13349.34 1.00 zb1_[6] -0.03 0.92 -0.04 -1.54 1.49 25261.45 1.00 zb1_[7] 0.31 0.87 0.33 -1.10 1.77 24399.64 1.00 zb1_[8] -0.14 0.87 -0.14 -1.59 1.26 25134.93 1.00 zb1_[9] -0.71 0.83 -0.74 -2.08 0.61 14346.51 1.00 zb1_[10] 0.20 0.95 0.22 -1.36 1.75 26909.07 1.00 zb1_[11] 0.05 0.91 0.05 -1.46 1.52 29476.07 1.00 zb1_[12] 0.15 0.92 0.16 -1.44 1.62 20162.71 1.00 zb1_[13] 0.03 0.80 0.03 -1.27 1.35 23669.26 1.00 zb1_[14] -0.14 0.97 -0.14 -1.72 1.47 29731.50 1.00 zb1_[15] -0.05 0.86 -0.06 -1.48 1.36 22293.17 1.00 zb1_[16] -0.12 0.82 -0.13 -1.46 1.21 24933.42 1.00 zb1_[17] 0.49 0.82 0.50 -0.82 1.86 26560.87 1.00 zb1_[18] 0.15 0.99 0.15 -1.48 1.77 29347.23 1.00 zb1_[19] 0.11 0.96 0.11 -1.44 1.67 27526.88 1.00 zb1_[20] 0.45 0.88 0.47 -0.96 1.91 15378.57 1.00 zb1_[21] 0.43 0.88 0.44 -1.05 1.84 23065.02 1.00 zb1_[22] 0.63 0.81 0.65 -0.68 1.98 16946.28 1.00 zb1_[23] -0.48 0.93 -0.49 -2.03 1.04 18804.50 1.00 zb1_[24] -0.57 0.93 -0.60 -2.14 0.90 15822.44 1.00 zb1_mean 0.70 0.10 0.70 0.54 0.88 10261.21 1.00 zb1_std 0.23 0.12 0.22 0.00 0.38 4588.93 1.00 zsigma 0.59 0.05 0.59 0.51 0.68 9970.60 1.00 Number of divergences: 0
income_data_3yr = pd.read_csv('datasets/IncomeFamszState3yr.csv', skiprows=1)
income_data_3yr['State'] = income_data_3yr['State'].astype('category')
income_data_3yr.describe()
FamilySize | MedianIncome | SampErr | |
---|---|---|---|
count | 312.000000 | 312.000000 | 312.000000 |
mean | 4.500000 | 66824.772436 | 2590.810897 |
std | 1.710569 | 14774.402348 | 2410.533770 |
min | 2.000000 | 18860.000000 | 240.000000 |
25% | 3.000000 | 57346.500000 | 1001.500000 |
50% | 4.500000 | 65097.000000 | 1805.000000 |
75% | 6.000000 | 73992.750000 | 3297.750000 |
max | 7.000000 | 124167.000000 | 15121.000000 |
fig, ax = plt.subplots()
for state in income_data_3yr['State'].unique():
state_data = income_data_3yr[income_data_3yr['State'] == state]
state_data = state_data.sort_values('FamilySize')
ax.plot(state_data['FamilySize'], state_data['MedianIncome'], 'o-')
ax.set_title('Median Income of Various States')
ax.set_xlabel('Family Size')
ax.set_ylabel('Median Income')
fig.tight_layout()
mcmc_key = random.PRNGKey(0)
kernel = NUTS(
glm_metric.hierarchical_quadtrend_one_metric_predictor_multi_groups_robust)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=10000)
mcmc.run(
mcmc_key,
jnp.array(income_data_3yr['MedianIncome'].values),
jnp.array(income_data_3yr['FamilySize'].values),
jnp.array(income_data_3yr['State'].cat.codes.values),
income_data_3yr['State'].cat.categories.size,
jnp.array(income_data_3yr['SampErr'].values),
)
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat b0_z_[0] -0.79 0.25 -0.80 -1.17 -0.37 1450.18 1.00 b0_z_[1] 1.15 0.28 1.15 0.69 1.60 1368.14 1.00 b0_z_[2] -0.75 0.22 -0.75 -1.11 -0.39 1152.45 1.00 b0_z_[3] -1.19 0.23 -1.19 -1.57 -0.83 1141.32 1.00 b0_z_[4] -0.08 0.27 -0.11 -0.52 0.41 1429.61 1.00 b0_z_[5] 0.31 0.24 0.29 -0.10 0.69 1280.56 1.00 b0_z_[6] 1.93 0.34 1.92 1.39 2.51 1360.38 1.00 b0_z_[7] 0.70 0.27 0.70 0.24 1.13 1432.02 1.00 b0_z_[8] 0.22 0.45 0.21 -0.51 0.95 4897.73 1.00 b0_z_[9] -0.72 0.22 -0.72 -1.08 -0.37 1135.64 1.00 b0_z_[10] -0.68 0.23 -0.69 -1.04 -0.30 1229.01 1.00 b0_z_[11] 1.06 0.25 1.05 0.65 1.45 1091.51 1.00 b0_z_[12] -0.79 0.22 -0.79 -1.17 -0.43 1265.59 1.00 b0_z_[13] 0.32 0.18 0.32 0.02 0.61 791.32 1.00 b0_z_[14] -0.32 0.20 -0.32 -0.66 0.01 1022.26 1.00 b0_z_[15] 0.05 0.20 0.05 -0.26 0.38 1068.64 1.00 b0_z_[16] 0.05 0.19 0.05 -0.27 0.36 929.39 1.00 b0_z_[17] -0.72 0.22 -0.72 -1.08 -0.36 1140.51 1.00 b0_z_[18] -0.46 0.23 -0.46 -0.83 -0.09 1262.86 1.00 b0_z_[19] -0.02 0.24 -0.02 -0.41 0.39 1533.15 1.00 b0_z_[20] 1.86 0.28 1.85 1.38 2.31 1036.08 1.00 b0_z_[21] 2.01 0.35 2.01 1.44 2.61 1632.92 1.00 b0_z_[22] -0.19 0.19 -0.20 -0.50 0.13 954.04 1.00 b0_z_[23] 0.88 0.21 0.88 0.53 1.21 858.27 1.00 b0_z_[24] -1.43 0.26 -1.42 -1.87 -1.00 1397.35 1.00 b0_z_[25] -0.19 0.20 -0.18 -0.52 0.14 1017.89 1.00 b0_z_[26] -0.42 0.26 -0.42 -0.82 0.02 1843.30 1.00 b0_z_[27] 0.16 0.19 0.16 -0.15 0.47 939.60 1.00 b0_z_[28] -0.33 0.25 -0.33 -0.74 0.07 1341.16 1.00 b0_z_[29] 1.52 0.31 1.52 1.04 2.04 1387.94 1.00 b0_z_[30] 2.18 0.31 2.18 1.66 2.67 1082.44 1.00 b0_z_[31] -1.10 0.24 -1.10 -1.49 -0.73 1347.06 1.00 b0_z_[32] 0.59 0.20 0.58 0.27 0.92 868.09 1.00 b0_z_[33] -0.80 0.22 -0.81 -1.16 -0.45 1061.66 1.00 b0_z_[34] 0.46 0.25 0.46 0.08 0.89 1313.16 1.00 b0_z_[35] -0.20 0.20 -0.21 -0.51 0.13 941.25 1.00 b0_z_[36] -0.86 0.22 -0.86 -1.22 -0.50 1158.80 1.00 b0_z_[37] -0.41 0.20 -0.41 -0.73 -0.08 1037.83 1.00 b0_z_[38] 0.50 0.18 0.50 0.22 0.81 715.06 1.00 b0_z_[39] -3.22 0.37 -3.21 -3.83 -2.63 1134.21 1.00 b0_z_[40] 0.72 0.30 0.72 0.23 1.20 1908.49 1.00 b0_z_[41] -0.83 0.23 -0.83 -1.23 -0.46 1199.08 1.00 b0_z_[42] -0.08 0.21 -0.08 -0.45 0.25 1081.10 1.00 b0_z_[43] -0.77 0.21 -0.77 -1.13 -0.44 1116.83 1.00 b0_z_[44] -0.77 0.20 -0.78 -1.10 -0.45 942.33 1.00 b0_z_[45] -0.10 0.18 -0.09 -0.38 0.19 791.24 1.00 b0_z_[46] 0.30 0.26 0.30 -0.10 0.74 1446.57 1.00 b0_z_[47] 0.97 0.23 0.96 0.61 1.34 929.54 1.00 b0_z_[48] 0.34 0.24 0.33 -0.03 0.76 1479.49 1.00 b0_z_[49] -0.69 0.22 -0.69 -1.05 -0.32 1336.65 1.00 b0_z_[50] 0.23 0.20 0.23 -0.09 0.56 965.97 1.00 b0_z_[51] 0.36 0.24 0.36 -0.01 0.77 1435.54 1.00 b0_z_mean 0.35 0.14 0.35 0.13 0.58 545.08 1.00 b0_z_std 0.93 0.10 0.92 0.77 1.08 1199.85 1.00 b1_z_[0] -0.68 0.59 -0.70 -1.61 0.34 8455.76 1.00 b1_z_[1] -0.26 0.72 -0.26 -1.48 0.87 11167.06 1.00 b1_z_[2] -1.22 0.58 -1.22 -2.18 -0.28 8356.65 1.00 b1_z_[3] -0.40 0.59 -0.40 -1.33 0.59 8836.51 1.00 b1_z_[4] -0.87 0.79 -0.93 -2.05 0.38 2606.24 1.00 b1_z_[5] -0.99 0.60 -1.01 -2.01 -0.06 7509.17 1.00 b1_z_[6] 1.34 0.82 1.37 -0.03 2.66 6927.01 1.00 b1_z_[7] 0.48 0.83 0.50 -0.86 1.83 9103.59 1.00 b1_z_[8] -0.50 0.94 -0.52 -2.07 1.00 13566.42 1.00 b1_z_[9] -0.18 0.55 -0.16 -1.07 0.73 6259.88 1.00 b1_z_[10] -0.53 0.56 -0.52 -1.45 0.36 8877.50 1.00 b1_z_[11] 1.05 0.88 1.05 -0.35 2.51 7673.05 1.00 b1_z_[12] -0.09 0.68 -0.09 -1.16 1.05 9239.28 1.00 b1_z_[13] -0.26 0.49 -0.25 -1.04 0.58 6958.81 1.00 b1_z_[14] 0.28 0.52 0.29 -0.56 1.14 6029.79 1.00 b1_z_[15] 0.01 0.53 0.01 -0.85 0.88 7031.98 1.00 b1_z_[16] -0.54 0.60 -0.54 -1.54 0.41 9005.49 1.00 b1_z_[17] 0.07 0.61 0.07 -0.89 1.11 6967.88 1.00 b1_z_[18] 0.07 0.61 0.06 -0.95 1.04 8288.69 1.00 b1_z_[19] 0.38 0.74 0.39 -0.83 1.59 10146.34 1.00 b1_z_[20] 0.98 0.63 0.99 -0.06 2.00 7663.19 1.00 b1_z_[21] 1.82 0.76 1.88 0.63 3.05 4077.76 1.00 b1_z_[22] -0.47 0.47 -0.49 -1.22 0.32 4714.90 1.00 b1_z_[23] -0.26 0.56 -0.29 -1.08 0.74 5557.14 1.00 b1_z_[24] -0.71 0.64 -0.72 -1.74 0.36 9272.17 1.00 b1_z_[25] 0.24 0.52 0.24 -0.62 1.09 6501.39 1.00 b1_z_[26] 0.13 0.78 0.15 -1.20 1.36 9618.27 1.00 b1_z_[27] -0.18 0.65 -0.19 -1.20 0.90 8060.66 1.00 b1_z_[28] 0.02 0.78 0.03 -1.27 1.30 9766.62 1.00 b1_z_[29] 1.06 0.86 1.08 -0.33 2.53 9035.78 1.00 b1_z_[30] 1.68 0.68 1.71 0.66 2.82 4323.27 1.00 b1_z_[31] -0.75 0.62 -0.75 -1.68 0.36 9415.75 1.00 b1_z_[32] 1.28 0.58 1.31 0.36 2.22 6385.65 1.00 b1_z_[33] -0.98 0.46 -0.97 -1.73 -0.23 7322.08 1.00 b1_z_[34] 0.10 0.74 0.10 -1.18 1.25 9993.30 1.00 b1_z_[35] -0.02 0.45 -0.03 -0.74 0.72 7152.94 1.00 b1_z_[36] -0.63 0.60 -0.62 -1.58 0.35 8027.27 1.00 b1_z_[37] -0.35 0.59 -0.34 -1.28 0.63 7082.06 1.00 b1_z_[38] 0.53 0.49 0.53 -0.26 1.35 6231.43 1.00 b1_z_[39] -0.20 0.46 -0.19 -0.92 0.57 6402.48 1.00 b1_z_[40] 0.44 0.79 0.46 -0.83 1.73 9408.99 1.00 b1_z_[41] -0.74 0.61 -0.76 -1.66 0.34 8277.29 1.00 b1_z_[42] -0.37 0.71 -0.39 -1.57 0.74 9494.07 1.00 b1_z_[43] -0.38 0.59 -0.39 -1.32 0.59 6960.49 1.00 b1_z_[44] -1.36 0.44 -1.33 -2.04 -0.64 4641.72 1.00 b1_z_[45] 1.49 0.64 1.53 0.49 2.56 4863.52 1.00 b1_z_[46] 0.13 0.72 0.13 -1.13 1.25 11238.97 1.00 b1_z_[47] 1.20 0.64 1.22 0.21 2.25 6999.60 1.00 b1_z_[48] -0.86 0.65 -0.87 -1.96 0.15 7828.48 1.00 b1_z_[49] 0.41 0.69 0.41 -0.73 1.53 7888.90 1.00 b1_z_[50] -0.31 0.49 -0.33 -1.06 0.52 5807.68 1.00 b1_z_[51] -0.15 0.76 -0.15 -1.38 1.09 9394.14 1.00 b1_z_mean 0.15 0.03 0.15 0.10 0.20 2850.12 1.00 b1_z_std 0.16 0.03 0.16 0.11 0.21 2665.44 1.00 b2_z_[0] -0.05 0.69 -0.04 -1.18 1.04 6944.49 1.00 b2_z_[1] 0.16 0.75 0.16 -1.02 1.45 9196.56 1.00 b2_z_[2] 0.92 0.62 0.93 -0.13 1.90 6788.75 1.00 b2_z_[3] 0.75 0.59 0.76 -0.18 1.71 7206.57 1.00 b2_z_[4] 1.08 0.66 1.07 -0.00 2.18 5140.49 1.00 b2_z_[5] 0.18 0.67 0.21 -0.91 1.25 5093.21 1.00 b2_z_[6] -0.90 0.81 -0.89 -2.23 0.41 7453.28 1.00 b2_z_[7] -0.23 0.76 -0.24 -1.50 0.98 6694.09 1.00 b2_z_[8] 0.34 0.97 0.34 -1.29 1.90 12643.94 1.00 b2_z_[9] 0.77 0.66 0.80 -0.30 1.86 5325.88 1.00 b2_z_[10] 0.59 0.61 0.63 -0.46 1.53 6364.62 1.00 b2_z_[11] -0.27 0.74 -0.30 -1.45 0.97 7226.46 1.00 b2_z_[12] 0.72 0.66 0.72 -0.37 1.79 7921.14 1.00 b2_z_[13] -0.26 0.50 -0.24 -1.04 0.57 6080.19 1.00 b2_z_[14] 0.15 0.58 0.18 -0.82 1.06 5499.07 1.00 b2_z_[15] 0.05 0.59 0.07 -0.86 1.05 6387.36 1.00 b2_z_[16] -0.32 0.57 -0.32 -1.25 0.61 7478.96 1.00 b2_z_[17] -0.06 0.62 -0.03 -1.11 0.90 6841.69 1.00 b2_z_[18] -0.53 0.65 -0.52 -1.64 0.50 7603.70 1.00 b2_z_[19] -0.37 0.72 -0.37 -1.51 0.86 7161.02 1.00 b2_z_[20] -0.52 0.66 -0.52 -1.60 0.55 8369.23 1.00 b2_z_[21] -1.62 0.84 -1.64 -3.00 -0.27 6539.84 1.00 b2_z_[22] -0.79 0.53 -0.77 -1.70 0.04 4437.85 1.00 b2_z_[23] -1.35 0.56 -1.34 -2.22 -0.41 5550.22 1.00 b2_z_[24] 0.46 0.70 0.49 -0.69 1.59 8079.32 1.00 b2_z_[25] -0.31 0.57 -0.31 -1.23 0.62 6176.96 1.00 b2_z_[26] 0.40 0.72 0.41 -0.80 1.56 7537.81 1.00 b2_z_[27] -0.39 0.58 -0.40 -1.37 0.51 5999.67 1.00 b2_z_[28] 0.92 0.74 0.94 -0.31 2.13 6995.62 1.00 b2_z_[29] -1.18 0.87 -1.23 -2.56 0.29 6873.78 1.00 b2_z_[30] -1.82 0.72 -1.84 -3.03 -0.71 6347.00 1.00 b2_z_[31] 1.30 0.71 1.34 0.18 2.51 7236.81 1.00 b2_z_[32] -0.33 0.60 -0.27 -1.30 0.64 5778.77 1.00 b2_z_[33] 0.43 0.56 0.48 -0.50 1.33 5003.79 1.00 b2_z_[34] -0.38 0.73 -0.39 -1.60 0.77 7147.15 1.00 b2_z_[35] -0.35 0.56 -0.31 -1.28 0.55 6618.10 1.00 b2_z_[36] 0.70 0.64 0.73 -0.31 1.74 6186.56 1.00 b2_z_[37] 0.60 0.59 0.62 -0.35 1.58 6766.83 1.00 b2_z_[38] -1.38 0.49 -1.36 -2.14 -0.54 5191.64 1.00 b2_z_[39] 1.13 0.46 1.12 0.37 1.88 5407.33 1.00 b2_z_[40] -0.62 0.78 -0.64 -1.92 0.63 7858.76 1.00 b2_z_[41] 0.46 0.67 0.43 -0.63 1.56 6198.23 1.00 b2_z_[42] -0.15 0.68 -0.15 -1.27 0.98 7564.69 1.00 b2_z_[43] 0.15 0.63 0.16 -0.89 1.17 5738.07 1.00 b2_z_[44] 1.14 0.49 1.16 0.34 1.93 4734.09 1.00 b2_z_[45] 1.37 0.61 1.39 0.43 2.41 4903.99 1.00 b2_z_[46] -0.16 0.76 -0.17 -1.41 1.07 8424.95 1.00 b2_z_[47] 0.08 0.66 0.10 -1.04 1.13 5549.73 1.00 b2_z_[48] 0.07 0.64 0.08 -0.99 1.13 7276.88 1.00 b2_z_[49] -0.21 0.65 -0.20 -1.26 0.86 7048.58 1.00 b2_z_[50] -0.72 0.56 -0.70 -1.64 0.16 5802.88 1.00 b2_z_[51] 0.31 0.72 0.32 -0.82 1.54 8773.60 1.00 b2_z_mean -0.39 0.03 -0.39 -0.44 -0.34 2734.47 1.00 b2_z_std 0.15 0.03 0.15 0.10 0.19 2256.88 1.00 nu 2.22 0.67 2.08 1.30 3.12 2639.97 1.00 sigma_z 0.24 0.04 0.24 0.18 0.31 1708.79 1.00 Number of divergences: 0
def choose_credible_parabola_parameters(idata, hdi=0.95, *, a='b0', b='b1', c='b2', n_curves=20, dim=None):
"""
Plot credible parabola ax^2 + bx + c = 0.
"""
def is_between(values, low, high):
return (values >= low) & (values <= high)
posterior = idata.posterior
hdi_posterior = az.hdi(posterior, hdi_prob=hdi)
a_posterior = posterior[a].sel(dim).values.flatten()
b_posterior = posterior[b].sel(dim).values.flatten()
c_posterior = posterior[c].sel(dim).values.flatten()
# Choose only parameters in the specified HDI.
amask_in_hdi = is_between(a_posterior, *hdi_posterior[a].sel(dim).values)
bmask_in_hdi = is_between(b_posterior, *hdi_posterior[b].sel(dim).values)
cmask_in_hdi = is_between(c_posterior, *hdi_posterior[c].sel(dim).values)
params_in_hdi = amask_in_hdi & bmask_in_hdi & cmask_in_hdi
idx_in_hdi = np.arange(len(posterior.chain) *
len(posterior.draw))[params_in_hdi.flatten()]
# Then, randomly choose from these parameters to plot the results.
for idx in np.random.choice(idx_in_hdi, n_curves, replace=False):
aval = a_posterior[idx]
bval = b_posterior[idx]
cval = c_posterior[idx]
yield aval, bval, cval
idata = az.from_numpyro(
mcmc,
coords=dict(state=np.arange(len(income_data_3yr['State'].cat.categories))),
dims=dict(b0=['state'], b1=['state'], b2=['state'],
b0_z=['state'], b1_z=['state'], b2_z=['state'],),
)
n_posterior_curves = 20
x = np.linspace(1, 7.5, 1000)
for b0_mean, b1_mean, b2_mean in choose_credible_parabola_parameters(
idata, a='b0_mean', b='b1_mean', c='b2_mean', n_curves=n_posterior_curves):
ax.plot(x, b0_mean + b1_mean * x + b2_mean * x**2, c='b')
fig
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(20, 20))
for state_idx, state, ax in zip(range(income_data_3yr['State'].cat.categories.size),
income_data_3yr['State'].cat.categories,
axes.flatten()):
# Plot posterior curves.
for b0, b1, b2 in choose_credible_parabola_parameters(idata, a='b0', b='b1', c='b2', dim=dict(state=state_idx)):
ax.plot(x, b0 + b1 * x + b2 * x**2, c='b')
# Superimpose state median income.
state_data = income_data_3yr[income_data_3yr['State'] == state]
state_data = state_data.sort_values('FamilySize')
ax.plot(state_data['FamilySize'], state_data['MedianIncome'], 'ko')
ax.set_title(state)
fig.tight_layout()