%cd ..
%load_ext autoreload
%autoreload 2
/home/runner/work/numpyro-doing-bayesian/numpyro-doing-bayesian
import arviz as az
import jax.numpy as jnp
import jax.random as random
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from numpyro.infer import MCMC, NUTS
import numpyro_glm
import numpyro_glm.metric.models as glm_metric

Chapter 17: Metric Predicted Variable with one Metric Predictor

Simple Linear Regression

x = np.random.uniform(-10, 10, size=500)
y = np.random.normal(10 + 2 * x, 2)

fig, ax = plt.subplots()
ax.scatter(x, y, c='black', s=4)
ax.set_title('Normal PDF around linear function')

xline = np.linspace(-10, 10, 1000)
ax.plot(xline, 10 + 2 * xline, lw=4, c='#87ceeb')

# TODO
for xinterval in [-7.5, -2.5, 2.5, 7.5]:
    y_ = np.linspace(xinterval - 6, xinterval + 6, 1000)

Robust Linear Regression

height_weight_30_data = pd.read_csv('datasets/HtWtData30.csv')
height_weight_30_data.describe()
male height weight
count 30.000000 30.000000 30.000000
mean 0.466667 66.983333 154.013333
std 0.507416 4.201074 31.702570
min 0.000000 57.500000 96.500000
25% 0.000000 64.100000 128.850000
50% 0.000000 66.450000 147.900000
75% 1.000000 69.575000 179.650000
max 1.000000 76.000000 215.100000

We will test the model with raw data (no standardization applied to both y and x).

mcmc_key = random.PRNGKey(0)
kernel = NUTS(glm_metric.one_metric_predictor_robust_no_standardization)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=20000)
mcmc.run(
    mcmc_key,
    jnp.array(height_weight_30_data.height.values),
    jnp.array(height_weight_30_data.weight.values),
)
mcmc.print_summary()
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
                mean       std    median      5.0%     95.0%     n_eff     r_hat
        b0     48.72      3.84     48.95     42.70     55.15   6216.35      1.00
        b1      0.12      0.02      0.12      0.08      0.16   6252.35      1.00
        nu     23.76     25.95     14.17      1.11     57.71   8596.60      1.00
     sigma      3.57      0.75      3.52      2.30      4.71   6426.80      1.00

