%cd ..
%load_ext autoreload
%autoreload 2
/home/runner/work/numpyro-doing-bayesian/numpyro-doing-bayesian
import arviz as az
import jax.numpy as jnp
import jax.random as random
import matplotlib.pyplot as plt
import numpy as np
import numpyro
from numpyro.infer import MCMC, NUTS
import numpyro_glm
import numpyro_glm.metric.models as glm_metric
import numpyro_glm.metric.plots as plots
import pandas as pd
from scipy.stats import norm, t

Chapter 16: Metric-Predicted Variable on One or Two Groups

Estimating Mean and Standard Deviation of a Normal distribution

Metric Model

numpyro.render_model( glm_metric.one_group, model_args=(jnp.ones(5), ), render_params=True)

Synthesis data

MEAN = 5
STD = 3
N = 100

y = np.random.normal(MEAN, STD, N)

# Plot the histogram and true PDF of normal distribution.
fig, ax = plt.subplots()
ax.hist(y, density=True, label='Histogram of $y$')
xmin, xmax = ax.get_xlim()
p = np.linspace(xmin, xmax, 1000)
ax.plot(p, norm.pdf(p, loc=MEAN, scale=STD), label='Normal PDF')
ax.legend()
fig.tight_layout()

Now, we'll try to apply the metric model on that data to see if it can recover the parameter.

mcmc_key = random.PRNGKey(0)
kernel = NUTS(glm_metric.one_group)
mcmc = MCMC(kernel, num_warmup=250, num_samples=750)
mcmc.run(mcmc_key, jnp.array(y))
mcmc.print_summary()
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
                mean       std    median      5.0%     95.0%     n_eff     r_hat
      mean      4.96      0.31      4.97      4.53      5.51    517.16      1.00
       std      3.05      0.22      3.03      2.73      3.45    746.19      1.00

Number of divergences: 0

Plot diagnostics plot to see if the MCMC chains are well-behaved.

