.. Places parent toc into the sidebar :parenttoc: True .. include:: ../../includes/big_toc_css.rst ================================= Generalized Additive Model ================================= In a generalized additive model (GAM), the primary effect of each feature is modeled using a non-parametric function, which can be expressed as: .. math:: \begin{align} g(\mathbb{E}(y|\textbf{x})) = f_{1}(x_{1}) + f_{2}(x_{2}) + \ldots + f_{d}(x_{d}) + \mu. \tag{1} \end{align} In this equation, :math:`f` represents the shape function, which is an unknown, smooth transformation of the features. This function can be estimated using a variety of methods, including smoothing splines, ensemble trees, or neural networks. Compared to GLM, the use of shape functions in GAMs allows for a more flexible and predictive model, as smooth functions can capture non-linear patterns in the data. However, to ensure model identifiability, the output of each shape function is assumed to have a zero mean. Model Training ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ We demonstrate how to train a GAM using the `CaliforniaHousing` dataset and assume that this data has been well prepared. Next, we import the corresponding estimators, i.e., `GAMRegressor`_ or `GAMClassifier`_, both based on the Python package `pygam`_. As it is a regression problem, we use the `GAMRegressor`_ object with some customized hyperparameter settings. .. jupyter-input:: from piml.models import GAMRegressor exp.model_train(model=GAMRegressor(spline_order=1, n_splines=20, lam=0.6), name="GAM") Below we briefly introduce some of its key hyperparameters. .. _GAMRegressor: ../../modules/generated/piml.models.GAMRegressor.html .. _GAMClassifier: ../../modules/generated/piml.models.GAMClassifier.html .. _pygam: https://pygam.readthedocs.io/en/latest/ - `spline_order`: The degree of the piecewise polynomial representation. - order 0: piecewise constant spline - order 1: piecewise linear spline - order 2: piecewise quadratic spline - order 3: piecewise cubic spline - `n_splines`: The number of knots in spline transformation refers to the number of anchor points used in estimating the shape function. With a larger number of knots, the estimated shape function tends to be more complex and able to capture more intricate patterns in the data. However, a higher number of knots also leads to a less smooth estimated function, which may result in overfitting the training data and poorer generalization to new data. - `lam`: The regularization strength hyperparameter that controls the smoothness of the estimated shape function. This penalty term helps to prevent overfitting of the training data and encourages the model to learn simpler and more generalizable patterns in the data. By increasing the value of `lam`, the estimated shape function becomes smoother, which can capture more general patterns in the data. However, a smoother shape function may also result in a less predictive model. Conversely, a lower value of `lam` leads to a rougher shape function, which can capture more intricate patterns in the data, potentially resulting in overfitting. Global Interpretation ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ For a fitted GAM, there are two common ways to interpret it globally, i.e., main effect plot and feature importance. Main Effect Plot """""""""""""""""""""""""""""""" This plot shows the estimated effect of each feature on the predicted response while controlling for the effects of other features in the model. - In GLM, the effect of each feature is a linear function, and the slope is determined by the regression coefficient. - In GAM, the effect of each feature can be of any shape, and this plot can be used to identify non-linear relationships between the features and the response and to assess the strength and direction of the relationship. Below we show how to generate the global effect plot for GAMs, using the keyword "global_effect_plot". .. jupyter-input:: exp.model_interpret(model="GAM", show="global_effect_plot", uni_feature="MedInc", original_scale=True, figsize=(5, 4)) .. figure:: ../../auto_examples/3_models/images/sphx_glr_plot_1_gam_reg_001.png :target: ../../auto_examples/3_models/plot_1_gam_reg.html :align: left The figure above shows the estimated shape functions for the feature `MedInc`, together with the histogram plot on the bottom. It can be observed that the estimated shape function is piecewise linear as the order of spline is set to 1. The estimated shape function is relatively flat for the large values of `MedInc`, while we observe a steep increase in the beginning. This suggests that the median house price is relatively stable for most values of `MedInc`, but there is a sharp increase in median house price for low values of MedInc (less than 9.2). This indicates that `MedInc` is an important predictor for predicting the median house price, especially for low-income areas. Feature Importance """""""""""""""""""""""""""""""" The keyword "global_fi" corresponds to the feature importance plot. The importance of the :math:`j`-th feature is calculated by the variance of :math:`\hat{f}_j(x)` using the training data. The feature importance is always non-negative and we also normalize them so that their sum equals 1. .. jupyter-input:: exp.model_interpret(model="GAM",show="global_fi", figsize=(5, 4)) .. figure:: ../../auto_examples/3_models/images/sphx_glr_plot_1_gam_reg_002.png :target: ../../auto_examples/3_models/plot_1_gam_reg.html :align: left Due to space limitations, we display only the top 10 important features. To get the results of all features, you can manually set `return_data=True`, and then a data frame containing all features' importance will be returned. From the plot above, we can see that the most important features are `Latitude` and `Longitude`, followed by `MedInc` and `AveOccup`. Local Interpretation ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ This plot is useful for understanding how the model makes predictions for a particular data point and which features have the strongest influence on the prediction. It can also help identify the features that are driving the prediction. Similar to GLM, we use "local_fi" as the keyword for GAM's local interpretation. .. jupyter-input:: exp.model_interpret(model="GAM",show="local_fi", sample_id=0, original_scale=True, figsize=(5, 4)) .. figure:: ../../auto_examples/3_models/images/sphx_glr_plot_1_gam_reg_003.png :target: ../../auto_examples/3_models/plot_1_gam_reg.html :align: left In this plot, the predicted value is 0.3804, while the actual response is 1.0. The bars represent the estimated effect values :math:`\hat{f}_j(x)` of the chosen sample. The feature `Latitude` shows the largest positive contribution to the final prediction, while `MedInc`, `Longitude`, and `AveOccup` all have negative contributions. Note that this plot only shows the top 10 features with the largest contribution. To get the full results, you can set the parameter `return_data` to True. Example ^^^^^^^^^^^^^^^^^^^^^ .. topic:: Example 1: California Housing The example below demonstrates how to use PiML with its high-code APis for the California Housing dataset from the UCI repository, which consists of 20,640 samples and 9 features, fetched by `sklearn.datasets`. The response variable `MedHouseVal` (Median Home Value) is continuous and is a regression problem. * :ref:`sphx_glr_auto_examples_3_models_plot_1_gam_reg.py` .. topic:: Example 2: CoCircles The example below demonstrates how to use PiML with the CoCircles data, which is a simulated binary classification data with two features exactly follow the GAM assumptions. * :ref:`sphx_glr_auto_examples_3_models_plot_1_gam_cls.py`