Number of divergences: 0
numpyro_glm.plot_diagnostic(mcmc, ['b0', 'b1', 'nu', 'sigma'])
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:     (chain: 1, draw: 20000, mean_dim_0: 30)
      Coordinates:
        * chain       (chain) int64 0
        * draw        (draw) int64 0 1 2 3 4 5 ... 19994 19995 19996 19997 19998 19999
        * mean_dim_0  (mean_dim_0) int64 0 1 2 3 4 5 6 7 8 ... 22 23 24 25 26 27 28 29
      Data variables:
          b0          (chain, draw) float32 41.41 43.14 43.32 ... 48.34 45.5 50.65
          b1          (chain, draw) float32 0.1631 0.148 0.1502 ... 0.1378 0.1047
          mean        (chain, draw, mean_dim_0) float32 63.66 76.49 ... 69.36 62.29
          nu          (chain, draw) float32 50.76 97.48 102.5 ... 2.813 11.31 3.153
          sigma       (chain, draw) float32 4.595 4.4 4.371 ... 2.978 2.951 3.091
      Attributes:
          created_at:                 2023-05-29T23:52:41.500946
          arviz_version:              0.12.1
          inference_library:          numpyro
          inference_library_version:  0.9.2

    • <xarray.Dataset>
      Dimensions:  (chain: 1, draw: 20000, y_dim_0: 30)
      Coordinates:
        * chain    (chain) int64 0
        * draw     (draw) int64 0 1 2 3 4 5 6 ... 19994 19995 19996 19997 19998 19999
        * y_dim_0  (y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 21 22 23 24 25 26 27 28 29
      Data variables:
          y        (chain, draw, y_dim_0) float32 -2.452 -6.907 ... -2.201 -3.235
      Attributes:
          created_at:                 2023-05-29T23:52:41.687343
          arviz_version:              0.12.1
          inference_library:          numpyro
          inference_library_version:  0.9.2

    • <xarray.Dataset>
      Dimensions:    (chain: 1, draw: 20000)
      Coordinates:
        * chain      (chain) int64 0
        * draw       (draw) int64 0 1 2 3 4 5 ... 19994 19995 19996 19997 19998 19999
      Data variables:
          diverging  (chain, draw) bool False False False False ... False False False
      Attributes:
          created_at:                 2023-05-29T23:52:41.513411
          arviz_version:              0.12.1
          inference_library:          numpyro
          inference_library_version:  0.9.2

    • <xarray.Dataset>
      Dimensions:  (y_dim_0: 30)
      Coordinates:
        * y_dim_0  (y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 21 22 23 24 25 26 27 28 29
      Data variables:
          y        (y_dim_0) float32 64.0 62.3 67.9 64.2 64.8 ... 66.4 65.7 68.3 66.9
      Attributes:
          created_at:                 2023-05-29T23:52:41.688187
          arviz_version:              0.12.1
          inference_library:          numpyro
          inference_library_version:  0.9.2

Here is the model with both y and x standardized.

mcmc_key = random.PRNGKey(0)
kernel = NUTS(glm_metric.one_metric_predictor_robust)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=20000)
mcmc.run(
    mcmc_key,
    jnp.array(height_weight_30_data.height.values),
    jnp.array(height_weight_30_data.weight.values),
)
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
        nu     28.91     28.39     19.63      1.26     65.21  15679.32      1.00
       zb0      0.03      0.17      0.03     -0.25      0.29  19537.09      1.00
       zb1      0.58      0.18      0.58      0.27      0.85  18687.78      1.00
    zsigma      0.84      0.15      0.83      0.60      1.08  12668.37      1.00

Number of divergences: 0
numpyro_glm.plot_diagnostic(mcmc, ['zb0', 'zb1', 'nu', 'zsigma', 'b0'])
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:      (chain: 1, draw: 20000, zmean_dim_0: 30)
      Coordinates:
        * chain        (chain) int64 0
        * draw         (draw) int64 0 1 2 3 4 5 ... 19995 19996 19997 19998 19999
        * zmean_dim_0  (zmean_dim_0) int64 0 1 2 3 4 5 6 7 ... 22 23 24 25 26 27 28 29
      Data variables:
          b0           (chain, draw) float32 54.14 59.34 55.78 ... 53.82 57.74 53.81
          b1           (chain, draw) float32 0.08307 0.05473 ... 0.0589 0.08952
          nu           (chain, draw) float32 8.131 44.38 25.45 ... 7.959 19.05 13.96
          sigma        (chain, draw) float32 2.563 3.453 3.461 ... 2.761 3.062 3.668
          zb0          (chain, draw) float32 -0.0121 0.1893 ... -0.04201 0.1493
          zb1          (chain, draw) float32 0.6268 0.413 0.5652 ... 0.4445 0.6756
          zmean        (chain, draw, zmean_dim_0) float32 -0.3663 1.216 ... -0.7808
          zsigma       (chain, draw) float32 0.6205 0.8361 0.838 ... 0.7412 0.8879
      Attributes:
          created_at:                 2023-05-29T23:52:45.902172
          arviz_version:              0.12.1
          inference_library:          numpyro
          inference_library_version:  0.9.2

    • <xarray.Dataset>
      Dimensions:   (chain: 1, draw: 20000, yz_dim_0: 30)
      Coordinates:
        * chain     (chain) int64 0
        * draw      (draw) int64 0 1 2 3 4 5 6 ... 19994 19995 19996 19997 19998 19999
        * yz_dim_0  (yz_dim_0) int64 0 1 2 3 4 5 6 7 8 ... 21 22 23 24 25 26 27 28 29
      Data variables:
          yz        (chain, draw, yz_dim_0) float32 -0.6535 -5.115 ... -0.9072 -1.201
      Attributes:
          created_at:                 2023-05-29T23:52:46.094067
          arviz_version:              0.12.1
          inference_library:          numpyro
          inference_library_version:  0.9.2

    • <xarray.Dataset>
      Dimensions:    (chain: 1, draw: 20000)
      Coordinates:
        * chain      (chain) int64 0
        * draw       (draw) int64 0 1 2 3 4 5 ... 19994 19995 19996 19997 19998 19999
      Data variables:
          diverging  (chain, draw) bool False False False False ... False False False
      Attributes:
          created_at:                 2023-05-29T23:52:45.906057
          arviz_version:              0.12.1
          inference_library:          numpyro
          inference_library_version:  0.9.2

    • <xarray.Dataset>
      Dimensions:   (yz_dim_0: 30)
      Coordinates:
        * yz_dim_0  (yz_dim_0) int64 0 1 2 3 4 5 6 7 8 ... 21 22 23 24 25 26 27 28 29
      Data variables:
          yz        (yz_dim_0) float32 -0.7223 -1.134 0.2219 ... 0.3188 -0.02018
      Attributes:
          created_at:                 2023-05-29T23:52:46.094933
          arviz_version:              0.12.1
          inference_library:          numpyro
          inference_library_version:  0.9.2

Hierarchical Regression on Individuals within Groups

hier_linear_reg_data = pd.read_csv('datasets/HierLinRegressData.csv')
hier_linear_reg_data['Subj'] = hier_linear_reg_data['Subj'].astype('category')
hier_linear_reg_data.describe()
X Y
count 132.000000 132.000000
mean 67.084848 153.328788
std 7.146830 32.041570
min 49.100000 70.000000
25% 62.500000 131.300000
50% 66.250000 153.600000
75% 71.400000 172.550000
max 86.200000 240.400000
mcmc_key = random.PRNGKey(0)
kernel = NUTS(glm_metric.hierarchical_one_metric_predictor_multi_groups_robust)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=20000)
mcmc.run(
    mcmc_key,
    jnp.array(hier_linear_reg_data.Y.values),
    jnp.array(hier_linear_reg_data.X.values),
    jnp.array(hier_linear_reg_data.Subj.cat.codes.values),
    hier_linear_reg_data.Subj.cat.categories.size,
)
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
        nu     39.88     31.29     30.97      3.52     82.04  25047.26      1.00
   zb0_[0]      0.66      0.34      0.66      0.12      1.24   5833.33      1.00
   zb0_[1]      1.07      0.36      1.05      0.47      1.65   5239.55      1.00
   zb0_[2]     -0.70      0.36     -0.70     -1.30     -0.13   5389.81      1.00
   zb0_[3]      0.64      0.44      0.64     -0.08      1.36  10989.48      1.00
   zb0_[4]     -0.73      0.38     -0.73     -1.32     -0.09   6746.15      1.00
   zb0_[5]     -2.17      0.49     -2.14     -2.98     -1.39   5626.90      1.00
   zb0_[6]     -0.14      0.41     -0.14     -0.82      0.52   8740.15      1.00
   zb0_[7]     -0.21      0.32     -0.20     -0.75      0.31   6024.53      1.00
   zb0_[8]     -0.56      0.34     -0.56     -1.08      0.01   5614.86      1.00
   zb0_[9]     -1.29      0.39     -1.27     -1.92     -0.65   4831.80      1.00
  zb0_[10]      0.72      0.34      0.71      0.18      1.30   5819.55      1.00
  zb0_[11]     -0.46      0.36     -0.45     -1.05      0.14   7087.15      1.00
  zb0_[12]      1.32      0.43      1.30      0.64      2.05   7005.42      1.00
  zb0_[13]      1.01      0.36      1.00      0.42      1.61   5923.62      1.00
  zb0_[14]      0.18      0.44      0.18     -0.55      0.90  10327.46      1.00
  zb0_[15]      0.32      0.34      0.32     -0.28      0.83   6205.49      1.00
  zb0_[16]     -0.05      0.29     -0.05     -0.53      0.41   4622.13      1.00
  zb0_[17]     -0.83      0.32     -0.82     -1.34     -0.30   4492.31      1.00
  zb0_[18]      0.25      0.48      0.25     -0.50      1.06  11627.68      1.00
  zb0_[19]      0.73      0.36      0.72      0.12      1.30   6564.84      1.00
  zb0_[20]      1.48      0.46      1.46      0.72      2.23   7476.43      1.00
  zb0_[21]     -0.71      0.33     -0.70     -1.24     -0.18   4987.13      1.00
  zb0_[22]     -0.51      0.31     -0.50     -1.02      0.02   4961.07      1.00
  zb0_[23]      0.96      0.69      0.98     -0.22      2.06  15145.85      1.00
  zb0_[24]     -1.08      0.70     -1.11     -2.22      0.09  14422.36      1.00
  zb0_mean      0.08      0.23      0.08     -0.30      0.45   2846.28      1.00
   zb0_std      1.06      0.20      1.03      0.73      1.35   4242.65      1.00
   zb1_[0]      0.16      0.91      0.16     -1.38      1.61  24686.58      1.00
   zb1_[1]     -0.62      0.88     -0.65     -2.08      0.83  19244.29      1.00
   zb1_[2]     -0.81      0.84     -0.84     -2.20      0.55  13748.94      1.00
   zb1_[3]      0.03      0.99      0.04     -1.63      1.63  24966.31      1.00
   zb1_[4]     -0.30      0.84     -0.31     -1.70      1.07  23699.34      1.00
   zb1_[5]      0.82      0.88      0.85     -0.66      2.25  13349.34      1.00
   zb1_[6]     -0.03      0.92     -0.04     -1.54      1.49  25261.45      1.00
   zb1_[7]      0.31      0.87      0.33     -1.10      1.77  24399.64      1.00
   zb1_[8]     -0.14      0.87     -0.14     -1.59      1.26  25134.93      1.00
   zb1_[9]     -0.71      0.83     -0.74     -2.08      0.61  14346.51      1.00
  zb1_[10]      0.20      0.95      0.22     -1.36      1.75  26909.07      1.00
  zb1_[11]      0.05      0.91      0.05     -1.46      1.52  29476.07      1.00
  zb1_[12]      0.15      0.92      0.16     -1.44      1.62  20162.71      1.00
  zb1_[13]      0.03      0.80      0.03     -1.27      1.35  23669.26      1.00
  zb1_[14]     -0.14      0.97     -0.14     -1.72      1.47  29731.50      1.00
  zb1_[15]     -0.05      0.86     -0.06     -1.48      1.36  22293.17      1.00
  zb1_[16]     -0.12      0.82     -0.13     -1.46      1.21  24933.42      1.00
  zb1_[17]      0.49      0.82      0.50     -0.82      1.86  26560.87      1.00
  zb1_[18]      0.15      0.99      0.15     -1.48      1.77  29347.23      1.00
  zb1_[19]      0.11      0.96      0.11     -1.44      1.67  27526.88      1.00
  zb1_[20]      0.45      0.88      0.47     -0.96      1.91  15378.57      1.00
  zb1_[21]      0.43      0.88      0.44     -1.05      1.84  23065.02      1.00
  zb1_[22]      0.63      0.81      0.65     -0.68      1.98  16946.28      1.00
  zb1_[23]     -0.48      0.93     -0.49     -2.03      1.04  18804.50      1.00
  zb1_[24]     -0.57      0.93     -0.60     -2.14      0.90  15822.44      1.00
  zb1_mean      0.70      0.10      0.70      0.54      0.88  10261.21      1.00
   zb1_std      0.23      0.12      0.22      0.00      0.38   4588.93      1.00
    zsigma      0.59      0.05      0.59      0.51      0.68   9970.60      1.00

