neuralGAM is a fully explainable Deep Learning framework based on Generalized Additive Models, which trains a different neural network to estimate the contribution of each feature to the response variable.
The networks are trained independently leveraging the local scoring and backfitting algorithms to ensure that the Generalized Additive Model converges and it is additive. The resultant Neural Network is a highly accurate and interpretable deep learning model, which can be used for high-risk AI practices where decision-making should be based on accountable and interpretable algorithms.
The full methodology of the method to train Generalized Additive Models using Deep Neural Networks is published in the following paper:
Ortega-Fernandez, I., Sestelo, M. & Villanueva, N.M. Explainable generalized additive neural networks with independent neural network training. Statistics & Computing 34, 6 (2024). https://doi.org/10.1007/s11222-023-10320-5
and is also available in Python at the following Github repository.
neuralGAM is based on Deep Neural Networks, and depends on Tensorflow and Keras packages. Therefore, a working Python>3.10 installation with those packages installed is required.
We provide a helper function to get a working python installation from RStudio, which creates a miniconda environment with all the required packages.
library(neuralGAM)
install_neuralGAM()
In the following example, we use synthetic data to showcase the performance of neuralGAM by fitting a model with a single layer with 1024 units.
<- 5000
n <- 42
seed set.seed(seed)
<- runif(n, -2.5, 2.5)
x1 <- runif(n, -2.5, 2.5)
x2 <- runif(n, -2.5, 2.5)
x3
<- x1**2
f1 <- 2 * x2
f2 <- sin(x3)
f3 <- f1 - mean(f1)
f1 <- f2 - mean(f2)
f2 <- f3 - mean(f3)
f3
<- 2 + f1 + f2 + f3
eta0 <- eta0 + rnorm(n, 0.25)
y <- data.frame(x1, x2, x3, y)
train
<- neuralGAM(
ngam ~ s(x1) + x2 + s(x3),
y data = train,
num_units = 128, family = "gaussian",
activation = "relu",
learning_rate = 0.001, bf_threshold = 0.001,
max_iter_backfitting = 10, max_iter_ls = 10,
uncertainty_method = "epistemic", forward_passes = 100,
seed = seed
)
summary(ngam)
You can then use the plot
function to visualize the
learnt partial effects:
plot(ngam)
Or the custom autoplot
function for more advanced
graphics using the ggplot2 library, including Confidence Intervals (if
available)
autoplot(ngam, which="terms", term = "x1", interval = "confidence")
To obtain predictions from new data, use the predict
function:
<- 5000
n <- runif(n, -2.5, 2.5)
x1 <- runif(n, -2.5, 2.5)
x2 <- runif(n, -2.5, 2.5)
x3
<- data.frame(x1, x2, x3)
test
# Obtain linear predictor
<- predict(ngam, newdata = test, type = "link")
eta
# Obtain predicted response using se.fit = TRUE to obtain standard errors:
<- predict(ngam, newdata = test, type = "response", se.fit = TRUE)
yhat
head(yhat$fit)
head(yhat$se.fit)
# Obtain each component of the linear predictor
<- predict(ngam, newdata = test, type = "terms")
terms
# Obtain only certain terms:
<- predict(ngam, newdata = test, type = "terms", terms = c("x1", "x2")) terms
neuralGAM
from version 2.0 allows to specify
hyperparameters per smooth term inside s(), overriding global defaults,
and the summary()
now print each smooth term
configuraion:
<- neuralGAM(
ngam ~ s(x1, num_units = 32) + x2 + s(x3, activation = "tanh"),
y data = train,
num_units = 64, # default for terms without explicit num_units
seed = seed
)
Per-term configuration supports custom initializers and regularizers
for both weights and biases, enabling fine control over model complexity
and stability. For example, you can set one of the neural networks to
use L2 regularization and He initialization using Keras functions
directly (i.e. keras::regularizer_l1()
).
This is useful for:
<- neuralGAM(
ngam ~ s(
y
x1, kernel_initializer = keras::initializer_he_normal(),
bias_initializer = keras::initializer_zeros(),
kernel_regularizer = keras::regularizer_l2(0.01),
bias_regularizer = keras::regularizer_l1(0.001)
+
) s(x2),
data = train,
num_units = 64,
activation = "relu",
seed = seed
)
The summary()
now prints each smooth terms configuration
and the essential parameters of each network’s architecture.
Enable confidence intervals by setting
uncertainty_method
and specifying a sensitivity level via
alpha
(i.e. 0.05
for 95% coverage). For
epistemic variance, forward_passes > 150
is
recommended.
<- neuralGAM(
ngam ~ s(x1) + s(x2),
y data = train,
uncertainty_method = "epistemic",
forward_passes = 100,
alpha = 0.05,
num_units = 1024,
seed = seed
)<- predict(ngam, newdata = test, type = "response")
pred head(pred)
You can monitor the validation loss during training using the
validation_split
parameter. You can then visualize the how
the loss evolves per backfitting iteration using the
plot_history()
function.
<- neuralGAM(y ~ s(x1) + x2 + s(x3),
ngam data = train,
num_units = 1024, family = "gaussian",
activation = "relu",
learning_rate = 0.001, bf_threshold = 0.001,
max_iter_backfitting = 10, max_iter_ls = 10,
validation_split = 0.2,
seed = seed)
# Plot loss per backfitting iteration
plot_history(ngam)
plot_history(ngam, select = "x1") # Plot just x1
plot_history(ngam, metric = "val_loss") # Plot only validation loss
The enhanced summary() shows:
Moreover, after fitting a neuralGAM
model, it is
important to evaluate whether the model assumptions are reasonable and
whether predictions are well calibrated.
The helper function diagnose()
provides a 2×2
diagnostic panel similar to gratia::appraise()
for
mgcv
models.
The four panels are:
QQ plot of residuals (top-left)
Compares sample residuals to theoretical quantiles. A straight line
indicates a good fit.
Deviations suggest skewness, heavy tails, or outliers.
Residuals vs linear predictor η
(top-right)
Shows residuals against the fitted linear predictor, with a LOESS
smoother.
A flat trend near 0 is ideal. Systematic curvature means the model
missed a trend; funnel shapes suggest heteroscedasticity.
Histogram of residuals (bottom-left)
Displays the distribution of residuals. Ideally symmetric and centered
at 0.
Skewness or multimodality may indicate model misspecification.
Observed vs fitted (bottom-right)
Compares predicted values with observed outcomes. For continuous data,
points should align with the 45° line.
For binary outcomes, this acts as a calibration plot: predicted
probabilities should match observed frequencies.
Notes
residual_type
can be deviance
(default), pearson
, or quantile
. Quantile
residuals (Dunn–Smyth) are recommended for discrete families (binomial,
Poisson) because they are continuous and approximately normal.
qq_method
controls reference quantiles:
uniform
: fast and default.simulate
: most faithful and provides bands, but
slowernormal
: fallbackTogether, these diagnosis plots help assess whether residuals behave like noise, whether systematic trends remain and if predictions are unbiased and calibrated.
If you use neuralGAM in your research, please cite the following paper:
Ortega-Fernandez, I., Sestelo, M. & Villanueva, N.M. Explainable generalized additive neural networks with independent neural network training. Statistics & Computing 34, 6 (2024). https://doi.org/10.1007/s11222-023-10320-5
@article{ortega2024explainable,
author = {Ortega-Fernandez, Ines and Sestelo, Marta and Villanueva, Nora M},
doi = {10.1007/s11222-023-10320-5},
issn = {1573-1375},
journal = {Statistics and Computing},
number = {1},
pages = {6},
title = {{Explainable generalized additive neural networks with independent neural network training}},
url = {https://doi.org/10.1007/s11222-023-10320-5},
volume = {34},
year = {2023}
}