HBR with SHASH likelihood ========================= Welcome to this tutorial notebook that will go through the fitting and evaluation of Normative models with a Hierarchical Bayesian Regression (HBR) model using a SHASH likelihood. Let’s jump right in. Imports ~~~~~~~ .. code:: ipython3 import warnings import logging import pandas as pd import matplotlib.pyplot as plt from pcntoolkit import ( HBR, BsplineBasisFunction, NormativeModel, NormData, load_fcon1000, SHASHbLikelihood, make_prior, plot_centiles_advanced, plot_qq, plot_ridge, ) import numpy as np import pcntoolkit.util.output import seaborn as sns sns.set_style("darkgrid") # Suppress some annoying warnings and logs pymc_logger = logging.getLogger("pymc") pymc_logger.setLevel(logging.WARNING) pymc_logger.propagate = False warnings.simplefilter(action="ignore", category=FutureWarning) pd.options.mode.chained_assignment = None # default='warn' pcntoolkit.util.output.Output.set_show_messages(False) Load data --------- First we download a small example dataset from github. .. code:: ipython3 # Download an example dataset norm_data: NormData = load_fcon1000() # Select only a few features features_to_model = [ "WM-hypointensities", "Right-Lateral-Ventricle", "Right-Amygdala", "CortexVol", ] norm_data = norm_data.sel({"response_vars": features_to_model}) # Split into train and test sets train, test = norm_data.train_test_split() .. code:: ipython3 # Visualize the data feature_to_plot = features_to_model[0] df = train.to_dataframe() fig, ax = plt.subplots(1, 2, figsize=(15, 5)) sns.countplot(data=df, y=("batch_effects", "site"), hue=("batch_effects", "sex"), ax=ax[0], orient="h") ax[0].legend(title="Sex") ax[0].set_title("Count of sites") ax[0].set_xlabel("Site") ax[0].set_ylabel("Count") sns.scatterplot( data=df, x=("X", "age"), y=("Y", feature_to_plot), hue=("batch_effects", "site"), style=("batch_effects", "sex"), ax=ax[1], ) ax[1].legend([], []) ax[1].set_title(f"Scatter plot of age vs {feature_to_plot}") ax[1].set_xlabel("Age") ax[1].set_ylabel(feature_to_plot) plt.show() .. image:: 04_HBR_SHASH_files/04_HBR_SHASH_6_0.png Creating a Normative model -------------------------- A normative model has a regression model for each response variable. We provide a template regression model which is copied for each response variable. A template regression model can be anything that extends the ``RegressionModel``. We provide a number of built-in regression models, but you can also create your own. Here we use the ``HBR`` class, which implements a Hierarchical Bayesian Regression model. Likelihoods ~~~~~~~~~~~ ``HBR`` models are composed of a likelihood and a number of priors on the parameters of the likelihood. The PCNtoolkit offers a number of likelihood functions: 1. NormallLikelihood: Good for modeling data that is (approximately) normally distributed. 2. SHASHbLikelihood: Good for modeling data that is heavily skewed, or tailed. 3. BetaLikelihood: Good for modeling data that is bounded, e.g. between 0 and 1. Likelihood parameters ~~~~~~~~~~~~~~~~~~~~~ Each of these likelihoods takes their own set of parameters, and for each, we have to set a prior: 1. NormalLikelihood: - ``mu``: The mean of the normal distribution. - ``sigma``: The standard deviation of the normal distribution. 2. SHASHbLikelihood: - ``mu``: The mean of the skew-normal distribution. - ``sigma``: The standard deviation of the skew-normal distribution. - ``epsilon``: The skewness parameter of the skew-normal distribution. - ``delta``: The tail thickness (or kurtosis) of the skew-normal distribution. 3. BetaLikelihood: - ``alpha``: The shape parameter of the beta distribution. - ``beta``: The scale parameter of the beta distribution. Configuring likelihood parameters ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Each likelihood parameter needs to be configured. The defaults should work reasonably well for most cases, at least where the data is standardized. Here’s a quick guide to configuring the likelihood parameters yourself, using the ``make_prior`` function. 1. Is your parameter a function of the covariates? If so, you set the ``linear`` parameter to ``True``. 1. If so, you can choose the basis expansion to use for the parameter: BSplineBasisFunction, LinearBasisFunction, or PolynomialBasisFunction. 2. Also, determine whether the slope and intercept of the prior have a random effect or not. Here’s an example of a linear prior with a bspline basis expansion and a random effect in the intercept. .. code:: python mu = make_prior('mu', linear=True, basis_function=BSplineBasisFunction(degree=3, nknots=5), intercept = make_prior('intercept_mu', random=True)) 2. If your parameter is not a function of the covariates, you have to decide whether the parameter itself has a random effect or not. Here’s an example of a prior with a random effect. .. code:: python epsilon = make_prior('epsilon', random=True) 3. Some parameters (such as sigma) need to be strictly positive, which we can enforce with a mapping. Here’s an example of a prior with a mapping to the positive real line. .. code:: python # The mapping_params are (horizontal shift, scaling, vertical shift) sigma = make_prior('sigma', mapping='softplus', mapping_params=(0, 5, 0)) .. code:: ipython3 # Mini demo of the mapping params xsp = np.linspace(-7, 7, 100) softplus = lambda x: np.log(1 + np.exp(x)) paramaterized_softplus = lambda x, a, b, c: softplus((x - a) / b) * b + c plt.plot(xsp, paramaterized_softplus(xsp, 0, 1, 0), label="no mapping") plt.plot(xsp, paramaterized_softplus(xsp, 1.5, 1, 0), label="horizontal shift of 1.5") plt.plot(xsp, paramaterized_softplus(xsp, 0, 1, 1), label="vertical shift of 1") plt.plot(xsp, paramaterized_softplus(xsp, 0, 2, 0), label="scale with a factor of 2") plt.legend() plt.show() .. image:: 04_HBR_SHASH_files/04_HBR_SHASH_8_0.png 4. Any non-linear parameters can be further configured with ``dist_name`` and ``dist_params``. Here’s an example of a prior with a gamma distribution. .. code:: python alpha = make_prior('alpha', dist_name='gamma', dist_params=(1, 1)) We currently support the following distributions: - Normal - HalfNormal - LogNormal - Uniform - Gamma The order of the parameters is important, and follows the order of the parameters in the corresponding distributions in PyMC. Creating a HBR model ~~~~~~~~~~~~~~~~~~~~ Here’s a thoroughly commented example of a HBR model with a SHASH Likelihood, which we will use to model our response variable. .. code:: ipython3 # The SHASHb likelihood is a bit more flexible than the Normal likelihood, and takes four parameters, mu, sigma, epsilon, and delta. # Mu and sigma fulfill the same role as in the Normal likelihood, namely the mean and standard deviation of the distribution. # Epsilon and delta are parameters that control the skewness and kurtosis of the distribution. # SHASHb model with fixed values for epsilon and delta mu = make_prior( # Mu is linear because we want to allow the mean to vary as a function of the covariates. linear=True, # The slope coefficients are assumed to be normally distributed, with a mean of 0 and a standard deviation of 10. slope=make_prior(dist_name="Normal", dist_params=(0.0, 10.0)), # The intercept is random, because we expect the intercept to vary between sites and sexes. intercept=make_prior( random=True, # Mu is the mean of the intercept, which is normally distributed with a mean of 0 and a standard deviation of 1. mu=make_prior(dist_name="Normal", dist_params=(0.0, 1.0)), # Sigma is the scale at which the intercepts vary. It is a positive parameter, so we have to map it to the positive domain. sigma=make_prior(dist_name="Normal", dist_params=(0.0, 1.0), mapping="softplus", mapping_params=(0.0, 3.0)), ), # We use a B-spline basis function to allow for non-linearity in the mean. basis_function=BsplineBasisFunction(basis_column=0, nknots=5, degree=3), ) sigma = make_prior( # Sigma is also linear, because we want to allow the standard deviation to vary as a function of the covariates: heteroskedasticity. linear=True, # The slope coefficients are assumed to be normally distributed, with a mean of 0 and a standard deviation of 2. slope=make_prior(dist_name="Normal", dist_params=(0.0, 2.0)), # The intercept is not random, because we assume the intercept of the variance to be the same for all sites and sexes. intercept=make_prior(dist_name="Normal", dist_params=(1.0, 1.0)), # We use a B-spline basis function to allow for non-linearity in the standard deviation. basis_function=BsplineBasisFunction(basis_column=0, nknots=5, degree=3), # We use a softplus mapping to ensure that sigma is strictly positive. mapping="softplus", # We scale the softplus mapping by a factor of 3, to avoid spikes in the resulting density. # The parameters (a, b, c) provided to a mapping f are used as: f_abc(x) = f((x - a) / b) * b + c # This basically provides an affine transformation of the softplus function. # a -> horizontal shift # b -> scaling # c -> vertical shift # You can leave c out, and it will default to 0. mapping_params=(0.0, 3.0), ) epsilon = make_prior( # Epsilon is assumed to follow a normal distribution, with a mean of 0 and a standard deviation of 1. dist_name="Normal", dist_params=(0.0, 1.0), ) delta = make_prior( # Delta is sampled from a normal distribution, with a mean of 1 and a standard deviation of 1, and then mapped to the positive real line using a softplus function. dist_name="Normal", dist_params=(1.0, 1.0), mapping="softplus", # We apply a softplus mapping to the delta parameter, to ensure that it is strictly positive. mapping_params=( 0.0, # Horizontal shift 3.0, # Scale for smoothness 0.6, # We need to provide a vertical shift as well, because the SHASH mapping goes a bit wild with low values for delta ), ) shashb1_regression_model = HBR( name="template", cores=16, progressbar=True, draws=1500, tune=500, chains=4, nuts_sampler="nutpie", likelihood=SHASHbLikelihood(mu, sigma, epsilon, delta), ) After specifying the regression model, we can configure a normative model. A normative model has a number of configuration options: - ``savemodel``: Whether to save the model after fitting. - ``evaluate_model``: Whether to evaluate the model after fitting. - ``saveresults``: Whether to save the results after evaluation. - ``saveplots``: Whether to save the plots after fitting. - ``save_dir``: The directory to save the model, results, and plots. - ``inscaler``: The scaler to use for the input data. - ``outscaler``: The scaler to use for the output data. .. code:: ipython3 model = NormativeModel( # The regression model to use for the normative model. template_regression_model=shashb1_regression_model, # Whether to save the model after fitting. savemodel=True, # Whether to evaluate the model after fitting. evaluate_model=True, # Whether to save the results after evaluation. saveresults=True, # Whether to save the plots after fitting. saveplots=False, # The directory to save the model, results, and plots. save_dir="resources/hbr_SHASH/save_dir", # The scaler to use for the input data. Can be either one of "standardize", "minmax", "robminmax", "none" inscaler="standardize", # The scaler to use for the output data. Can be either one of "standardize", "minmax", "robminmax", "none" outscaler="standardize", ) Fit the model ------------- With all that configured, we can fit the model. The ``fit_predict`` function will fit the model, evaluate it, and save the results and plots (if so configured). After that, it will compute Z-scores and centiles for the test set. All results can be found in the save directory. .. code:: ipython3 model.fit_predict(train, test) .. raw:: html .. raw:: html