Number of divergences: 0

Quadratic Trend and Weighted Data

income_data_3yr = pd.read_csv('datasets/IncomeFamszState3yr.csv', skiprows=1)
income_data_3yr['State'] = income_data_3yr['State'].astype('category')
income_data_3yr.describe()
FamilySize MedianIncome SampErr
count 312.000000 312.000000 312.000000
mean 4.500000 66824.772436 2590.810897
std 1.710569 14774.402348 2410.533770
min 2.000000 18860.000000 240.000000
25% 3.000000 57346.500000 1001.500000
50% 4.500000 65097.000000 1805.000000
75% 6.000000 73992.750000 3297.750000
max 7.000000 124167.000000 15121.000000
fig, ax = plt.subplots()

for state in income_data_3yr['State'].unique():
    state_data = income_data_3yr[income_data_3yr['State'] == state]
    state_data = state_data.sort_values('FamilySize')

    ax.plot(state_data['FamilySize'], state_data['MedianIncome'], 'o-')

ax.set_title('Median Income of Various States')
ax.set_xlabel('Family Size')
ax.set_ylabel('Median Income')
fig.tight_layout()
mcmc_key = random.PRNGKey(0)
kernel = NUTS(
    glm_metric.hierarchical_quadtrend_one_metric_predictor_multi_groups_robust)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=10000)