numpyro_glm.plot_diagnostic(mcmc, ['mean', 'std'])
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:  (chain: 1, draw: 750)
      Coordinates:
        * chain    (chain) int64 0
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749
      Data variables:
          mean     (chain, draw) float32 5.209 4.529 5.347 5.227 ... 4.773 4.991 4.968
          std      (chain, draw) float32 2.917 3.103 2.868 2.874 ... 2.672 3.41 3.339
      Attributes:
          created_at:                 2023-05-29T23:56:31.752566
          arviz_version:              0.12.1
          inference_library:          numpyro
          inference_library_version:  0.9.2

    • <xarray.Dataset>
      Dimensions:  (chain: 1, draw: 750, y_dim_0: 100)
      Coordinates:
        * chain    (chain) int64 0
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749
        * y_dim_0  (y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 91 92 93 94 95 96 97 98 99
      Data variables:
          y        (chain, draw, y_dim_0) float32 -2.681 -1.99 ... -2.229 -2.194
      Attributes:
          created_at:                 2023-05-29T23:56:31.843696
          arviz_version:              0.12.1
          inference_library:          numpyro
          inference_library_version:  0.9.2

    • <xarray.Dataset>
      Dimensions:    (chain: 1, draw: 750)
      Coordinates:
        * chain      (chain) int64 0
        * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 742 743 744 745 746 747 748 749
      Data variables:
          diverging  (chain, draw) bool False False False False ... False False False
      Attributes:
          created_at:                 2023-05-29T23:56:31.764392
          arviz_version:              0.12.1
          inference_library:          numpyro
          inference_library_version:  0.9.2

    • <xarray.Dataset>
      Dimensions:  (y_dim_0: 100)
      Coordinates:
        * y_dim_0  (y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 91 92 93 94 95 96 97 98 99
      Data variables:
          y        (y_dim_0) float32 1.78 5.272 6.843 6.101 ... 6.247 7.294 6.49 6.21
      Attributes:
          created_at:                 2023-05-29T23:56:31.844516
          arviz_version:              0.12.1
          inference_library:          numpyro
          inference_library_version:  0.9.2

fig = plots.plot_st(mcmc, y)

Smart Group IQ

iq_data = pd.read_csv('datasets/TwoGroupIQ.csv')
iq_data['Group'] = iq_data['Group'].astype('category')
smart_group_data = iq_data[iq_data.Group == 'Smart Drug']
smart_group_data.describe()
Score
count 63.000000
mean 107.841270
std 25.445201
min 50.000000
25% 96.000000
50% 107.000000
75% 119.000000
max 208.000000

Then, we will apply the one group model to the data and plot the results.

mcmc_key = random.PRNGKey(0)
kernel = NUTS(glm_metric.one_group)
mcmc = MCMC(kernel, num_warmup=250, num_samples=750)
mcmc.run(mcmc_key, jnp.array(smart_group_data.Score.values))
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
      mean    107.80      3.24    107.89    102.57    113.19    539.80      1.00
       std     26.08      2.35     25.88     22.73     30.35    642.03      1.00

Number of divergences: 0
numpyro_glm.plot_diagnostic(mcmc, ['mean', 'std'])
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:  (chain: 1, draw: 750)
      Coordinates:
        * chain    (chain) int64 0
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749
      Data variables:
          mean     (chain, draw) float32 110.0 105.3 110.2 109.4 ... 106.8 108.7 108.2
          std      (chain, draw) float32 24.79 26.16 24.03 23.97 ... 22.13 29.33 28.75
      Attributes:
          created_at:                 2023-05-29T23:56:35.414985
          arviz_version:              0.12.1
          inference_library:          numpyro
          inference_library_version:  0.9.2

    • <xarray.Dataset>
      Dimensions:  (chain: 1, draw: 750, y_dim_0: 63)
      Coordinates:
        * chain    (chain) int64 0
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749
        * y_dim_0  (y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62
      Data variables:
          y        (chain, draw, y_dim_0) float32 -4.181 -4.136 ... -5.051 -4.286
      Attributes:
          created_at:                 2023-05-29T23:56:35.499221
          arviz_version:              0.12.1
          inference_library:          numpyro
          inference_library_version:  0.9.2

    • <xarray.Dataset>
      Dimensions:    (chain: 1, draw: 750)
      Coordinates:
        * chain      (chain) int64 0
        * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 742 743 744 745 746 747 748 749
      Data variables:
          diverging  (chain, draw) bool False False False False ... False False False
      Attributes:
          created_at:                 2023-05-29T23:56:35.416760
          arviz_version:              0.12.1
          inference_library:          numpyro
          inference_library_version:  0.9.2

    • <xarray.Dataset>
      Dimensions:  (y_dim_0: 63)
      Coordinates:
        * y_dim_0  (y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62
      Data variables:
          y        (y_dim_0) int32 102 107 92 101 110 68 119 ... 97 139 72 100 144 112
      Attributes:
          created_at:                 2023-05-29T23:56:35.500056
          arviz_version:              0.12.1
          inference_library:          numpyro
          inference_library_version:  0.9.2

fig = plots.plot_st(
    mcmc, smart_group_data.Score.values,
    mean_comp_val=100,
    std_comp_val=15,
    effsize_comp_val=0,
)

Outliers and Robust Estimation: $t$ Distribution

Robust Metric Model

numpyro.render_model( glm_metric.one_group_robust, model_args=(jnp.ones(5),), render_params=True)

Synthesis Data

MEAN = 5
SIGMA = 3
NORMALITY = 3
N = 1000

y = np.random.standard_t(NORMALITY, size=N) * SIGMA + MEAN

# Plot the histogram and true PDF of normal distribution.
fig, ax = plt.subplots()
ax.hist(y, density=True, bins=100, label='Histogram of $y$')
xmin, xmax = ax.get_xlim()
p = np.linspace(xmin, xmax, 1000)
ax.plot(p, t.pdf(p, loc=MEAN, scale=SIGMA, df=NORMALITY), label='Student-$t$ PDF')
ax.plot(p, norm.pdf(p, loc=y.mean(), scale=y.std()),
        label='Normal PDF using\ndata mean and std')
ax.legend()
fig.tight_layout()

Using the robust metric model on our synthesis data to see if it can recover the original parameters.

mcmc_key = random.PRNGKey(0)
kernel = NUTS(glm_metric.one_group_robust)
mcmc = MCMC(kernel, num_warmup=250, num_samples=750)
mcmc.run(mcmc_key, jnp.array(y))
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
      mean      5.19      0.12      5.19      5.01      5.38    528.65      1.01
        nu      3.02      0.32      3.00      2.48      3.54    389.25      1.00
     sigma      2.97      0.12      2.97      2.78      3.17    351.15      1.00

Number of divergences: 0
numpyro_glm.plot_diagnostic(mcmc, ['mean', 'sigma', 'nu'])
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:  (chain: 1, draw: 750)
      Coordinates:
        * chain    (chain) int64 0
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749
      Data variables:
          mean     (chain, draw) float32 5.225 5.124 5.198 5.033 ... 5.201 5.154 5.112
          nu       (chain, draw) float32 3.03 2.811 3.201 2.794 ... 3.153 3.204 3.224
          sigma    (chain, draw) float32 3.008 2.894 3.055 2.892 ... 3.009 3.049 3.047
      Attributes:
          created_at:                 2023-05-29T23:56:40.651587
          arviz_version:              0.12.1
          inference_library:          numpyro
          inference_library_version:  0.9.2

    • <xarray.Dataset>
      Dimensions:  (chain: 1, draw: 750, y_dim_0: 1000)
      Coordinates:
        * chain    (chain) int64 0
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749
        * y_dim_0  (y_dim_0) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999
      Data variables:
          y        (chain, draw, y_dim_0) float32 -4.239 -2.245 ... -2.194 -3.763
      Attributes:
          created_at:                 2023-05-29T23:56:40.806417
          arviz_version:              0.12.1
          inference_library:          numpyro
          inference_library_version:  0.9.2

    • <xarray.Dataset>
      Dimensions:    (chain: 1, draw: 750)
      Coordinates:
        * chain      (chain) int64 0
        * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 742 743 744 745 746 747 748 749
      Data variables:
          diverging  (chain, draw) bool False False False False ... False False False
      Attributes:
          created_at:                 2023-05-29T23:56:40.653699
          arviz_version:              0.12.1
          inference_library:          numpyro
          inference_library_version:  0.9.2

    • <xarray.Dataset>
      Dimensions:  (y_dim_0: 1000)
      Coordinates:
        * y_dim_0  (y_dim_0) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999
      Data variables:
          y        (y_dim_0) float32 -1.97 6.648 -0.6971 5.801 ... 11.52 6.219 -0.8494
      Attributes:
          created_at:                 2023-05-29T23:56:40.807396
          arviz_version:              0.12.1
          inference_library:          numpyro
          inference_library_version:  0.9.2

fig = plots.plot_st_2(mcmc, y)

Smart Drug Group Data

mcmc_key = random.PRNGKey(0)
kernel = NUTS(glm_metric.one_group_robust)
mcmc = MCMC(kernel, num_warmup=250, num_samples=750)
mcmc.run(mcmc_key, jnp.array(smart_group_data.Score.values))
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
      mean    107.38      2.96    107.35    102.54    112.11    499.25      1.00
        nu      9.63     13.31      5.73      1.20     19.70    158.39      1.01
     sigma     19.83      3.53     19.78     14.20     25.66    235.03      1.00

Number of divergences: 0
numpyro_glm.plot_diagnostic(mcmc, ['mean', 'sigma', 'nu'])
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:  (chain: 1, draw: 750)
      Coordinates:
        * chain    (chain) int64 0
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749
      Data variables:
          mean     (chain, draw) float32 105.6 106.3 108.6 104.4 ... 108.3 106.8 104.6
          nu       (chain, draw) float32 6.38 6.304 6.77 2.442 ... 9.179 8.01 13.55
          sigma    (chain, draw) float32 22.7 23.93 23.56 16.29 ... 22.12 23.89 20.27
      Attributes:
          created_at:                 2023-05-29T23:56:45.027781
          arviz_version:              0.12.1
          inference_library:          numpyro
          inference_library_version:  0.9.2

    • <xarray.Dataset>
      Dimensions:  (chain: 1, draw: 750, y_dim_0: 63)
      Coordinates:
        * chain    (chain) int64 0
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749
        * y_dim_0  (y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62
      Data variables:
          y        (chain, draw, y_dim_0) float32 -4.095 -4.083 ... -5.734 -4.017
      Attributes:
          created_at:                 2023-05-29T23:56:45.160203
          arviz_version:              0.12.1
          inference_library:          numpyro
          inference_library_version:  0.9.2

    • <xarray.Dataset>
      Dimensions:    (chain: 1, draw: 750)
      Coordinates:
        * chain      (chain) int64 0
        * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 742 743 744 745 746 747 748 749
      Data variables:
          diverging  (chain, draw) bool False False False False ... False False False
      Attributes:
          created_at:                 2023-05-29T23:56:45.029825
          arviz_version:              0.12.1
          inference_library:          numpyro
          inference_library_version:  0.9.2

    • <xarray.Dataset>
      Dimensions:  (y_dim_0: 63)
      Coordinates:
        * y_dim_0  (y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62
      Data variables:
          y        (y_dim_0) int32 102 107 92 101 110 68 119 ... 97 139 72 100 144 112
      Attributes:
          created_at:                 2023-05-29T23:56:45.161050
          arviz_version:              0.12.1
          inference_library:          numpyro
          inference_library_version:  0.9.2

fig = plots.plot_st_2(
    mcmc, smart_group_data.Score.values,
    mean_comp_val=100,
    sigma_comp_val=15,
    effsize_comp_val=0)
fig = numpyro_glm.plot_pairwise_scatter(mcmc, ['mean', 'sigma', 'nu'])

Two Groups

mcmc_key = random.PRNGKey(0)
kernel = NUTS(glm_metric.multi_groups_robust)
mcmc = MCMC(kernel, num_warmup=250, num_samples=750)
mcmc.run(
    mcmc_key,
    jnp.array(iq_data['Score'].values),
    jnp.array(iq_data['Group'].cat.codes.values),
    len(iq_data['Group'].cat.categories),
)
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
   mean[0]     99.29      1.76     99.20     96.70    102.33    806.66      1.00
   mean[1]    107.16      2.62    107.14    102.90    111.36    910.67      1.00
        nu      3.86      1.68      3.45      1.80      5.73    403.33      1.01
  sigma[0]     11.30      1.74     11.10      8.37     13.97    655.02      1.00
  sigma[1]     17.91      2.72     17.71     13.80     22.56    534.41      1.00

Number of divergences: 0
numpyro_glm.plot_diagnostic(mcmc, ['mean', 'sigma', 'nu'])
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:      (chain: 1, draw: 750, mean_dim_0: 2, sigma_dim_0: 2)
      Coordinates:
        * chain        (chain) int64 0
        * draw         (draw) int64 0 1 2 3 4 5 6 7 ... 743 744 745 746 747 748 749
        * mean_dim_0   (mean_dim_0) int64 0 1
        * sigma_dim_0  (sigma_dim_0) int64 0 1
      Data variables:
          mean         (chain, draw, mean_dim_0) float32 98.21 108.0 ... 94.3 103.5
          nu           (chain, draw) float32 4.904 4.618 2.711 ... 2.909 2.738 2.962
          sigma        (chain, draw, sigma_dim_0) float32 11.06 19.12 ... 10.23 17.3
      Attributes:
          created_at:                 2023-05-29T23:56:52.320366
          arviz_version:              0.12.1
          inference_library:          numpyro
          inference_library_version:  0.9.2

    • <xarray.Dataset>
      Dimensions:  (chain: 1, draw: 750, y_dim_0: 120)
      Coordinates:
        * chain    (chain) int64 0
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 742 743 744 745 746 747 748 749
        * y_dim_0  (y_dim_0) int64 0 1 2 3 4 5 6 7 ... 112 113 114 115 116 117 118 119
      Data variables:
          y        (chain, draw, y_dim_0) float32 -3.979 -3.922 ... -3.338 -5.223
      Attributes:
          created_at:                 2023-05-29T23:56:52.493138
          arviz_version:              0.12.1
          inference_library:          numpyro
          inference_library_version:  0.9.2

    • <xarray.Dataset>
      Dimensions:    (chain: 1, draw: 750)
      Coordinates:
        * chain      (chain) int64 0
        * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 742 743 744 745 746 747 748 749
      Data variables:
          diverging  (chain, draw) bool False False False False ... False False False
      Attributes:
          created_at:                 2023-05-29T23:56:52.322597
          arviz_version:              0.12.1
          inference_library:          numpyro
          inference_library_version:  0.9.2

    • <xarray.Dataset>
      Dimensions:  (y_dim_0: 120)
      Coordinates:
        * y_dim_0  (y_dim_0) int64 0 1 2 3 4 5 6 7 ... 112 113 114 115 116 117 118 119
      Data variables:
          y        (y_dim_0) int32 102 107 92 101 110 68 119 ... 82 138 99 93 93 72
      Attributes:
          created_at:                 2023-05-29T23:56:52.493967
          arviz_version:              0.12.1
          inference_library:          numpyro
          inference_library_version:  0.9.2

Plot the resulting posteriors.

fig = plots.plot_st_2(
    mcmc,
    iq_data[iq_data['Group'] == 'Placebo']['Score'].values,
    mean_coords=dict(mean_dim_0=0),
    sigma_coords=dict(sigma_dim_0=0),
    figtitle='Placebo Posteriors')
fig = plots.plot_st_2(
    mcmc,
    iq_data[iq_data['Group'] == 'Smart Drug']['Score'].values,
    mean_coords=dict(mean_dim_0=1),
    sigma_coords=dict(sigma_dim_0=1),
    figtitle='Smart Drug Posteriors')

Then, we will plot the difference between the two groups.

idata = az.from_numpyro(mcmc)
posteriors = idata.posterior

fig, axes = plt.subplots(ncols=3, figsize=(18, 6))
fig.suptitle('Difference between Smart Drug and Placebo')

# Plot mean difference.
ax = axes[0]
mean_difference = (posteriors['mean'].sel(dict(mean_dim_0=1))
                   - posteriors['mean'].sel(dict(mean_dim_0=0)))
az.plot_posterior(
    mean_difference,
    hdi_prob=0.95,
    ref_val=0,
    rope=(-0.5, 0.5),
    point_estimate='mode',
    kind='hist',
    ax=ax)
ax.set_title('Difference of Means')
ax.set_xlabel('$\mu[1] - \mu[0]$')

# Plot sigma difference.
ax = axes[1]
sigma_difference = (posteriors['sigma'].sel(dict(sigma_dim_0=1))
                    - posteriors['sigma'].sel(dict(sigma_dim_0=0)))
az.plot_posterior(
    sigma_difference,
    hdi_prob=0.95,
    ref_val=0,
    rope=(-0.5, 0.5),
    point_estimate='mode',
    kind='hist',
    ax=ax)
ax.set_title('Difference of Scales')
ax.set_xlabel('$\sigma[1] - \sigma[0]$')

# Plot effect size.
ax = axes[2]
sigmas_squared = (posteriors['sigma'].sel(dict(sigma_dim_0=1))**2
                  + posteriors['sigma'].sel(dict(sigma_dim_0=0))**2)
effect_size = mean_difference / np.sqrt(sigmas_squared)
az.plot_posterior(
    effect_size,
    hdi_prob=0.95,
    ref_val=0,
    rope=(-0.1, 0.1),
    point_estimate='mode',
    kind='hist',
    ax=ax)
ax.set_title('Effect Size')
ax.set_xlabel('$(\mu[1] - \mu[0]) / \sqrt{\sigma[1]^2 + \sigma[0]^2}$')

fig.tight_layout()