Sampler Progress

Total Chains: 4

Active Chains: 0

Finished Chains: 4

Sampling for 2 minutes

Estimated Time to Completion: now

Progress Draws Divergences Step Size Gradients/Draw
2000 0 0.01 511
2000 0 0.01 1023
2000 0 0.02 1023
2000 0 0.01 767
.. raw:: html .. raw:: html

Sampler Progress

Total Chains: 4

Active Chains: 0

Finished Chains: 4

Sampling for 2 minutes

Estimated Time to Completion: now

Progress Draws Divergences Step Size Gradients/Draw
2000 0 0.01 1023
2000 0 0.01 767
2000 0 0.01 511
2000 0 0.01 1023
.. raw:: html .. raw:: html

Sampler Progress

Total Chains: 4

Active Chains: 0

Finished Chains: 4

Sampling for 2 minutes

Estimated Time to Completion: now

Progress Draws Divergences Step Size Gradients/Draw
2000 0 0.02 511
2000 0 0.01 1023
2000 0 0.01 255
2000 0 0.01 1023
.. raw:: html .. raw:: html

Sampler Progress

Total Chains: 4

Active Chains: 0

Finished Chains: 4

Sampling for 2 minutes

Estimated Time to Completion: now

Progress Draws Divergences Step Size Gradients/Draw
2000 0 0.02 511
2000 0 0.02 255
2000 0 0.02 127
2000 0 0.01 511
.. raw:: html
<xarray.NormData> Size: 98kB
    Dimensions:            (observations: 216, response_vars: 4, covariates: 1,
                            batch_effect_dims: 2, centile: 5, statistic: 11)
    Coordinates:
      * observations       (observations) int64 2kB 756 769 692 616 ... 751 470 1043
      * response_vars      (response_vars) <U23 368B 'WM-hypointensities' ... 'Co...
      * covariates         (covariates) <U3 12B 'age'
      * batch_effect_dims  (batch_effect_dims) <U4 32B 'sex' 'site'
      * centile            (centile) float64 40B 0.05 0.25 0.5 0.75 0.95
      * statistic          (statistic) <U8 352B 'EXPV' 'MACE' ... 'SMSE' 'ShapiroW'
    Data variables:
        subject_ids        (observations) object 2kB 'Munchen_sub96752' ... 'Quee...
        Y                  (observations, response_vars) float64 7kB 2.721e+03 .....
        X                  (observations, covariates) float64 2kB 63.0 ... 23.0
        batch_effects      (observations, batch_effect_dims) <U17 29kB 'F' ... 'Q...
        Z                  (observations, response_vars) float64 7kB 0.5221 ... -...
        centiles           (centile, observations, response_vars) float64 35kB 1....
        logp               (observations, response_vars) float64 7kB -1.623 ... -...
        Yhat               (observations, response_vars) float64 7kB 2.362e+03 .....
        statistics         (response_vars, statistic) float64 352B 0.3927 ... 0.995
    Attributes:
        real_ids:                       True
        is_scaled:                      False
        name:                           fcon1000_test
        unique_batch_effects:           {np.str_('sex'): [np.str_('F'), np.str_('...
        batch_effect_counts:            defaultdict(<function NormData.register_b...
        covariate_ranges:               {np.str_('age'): {'mean': np.float64(28.2...
        batch_effect_covariate_ranges:  {np.str_('sex'): {np.str_('F'): {np.str_(...
Plot the results ---------------- The PCNtoolkit offers are a number of different plotting functions: 1. plot_centiles: Plot the predicted centiles for a model on top of harmonized scatter data. 2. plot_centiles_advanced: a more advanced version of plot_centiles, with more configuration options, coloring, and conditionals 3. plot_qq: Plot the QQ-plot of the predicted Z-scores 4. plot_ridge: Plot density plots of the predicted Z-scores Let’s start with the centiles. .. code:: ipython3 plot_centiles_advanced( model, centiles=[0.05, 0.5, 0.95], # Plot these centiles, the default is [0.05, 0.25, 0.5, 0.75, 0.95] scatter_data=train, # Scatter this data along with the centiles batch_effects={"site": ["Beijing_Zang", "AnnArbor_a"], "sex": ["M"]}, # Highlight these groups show_other_data=True, # scatter data not in those groups as smaller black circles harmonize_data=True, # harmonize the scatterdata, this means that we 'remove' the batch effects from the data, by simulating what the data would have looked like if all data was from the same batch. conditionals=[30] ) .. image:: 04_HBR_SHASH_files/04_HBR_SHASH_16_0.png .. image:: 04_HBR_SHASH_files/04_HBR_SHASH_16_1.png .. image:: 04_HBR_SHASH_files/04_HBR_SHASH_16_2.png .. image:: 04_HBR_SHASH_files/04_HBR_SHASH_16_3.png Now let’s see the qq plots .. code:: ipython3 plot_qq(test, plot_id_line=True) .. image:: 04_HBR_SHASH_files/04_HBR_SHASH_18_0.png .. image:: 04_HBR_SHASH_files/04_HBR_SHASH_18_1.png .. image:: 04_HBR_SHASH_files/04_HBR_SHASH_18_2.png .. image:: 04_HBR_SHASH_files/04_HBR_SHASH_18_3.png We can also split the QQ plots by batch effects: .. code:: ipython3 plot_qq(test, plot_id_line=True, hue_data="sex", split_data="sex") sns.set_theme(style="darkgrid", rc={"axes.facecolor": (0, 0, 0, 0)}) .. image:: 04_HBR_SHASH_files/04_HBR_SHASH_20_0.png .. image:: 04_HBR_SHASH_files/04_HBR_SHASH_20_1.png .. image:: 04_HBR_SHASH_files/04_HBR_SHASH_20_2.png .. image:: 04_HBR_SHASH_files/04_HBR_SHASH_20_3.png And finally the ridge plot: .. code:: ipython3 plot_ridge( train, "Z", split_by="sex" ) # We can also show the 'Y' variable, and that will show the marginal distribution of the response variable, per batch effect. .. parsed-literal:: /opt/anaconda3/envs/ptk/lib/python3.12/site-packages/seaborn/axisgrid.py:123: UserWarning: Tight layout not applied. tight_layout cannot make Axes height small enough to accommodate all Axes decorations. self._figure.tight_layout(*args, **kwargs) /opt/anaconda3/envs/ptk/lib/python3.12/site-packages/seaborn/axisgrid.py:123: UserWarning: Tight layout not applied. The bottom and top margins cannot be made large enough to accommodate all Axes decorations. self._figure.tight_layout(*args, **kwargs) /opt/anaconda3/envs/ptk/lib/python3.12/site-packages/seaborn/axisgrid.py:123: UserWarning: Tight layout not applied. The bottom and top margins cannot be made large enough to accommodate all Axes decorations. self._figure.tight_layout(*args, **kwargs) /opt/anaconda3/envs/ptk/lib/python3.12/site-packages/seaborn/axisgrid.py:123: UserWarning: Tight layout not applied. The bottom and top margins cannot be made large enough to accommodate all Axes decorations. self._figure.tight_layout(*args, **kwargs) /opt/anaconda3/envs/ptk/lib/python3.12/site-packages/seaborn/axisgrid.py:123: UserWarning: Tight layout not applied. The bottom and top margins cannot be made large enough to accommodate all Axes decorations. self._figure.tight_layout(*args, **kwargs) /opt/anaconda3/envs/ptk/lib/python3.12/site-packages/pcntoolkit/util/plotter.py:817: UserWarning: Tight layout not applied. tight_layout cannot make Axes height small enough to accommodate all Axes decorations. plt.tight_layout() .. image:: 04_HBR_SHASH_files/04_HBR_SHASH_22_1.png .. parsed-literal:: /opt/anaconda3/envs/ptk/lib/python3.12/site-packages/seaborn/axisgrid.py:123: UserWarning: Tight layout not applied. tight_layout cannot make Axes height small enough to accommodate all Axes decorations. self._figure.tight_layout(*args, **kwargs) /opt/anaconda3/envs/ptk/lib/python3.12/site-packages/seaborn/axisgrid.py:123: UserWarning: Tight layout not applied. The bottom and top margins cannot be made large enough to accommodate all Axes decorations. self._figure.tight_layout(*args, **kwargs) /opt/anaconda3/envs/ptk/lib/python3.12/site-packages/seaborn/axisgrid.py:123: UserWarning: Tight layout not applied. The bottom and top margins cannot be made large enough to accommodate all Axes decorations. self._figure.tight_layout(*args, **kwargs) /opt/anaconda3/envs/ptk/lib/python3.12/site-packages/seaborn/axisgrid.py:123: UserWarning: Tight layout not applied. The bottom and top margins cannot be made large enough to accommodate all Axes decorations. self._figure.tight_layout(*args, **kwargs) /opt/anaconda3/envs/ptk/lib/python3.12/site-packages/seaborn/axisgrid.py:123: UserWarning: Tight layout not applied. The bottom and top margins cannot be made large enough to accommodate all Axes decorations. self._figure.tight_layout(*args, **kwargs) /opt/anaconda3/envs/ptk/lib/python3.12/site-packages/pcntoolkit/util/plotter.py:817: UserWarning: Tight layout not applied. tight_layout cannot make Axes height small enough to accommodate all Axes decorations. plt.tight_layout() .. image:: 04_HBR_SHASH_files/04_HBR_SHASH_22_3.png .. parsed-literal:: /opt/anaconda3/envs/ptk/lib/python3.12/site-packages/seaborn/axisgrid.py:123: UserWarning: Tight layout not applied. tight_layout cannot make Axes height small enough to accommodate all Axes decorations. self._figure.tight_layout(*args, **kwargs) /opt/anaconda3/envs/ptk/lib/python3.12/site-packages/seaborn/axisgrid.py:123: UserWarning: Tight layout not applied. The bottom and top margins cannot be made large enough to accommodate all Axes decorations. self._figure.tight_layout(*args, **kwargs) /opt/anaconda3/envs/ptk/lib/python3.12/site-packages/seaborn/axisgrid.py:123: UserWarning: Tight layout not applied. The bottom and top margins cannot be made large enough to accommodate all Axes decorations. self._figure.tight_layout(*args, **kwargs) /opt/anaconda3/envs/ptk/lib/python3.12/site-packages/seaborn/axisgrid.py:123: UserWarning: Tight layout not applied. The bottom and top margins cannot be made large enough to accommodate all Axes decorations. self._figure.tight_layout(*args, **kwargs) /opt/anaconda3/envs/ptk/lib/python3.12/site-packages/seaborn/axisgrid.py:123: UserWarning: Tight layout not applied. The bottom and top margins cannot be made large enough to accommodate all Axes decorations. self._figure.tight_layout(*args, **kwargs) /opt/anaconda3/envs/ptk/lib/python3.12/site-packages/pcntoolkit/util/plotter.py:817: UserWarning: Tight layout not applied. tight_layout cannot make Axes height small enough to accommodate all Axes decorations. plt.tight_layout() .. image:: 04_HBR_SHASH_files/04_HBR_SHASH_22_5.png .. parsed-literal:: /opt/anaconda3/envs/ptk/lib/python3.12/site-packages/seaborn/axisgrid.py:123: UserWarning: Tight layout not applied. tight_layout cannot make Axes height small enough to accommodate all Axes decorations. self._figure.tight_layout(*args, **kwargs) /opt/anaconda3/envs/ptk/lib/python3.12/site-packages/seaborn/axisgrid.py:123: UserWarning: Tight layout not applied. The bottom and top margins cannot be made large enough to accommodate all Axes decorations. self._figure.tight_layout(*args, **kwargs) /opt/anaconda3/envs/ptk/lib/python3.12/site-packages/seaborn/axisgrid.py:123: UserWarning: Tight layout not applied. The bottom and top margins cannot be made large enough to accommodate all Axes decorations. self._figure.tight_layout(*args, **kwargs) /opt/anaconda3/envs/ptk/lib/python3.12/site-packages/seaborn/axisgrid.py:123: UserWarning: Tight layout not applied. The bottom and top margins cannot be made large enough to accommodate all Axes decorations. self._figure.tight_layout(*args, **kwargs) /opt/anaconda3/envs/ptk/lib/python3.12/site-packages/seaborn/axisgrid.py:123: UserWarning: Tight layout not applied. The bottom and top margins cannot be made large enough to accommodate all Axes decorations. self._figure.tight_layout(*args, **kwargs) /opt/anaconda3/envs/ptk/lib/python3.12/site-packages/pcntoolkit/util/plotter.py:817: UserWarning: Tight layout not applied. tight_layout cannot make Axes height small enough to accommodate all Axes decorations. plt.tight_layout() .. image:: 04_HBR_SHASH_files/04_HBR_SHASH_22_7.png Evaluation statistcs are stored in the NormData object: .. code:: ipython3 display(train.get_statistics_df()) display(test.get_statistics_df()) .. raw:: html
statistic EXPV MACE MAPE MSLL NLL R2 RMSE Rho Rho_p SMSE ShapiroW
response_vars
CortexVol 0.519196 0.010719 0.057414 -11.208387 1.072185 0.519179 36152.571241 0.710485 2.199838e-133 0.480821 0.995229
Right-Amygdala 0.383195 0.013689 0.088632 -5.738379 1.181817 0.383064 192.435690 0.599840 2.338456e-85 0.616936 0.995682
Right-Lateral-Ventricle 0.214931 0.013968 0.400007 -8.643762 1.031061 0.213011 3415.620496 0.371987 1.116620e-29 0.786989 0.987490
WM-hypointensities 0.350038 0.014664 0.320848 -7.437633 0.686587 0.349278 658.816151 0.488287 7.601376e-53 0.650722 0.986337
.. raw:: html
statistic EXPV MACE MAPE MSLL NLL R2 RMSE Rho Rho_p SMSE ShapiroW
response_vars
CortexVol 0.421982 0.026852 0.063421 -11.112723 1.104837 0.413094 37502.992757 0.641003 2.177593e-26 0.586906 0.995037
Right-Amygdala 0.302032 0.017963 0.094731 -5.663263 1.219263 0.300213 197.373557 0.524081 1.229031e-16 0.699787 0.991782
Right-Lateral-Ventricle 0.226325 0.015185 0.431631 -8.528336 1.162170 0.226195 3440.421655 0.262138 9.672659e-05 0.773805 0.984698
WM-hypointensities 0.392663 0.018519 0.347739 -7.076143 0.748407 0.383796 475.098802 0.512638 7.123854e-16 0.616204 0.980840
What’s next? ------------ Now we have a normative hierarchical Bayesian regression model, we can use it to: - Make predictions on new data - Harmonize data, this means that we ‘remove’ the batch effects from the data, by simulating what the data would have looked like if all data was from the same batch. - Synthesize new data - Extend the model using data from new batches Predicting ~~~~~~~~~~ .. code:: ipython3 model.predict(test) .. raw:: html
<xarray.NormData> Size: 98kB
    Dimensions:            (observations: 216, response_vars: 4, covariates: 1,
                            batch_effect_dims: 2, statistic: 11, centile: 5)
    Coordinates:
      * observations       (observations) int64 2kB 756 769 692 616 ... 751 470 1043
      * response_vars      (response_vars) <U23 368B 'WM-hypointensities' ... 'Co...
      * covariates         (covariates) <U3 12B 'age'
      * batch_effect_dims  (batch_effect_dims) <U4 32B 'sex' 'site'
      * statistic          (statistic) <U8 352B 'EXPV' 'MACE' ... 'SMSE' 'ShapiroW'
      * centile            (centile) float64 40B 0.05 0.25 0.5 0.75 0.95
    Data variables:
        subject_ids        (observations) object 2kB 'Munchen_sub96752' ... 'Quee...
        Y                  (observations, response_vars) float64 7kB 2.721e+03 .....
        X                  (observations, covariates) float64 2kB 63.0 ... 23.0
        batch_effects      (observations, batch_effect_dims) <U17 29kB 'F' ... 'Q...
        Z                  (observations, response_vars) float64 7kB 0.5221 ... -...
        logp               (observations, response_vars) float64 7kB -1.623 ... -...
        Yhat               (observations, response_vars) float64 7kB 2.362e+03 .....
        statistics         (response_vars, statistic) float64 352B 0.3927 ... 0.995
        centiles           (centile, observations, response_vars) float64 35kB 1....
    Attributes:
        real_ids:                       True
        is_scaled:                      False
        name:                           fcon1000_test
        unique_batch_effects:           {np.str_('sex'): [np.str_('F'), np.str_('...
        batch_effect_counts:            defaultdict(<function NormData.register_b...
        covariate_ranges:               {np.str_('age'): {'mean': np.float64(28.2...
        batch_effect_covariate_ranges:  {np.str_('sex'): {np.str_('F'): {np.str_(...
Harmonize ~~~~~~~~~ .. code:: ipython3 # Harmonizing is also easy: reference_batch_effect = { "site": "Beijing_Zang", "sex": "M", } # Set a pseudo-batch effect. I.e., this means 'pretend that all data was from this site and sex' # model.harmonize(train, reference_batch_effect=reference_batch_effect) # <- easy plt.style.use("seaborn-v0_8") df = train.to_dataframe() fig, ax = plt.subplots(1, 2, figsize=(13, 5), sharey=True) sns.scatterplot(data=df, x=("X", "age"), y=("Y", feature_to_plot), hue=("batch_effects", "site"), ax=ax[0]) sns.scatterplot(data=df, x=("X", "age"), y=("Y_harmonized", feature_to_plot), hue=("batch_effects", "site"), ax=ax[1]) ax[0].title.set_text("Unharmonized") ax[1].title.set_text("Harmonized") ax[0].legend([], []) ax[1].legend([], []) ax[0].set_xlabel("Age") ax[0].set_ylabel(feature_to_plot) ax[1].set_xlabel("Age") ax[1].set_ylabel(feature_to_plot) plt.tight_layout() plt.show() .. image:: 04_HBR_SHASH_files/04_HBR_SHASH_29_0.png Synthesize ~~~~~~~~~~ Our models can synthesize new data that follows the learned distribution. Not only the distribution of the response variables given a covariate is learned, but also the ranges of the covariates *within* each batch effect. So if we have fitted a model on a number of sites, and subjects from A have an age between 10 and 20, then the synthesized pseudo-subjects from site A will also have an age between 10 and 20. Not only that, but we also sample the batch effects in the frequency of the batch effects in the original data. So if the train data contained twice as many subjects from site A as site B, then the synthesized pseudo-subjects will also have twice as many subjects from site A as site B. .. code:: ipython3 # Generate 10000 synthetic datapoints from scratch synthetic_data = model.synthesize(covariate_range_per_batch_effect=True, n_samples=1000) # <- also easy # Show the synthetic data along with the centiles plot_centiles_advanced( model, covariate="age", # Which covariate to plot on the x-axis scatter_data=synthetic_data, ) .. image:: 04_HBR_SHASH_files/04_HBR_SHASH_31_0.png .. image:: 04_HBR_SHASH_files/04_HBR_SHASH_31_1.png .. image:: 04_HBR_SHASH_files/04_HBR_SHASH_31_2.png .. image:: 04_HBR_SHASH_files/04_HBR_SHASH_31_3.png .. code:: ipython3 # Synthesize new Y data for existing X data new_test_data = test.copy() # Remove the Y data, this way we will synthesize new Y data for the existing X data if hasattr(new_test_data, "Y"): del new_test_data["Y"] synthetic = model.synthesize(new_test_data) # <- will fill in the missing Y data .. code:: ipython3 plot_centiles_advanced( model, centiles=[0.05, 0.5, 0.95], # Plot arbitrary centiles covariate="age", # Which covariate to plot on the x-axis scatter_data=synthetic, # Scatter the train data points batch_effects="all", # You can set this to "all" to show all batch effects show_other_data=False, # Show data points that do not match any batch effects show_centile_labels=True, harmonize_data=True, # Set this to False to see the difference show_legend=False, # Don't show the legend because it crowds the plot ) .. image:: 04_HBR_SHASH_files/04_HBR_SHASH_33_0.png .. image:: 04_HBR_SHASH_files/04_HBR_SHASH_33_1.png .. image:: 04_HBR_SHASH_files/04_HBR_SHASH_33_2.png .. image:: 04_HBR_SHASH_files/04_HBR_SHASH_33_3.png Next steps ---------- Please see the other tutorials for more examples, and we also recommend you to read the documentation! As this toolkit is still in development, the documentation may not be up to date. If you find any issues, please let us know! Also, feel free to contact us on Github if you have any questions or suggestions. Have fun modeling! Bonus content ~~~~~~~~~~~~~ Here is another model configuration using a SHASH likelihood, but this one also has a linear regression in epsilon and delta. If you have a feature that is heavily skewed and for which the skewness also changes with the covariates, this is the model for you: .. code:: ipython3 # Here's a model with a SHASHb likelihood, with a linear regression in all four parameters, so including epsilon and delta. # This is a very flexible model, but it will also take a lot longer to run. mu = make_prior( linear=True, slope=make_prior(dist_name="Normal", dist_params=(0.0, 10.0)), intercept=make_prior( random=True, mu=make_prior(dist_name="Normal", dist_params=(1.0, 1.0)), sigma=make_prior(dist_name="Gamma", dist_params=(3.0, 1.0)), ), basis_function=BsplineBasisFunction(basis_column=0, nknots=5, degree=3), ) sigma = make_prior( linear=True, slope=make_prior(dist_name="Normal", dist_params=(0.0, 2.0)), intercept=make_prior(dist_name="Normal", dist_params=(1.0, 1.0)), basis_function=BsplineBasisFunction(basis_column=0, nknots=5, degree=3), mapping="softplus", mapping_params=(0.0, 3.0), ) epsilon = make_prior( linear=True, slope=make_prior(dist_name="Normal", dist_params=(0.0, 1.0)), intercept=make_prior(dist_name="Normal", dist_params=(0.0, 1.0)), basis_function=BsplineBasisFunction(basis_column=0, nknots=5, degree=3), ) delta = make_prior( linear=True, slope=make_prior(dist_name="Normal", dist_params=(0.0, 1.0)), intercept=make_prior(dist_name="Normal", dist_params=(1.0, 1.0)), basis_function=BsplineBasisFunction(basis_column=0, nknots=5, degree=3), mapping="softplus", mapping_params=( 0.0, 3.0, # Scale for smoothness 0.6, # We need to provide a vertical shift as well, because the SHASH mapping goes a bit wild with low values for delta ), ) shashb2_regression_model = HBR( name="template", cores=16, progressbar=True, draws=1500, tune=500, chains=4, nuts_sampler="nutpie", likelihood=SHASHbLikelihood(mu, sigma, epsilon, delta), )