mcmc.run(
    mcmc_key,
    jnp.array(income_data_3yr['MedianIncome'].values),
    jnp.array(income_data_3yr['FamilySize'].values),
    jnp.array(income_data_3yr['State'].cat.codes.values),
    income_data_3yr['State'].cat.categories.size,
    jnp.array(income_data_3yr['SampErr'].values),
)
mcmc.print_summary()
                 mean       std    median      5.0%     95.0%     n_eff     r_hat
   b0_z_[0]     -0.79      0.25     -0.80     -1.17     -0.37   1450.18      1.00
   b0_z_[1]      1.15      0.28      1.15      0.69      1.60   1368.14      1.00
   b0_z_[2]     -0.75      0.22     -0.75     -1.11     -0.39   1152.45      1.00
   b0_z_[3]     -1.19      0.23     -1.19     -1.57     -0.83   1141.32      1.00
   b0_z_[4]     -0.08      0.27     -0.11     -0.52      0.41   1429.61      1.00
   b0_z_[5]      0.31      0.24      0.29     -0.10      0.69   1280.56      1.00
   b0_z_[6]      1.93      0.34      1.92      1.39      2.51   1360.38      1.00
   b0_z_[7]      0.70      0.27      0.70      0.24      1.13   1432.02      1.00
   b0_z_[8]      0.22      0.45      0.21     -0.51      0.95   4897.73      1.00
   b0_z_[9]     -0.72      0.22     -0.72     -1.08     -0.37   1135.64      1.00
  b0_z_[10]     -0.68      0.23     -0.69     -1.04     -0.30   1229.01      1.00
  b0_z_[11]      1.06      0.25      1.05      0.65      1.45   1091.51      1.00
  b0_z_[12]     -0.79      0.22     -0.79     -1.17     -0.43   1265.59      1.00
  b0_z_[13]      0.32      0.18      0.32      0.02      0.61    791.32      1.00
  b0_z_[14]     -0.32      0.20     -0.32     -0.66      0.01   1022.26      1.00
  b0_z_[15]      0.05      0.20      0.05     -0.26      0.38   1068.64      1.00
  b0_z_[16]      0.05      0.19      0.05     -0.27      0.36    929.39      1.00
  b0_z_[17]     -0.72      0.22     -0.72     -1.08     -0.36   1140.51      1.00
  b0_z_[18]     -0.46      0.23     -0.46     -0.83     -0.09   1262.86      1.00
  b0_z_[19]     -0.02      0.24     -0.02     -0.41      0.39   1533.15      1.00
  b0_z_[20]      1.86      0.28      1.85      1.38      2.31   1036.08      1.00
  b0_z_[21]      2.01      0.35      2.01      1.44      2.61   1632.92      1.00
  b0_z_[22]     -0.19      0.19     -0.20     -0.50      0.13    954.04      1.00
  b0_z_[23]      0.88      0.21      0.88      0.53      1.21    858.27      1.00
  b0_z_[24]     -1.43      0.26     -1.42     -1.87     -1.00   1397.35      1.00
  b0_z_[25]     -0.19      0.20     -0.18     -0.52      0.14   1017.89      1.00
  b0_z_[26]     -0.42      0.26     -0.42     -0.82      0.02   1843.30      1.00
  b0_z_[27]      0.16      0.19      0.16     -0.15      0.47    939.60      1.00
  b0_z_[28]     -0.33      0.25     -0.33     -0.74      0.07   1341.16      1.00
  b0_z_[29]      1.52      0.31      1.52      1.04      2.04   1387.94      1.00
  b0_z_[30]      2.18      0.31      2.18      1.66      2.67   1082.44      1.00
  b0_z_[31]     -1.10      0.24     -1.10     -1.49     -0.73   1347.06      1.00
  b0_z_[32]      0.59      0.20      0.58      0.27      0.92    868.09      1.00
  b0_z_[33]     -0.80      0.22     -0.81     -1.16     -0.45   1061.66      1.00
  b0_z_[34]      0.46      0.25      0.46      0.08      0.89   1313.16      1.00
  b0_z_[35]     -0.20      0.20     -0.21     -0.51      0.13    941.25      1.00
  b0_z_[36]     -0.86      0.22     -0.86     -1.22     -0.50   1158.80      1.00
  b0_z_[37]     -0.41      0.20     -0.41     -0.73     -0.08   1037.83      1.00
  b0_z_[38]      0.50      0.18      0.50      0.22      0.81    715.06      1.00
  b0_z_[39]     -3.22      0.37     -3.21     -3.83     -2.63   1134.21      1.00
  b0_z_[40]      0.72      0.30      0.72      0.23      1.20   1908.49      1.00
  b0_z_[41]     -0.83      0.23     -0.83     -1.23     -0.46   1199.08      1.00
  b0_z_[42]     -0.08      0.21     -0.08     -0.45      0.25   1081.10      1.00
  b0_z_[43]     -0.77      0.21     -0.77     -1.13     -0.44   1116.83      1.00
  b0_z_[44]     -0.77      0.20     -0.78     -1.10     -0.45    942.33      1.00
  b0_z_[45]     -0.10      0.18     -0.09     -0.38      0.19    791.24      1.00
  b0_z_[46]      0.30      0.26      0.30     -0.10      0.74   1446.57      1.00
  b0_z_[47]      0.97      0.23      0.96      0.61      1.34    929.54      1.00
  b0_z_[48]      0.34      0.24      0.33     -0.03      0.76   1479.49      1.00
  b0_z_[49]     -0.69      0.22     -0.69     -1.05     -0.32   1336.65      1.00
  b0_z_[50]      0.23      0.20      0.23     -0.09      0.56    965.97      1.00
  b0_z_[51]      0.36      0.24      0.36     -0.01      0.77   1435.54      1.00
  b0_z_mean      0.35      0.14      0.35      0.13      0.58    545.08      1.00
   b0_z_std      0.93      0.10      0.92      0.77      1.08   1199.85      1.00
   b1_z_[0]     -0.68      0.59     -0.70     -1.61      0.34   8455.76      1.00
   b1_z_[1]     -0.26      0.72     -0.26     -1.48      0.87  11167.06      1.00
   b1_z_[2]     -1.22      0.58     -1.22     -2.18     -0.28   8356.65      1.00
   b1_z_[3]     -0.40      0.59     -0.40     -1.33      0.59   8836.51      1.00
   b1_z_[4]     -0.87      0.79     -0.93     -2.05      0.38   2606.24      1.00
   b1_z_[5]     -0.99      0.60     -1.01     -2.01     -0.06   7509.17      1.00
   b1_z_[6]      1.34      0.82      1.37     -0.03      2.66   6927.01      1.00
   b1_z_[7]      0.48      0.83      0.50     -0.86      1.83   9103.59      1.00
   b1_z_[8]     -0.50      0.94     -0.52     -2.07      1.00  13566.42      1.00
   b1_z_[9]     -0.18      0.55     -0.16     -1.07      0.73   6259.88      1.00
  b1_z_[10]     -0.53      0.56     -0.52     -1.45      0.36   8877.50      1.00
  b1_z_[11]      1.05      0.88      1.05     -0.35      2.51   7673.05      1.00
  b1_z_[12]     -0.09      0.68     -0.09     -1.16      1.05   9239.28      1.00
  b1_z_[13]     -0.26      0.49     -0.25     -1.04      0.58   6958.81      1.00
  b1_z_[14]      0.28      0.52      0.29     -0.56      1.14   6029.79      1.00
  b1_z_[15]      0.01      0.53      0.01     -0.85      0.88   7031.98      1.00
  b1_z_[16]     -0.54      0.60     -0.54     -1.54      0.41   9005.49      1.00
  b1_z_[17]      0.07      0.61      0.07     -0.89      1.11   6967.88      1.00
  b1_z_[18]      0.07      0.61      0.06     -0.95      1.04   8288.69      1.00
  b1_z_[19]      0.38      0.74      0.39     -0.83      1.59  10146.34      1.00
  b1_z_[20]      0.98      0.63      0.99     -0.06      2.00   7663.19      1.00
  b1_z_[21]      1.82      0.76      1.88      0.63      3.05   4077.76      1.00
  b1_z_[22]     -0.47      0.47     -0.49     -1.22      0.32   4714.90      1.00
  b1_z_[23]     -0.26      0.56     -0.29     -1.08      0.74   5557.14      1.00
  b1_z_[24]     -0.71      0.64     -0.72     -1.74      0.36   9272.17      1.00
  b1_z_[25]      0.24      0.52      0.24     -0.62      1.09   6501.39      1.00
  b1_z_[26]      0.13      0.78      0.15     -1.20      1.36   9618.27      1.00
  b1_z_[27]     -0.18      0.65     -0.19     -1.20      0.90   8060.66      1.00
  b1_z_[28]      0.02      0.78      0.03     -1.27      1.30   9766.62      1.00
  b1_z_[29]      1.06      0.86      1.08     -0.33      2.53   9035.78      1.00
  b1_z_[30]      1.68      0.68      1.71      0.66      2.82   4323.27      1.00
  b1_z_[31]     -0.75      0.62     -0.75     -1.68      0.36   9415.75      1.00
  b1_z_[32]      1.28      0.58      1.31      0.36      2.22   6385.65      1.00
  b1_z_[33]     -0.98      0.46     -0.97     -1.73     -0.23   7322.08      1.00
  b1_z_[34]      0.10      0.74      0.10     -1.18      1.25   9993.30      1.00
  b1_z_[35]     -0.02      0.45     -0.03     -0.74      0.72   7152.94      1.00
  b1_z_[36]     -0.63      0.60     -0.62     -1.58      0.35   8027.27      1.00
  b1_z_[37]     -0.35      0.59     -0.34     -1.28      0.63   7082.06      1.00
  b1_z_[38]      0.53      0.49      0.53     -0.26      1.35   6231.43      1.00
  b1_z_[39]     -0.20      0.46     -0.19     -0.92      0.57   6402.48      1.00
  b1_z_[40]      0.44      0.79      0.46     -0.83      1.73   9408.99      1.00
  b1_z_[41]     -0.74      0.61     -0.76     -1.66      0.34   8277.29      1.00
  b1_z_[42]     -0.37      0.71     -0.39     -1.57      0.74   9494.07      1.00
  b1_z_[43]     -0.38      0.59     -0.39     -1.32      0.59   6960.49      1.00
  b1_z_[44]     -1.36      0.44     -1.33     -2.04     -0.64   4641.72      1.00
  b1_z_[45]      1.49      0.64      1.53      0.49      2.56   4863.52      1.00
  b1_z_[46]      0.13      0.72      0.13     -1.13      1.25  11238.97      1.00
  b1_z_[47]      1.20      0.64      1.22      0.21      2.25   6999.60      1.00
  b1_z_[48]     -0.86      0.65     -0.87     -1.96      0.15   7828.48      1.00
  b1_z_[49]      0.41      0.69      0.41     -0.73      1.53   7888.90      1.00
  b1_z_[50]     -0.31      0.49     -0.33     -1.06      0.52   5807.68      1.00
  b1_z_[51]     -0.15      0.76     -0.15     -1.38      1.09   9394.14      1.00
  b1_z_mean      0.15      0.03      0.15      0.10      0.20   2850.12      1.00
   b1_z_std      0.16      0.03      0.16      0.11      0.21   2665.44      1.00
   b2_z_[0]     -0.05      0.69     -0.04     -1.18      1.04   6944.49      1.00
   b2_z_[1]      0.16      0.75      0.16     -1.02      1.45   9196.56      1.00
   b2_z_[2]      0.92      0.62      0.93     -0.13      1.90   6788.75      1.00
   b2_z_[3]      0.75      0.59      0.76     -0.18      1.71   7206.57      1.00
   b2_z_[4]      1.08      0.66      1.07     -0.00      2.18   5140.49      1.00
   b2_z_[5]      0.18      0.67      0.21     -0.91      1.25   5093.21      1.00
   b2_z_[6]     -0.90      0.81     -0.89     -2.23      0.41   7453.28      1.00
   b2_z_[7]     -0.23      0.76     -0.24     -1.50      0.98   6694.09      1.00
   b2_z_[8]      0.34      0.97      0.34     -1.29      1.90  12643.94      1.00
   b2_z_[9]      0.77      0.66      0.80     -0.30      1.86   5325.88      1.00
  b2_z_[10]      0.59      0.61      0.63     -0.46      1.53   6364.62      1.00
  b2_z_[11]     -0.27      0.74     -0.30     -1.45      0.97   7226.46      1.00
  b2_z_[12]      0.72      0.66      0.72     -0.37      1.79   7921.14      1.00
  b2_z_[13]     -0.26      0.50     -0.24     -1.04      0.57   6080.19      1.00
  b2_z_[14]      0.15      0.58      0.18     -0.82      1.06   5499.07      1.00
  b2_z_[15]      0.05      0.59      0.07     -0.86      1.05   6387.36      1.00
  b2_z_[16]     -0.32      0.57     -0.32     -1.25      0.61   7478.96      1.00
  b2_z_[17]     -0.06      0.62     -0.03     -1.11      0.90   6841.69      1.00
  b2_z_[18]     -0.53      0.65     -0.52     -1.64      0.50   7603.70      1.00
  b2_z_[19]     -0.37      0.72     -0.37     -1.51      0.86   7161.02      1.00
  b2_z_[20]     -0.52      0.66     -0.52     -1.60      0.55   8369.23      1.00
  b2_z_[21]     -1.62      0.84     -1.64     -3.00     -0.27   6539.84      1.00
  b2_z_[22]     -0.79      0.53     -0.77     -1.70      0.04   4437.85      1.00
  b2_z_[23]     -1.35      0.56     -1.34     -2.22     -0.41   5550.22      1.00
  b2_z_[24]      0.46      0.70      0.49     -0.69      1.59   8079.32      1.00
  b2_z_[25]     -0.31      0.57     -0.31     -1.23      0.62   6176.96      1.00
  b2_z_[26]      0.40      0.72      0.41     -0.80      1.56   7537.81      1.00
  b2_z_[27]     -0.39      0.58     -0.40     -1.37      0.51   5999.67      1.00
  b2_z_[28]      0.92      0.74      0.94     -0.31      2.13   6995.62      1.00
  b2_z_[29]     -1.18      0.87     -1.23     -2.56      0.29   6873.78      1.00
  b2_z_[30]     -1.82      0.72     -1.84     -3.03     -0.71   6347.00      1.00
  b2_z_[31]      1.30      0.71      1.34      0.18      2.51   7236.81      1.00
  b2_z_[32]     -0.33      0.60     -0.27     -1.30      0.64   5778.77      1.00
  b2_z_[33]      0.43      0.56      0.48     -0.50      1.33   5003.79      1.00
  b2_z_[34]     -0.38      0.73     -0.39     -1.60      0.77   7147.15      1.00
  b2_z_[35]     -0.35      0.56     -0.31     -1.28      0.55   6618.10      1.00
  b2_z_[36]      0.70      0.64      0.73     -0.31      1.74   6186.56      1.00
  b2_z_[37]      0.60      0.59      0.62     -0.35      1.58   6766.83      1.00
  b2_z_[38]     -1.38      0.49     -1.36     -2.14     -0.54   5191.64      1.00
  b2_z_[39]      1.13      0.46      1.12      0.37      1.88   5407.33      1.00
  b2_z_[40]     -0.62      0.78     -0.64     -1.92      0.63   7858.76      1.00
  b2_z_[41]      0.46      0.67      0.43     -0.63      1.56   6198.23      1.00
  b2_z_[42]     -0.15      0.68     -0.15     -1.27      0.98   7564.69      1.00
  b2_z_[43]      0.15      0.63      0.16     -0.89      1.17   5738.07      1.00
  b2_z_[44]      1.14      0.49      1.16      0.34      1.93   4734.09      1.00
  b2_z_[45]      1.37      0.61      1.39      0.43      2.41   4903.99      1.00
  b2_z_[46]     -0.16      0.76     -0.17     -1.41      1.07   8424.95      1.00
  b2_z_[47]      0.08      0.66      0.10     -1.04      1.13   5549.73      1.00
  b2_z_[48]      0.07      0.64      0.08     -0.99      1.13   7276.88      1.00
  b2_z_[49]     -0.21      0.65     -0.20     -1.26      0.86   7048.58      1.00
  b2_z_[50]     -0.72      0.56     -0.70     -1.64      0.16   5802.88      1.00
  b2_z_[51]      0.31      0.72      0.32     -0.82      1.54   8773.60      1.00
  b2_z_mean     -0.39      0.03     -0.39     -0.44     -0.34   2734.47      1.00
   b2_z_std      0.15      0.03      0.15      0.10      0.19   2256.88      1.00
         nu      2.22      0.67      2.08      1.30      3.12   2639.97      1.00
    sigma_z      0.24      0.04      0.24      0.18      0.31   1708.79      1.00

Number of divergences: 0
def choose_credible_parabola_parameters(idata, hdi=0.95, *, a='b0', b='b1', c='b2', n_curves=20, dim=None):
    """
    Plot credible parabola ax^2 + bx + c = 0. 
    """
    def is_between(values, low, high):
        return (values >= low) & (values <= high)

    posterior = idata.posterior
    hdi_posterior = az.hdi(posterior, hdi_prob=hdi)

    a_posterior = posterior[a].sel(dim).values.flatten()
    b_posterior = posterior[b].sel(dim).values.flatten()
    c_posterior = posterior[c].sel(dim).values.flatten()

    # Choose only parameters in the specified HDI.
    amask_in_hdi = is_between(a_posterior, *hdi_posterior[a].sel(dim).values)
    bmask_in_hdi = is_between(b_posterior, *hdi_posterior[b].sel(dim).values)
    cmask_in_hdi = is_between(c_posterior, *hdi_posterior[c].sel(dim).values)
    params_in_hdi = amask_in_hdi & bmask_in_hdi & cmask_in_hdi
    idx_in_hdi = np.arange(len(posterior.chain) *
                           len(posterior.draw))[params_in_hdi.flatten()]

    # Then, randomly choose from these parameters to plot the results.
    for idx in np.random.choice(idx_in_hdi, n_curves, replace=False):
        aval = a_posterior[idx]
        bval = b_posterior[idx]
        cval = c_posterior[idx]

        yield aval, bval, cval


idata = az.from_numpyro(
    mcmc,
    coords=dict(state=np.arange(len(income_data_3yr['State'].cat.categories))),
    dims=dict(b0=['state'], b1=['state'], b2=['state'],
              b0_z=['state'], b1_z=['state'], b2_z=['state'],),
)

n_posterior_curves = 20
x = np.linspace(1, 7.5, 1000)

for b0_mean, b1_mean, b2_mean in choose_credible_parabola_parameters(
        idata, a='b0_mean', b='b1_mean', c='b2_mean', n_curves=n_posterior_curves):
    ax.plot(x, b0_mean + b1_mean * x + b2_mean * x**2, c='b')

fig
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(20, 20))

for state_idx, state, ax in zip(range(income_data_3yr['State'].cat.categories.size),
                                income_data_3yr['State'].cat.categories,
                                axes.flatten()):
    # Plot posterior curves.
    for b0, b1, b2 in choose_credible_parabola_parameters(idata, a='b0', b='b1', c='b2', dim=dict(state=state_idx)):
        ax.plot(x, b0 + b1 * x + b2 * x**2, c='b')

    # Superimpose state median income.
    state_data = income_data_3yr[income_data_3yr['State'] == state]
    state_data = state_data.sort_values('FamilySize')
    ax.plot(state_data['FamilySize'], state_data['MedianIncome'], 'ko')

    ax.set_title(state)

fig.tight_layout()