广义线性模型 (公式)

此笔记本演示了如何使用 R 风格的公式来拟合广义线性模型。

首先,我们加载 Star98 数据集,并构建公式和预处理数据

[1]:
import statsmodels.api as sm
import statsmodels.formula.api as smf

star98 = sm.datasets.star98.load_pandas().data
formula = "SUCCESS ~ LOWINC + PERASIAN + PERBLACK + PERHISP + PCTCHRT + \
           PCTYRRND + PERMINTE*AVYRSEXP*AVSALK + PERSPENK*PTRATIO*PCTAF"
dta = star98[
    [
        "NABOVE",
        "NBELOW",
        "LOWINC",
        "PERASIAN",
        "PERBLACK",
        "PERHISP",
        "PCTCHRT",
        "PCTYRRND",
        "PERMINTE",
        "AVYRSEXP",
        "AVSALK",
        "PERSPENK",
        "PTRATIO",
        "PCTAF",
    ]
].copy()
endog = dta["NABOVE"] / (dta["NABOVE"] + dta.pop("NBELOW"))
del dta["NABOVE"]
dta["SUCCESS"] = endog

然后,我们拟合 GLM 模型

[2]:
mod1 = smf.glm(formula=formula, data=dta, family=sm.families.Binomial()).fit()
print(mod1.summary())
                 Generalized Linear Model Regression Results
==============================================================================
Dep. Variable:                SUCCESS   No. Observations:                  303
Model:                            GLM   Df Residuals:                      282
Model Family:                Binomial   Df Model:                           20
Link Function:                  Logit   Scale:                          1.0000
Method:                          IRLS   Log-Likelihood:                -127.33
Date:                Thu, 03 Oct 2024   Deviance:                       8.5477
Time:                        15:44:55   Pearson chi2:                     8.48
No. Iterations:                     4   Pseudo R-squ. (CS):             0.1115
Covariance Type:            nonrobust
============================================================================================
                               coef    std err          z      P>|z|      [0.025      0.975]
--------------------------------------------------------------------------------------------
Intercept                    0.4037     25.036      0.016      0.987     -48.665      49.472
LOWINC                      -0.0204      0.010     -1.982      0.048      -0.041      -0.000
PERASIAN                     0.0159      0.017      0.910      0.363      -0.018       0.050
PERBLACK                    -0.0198      0.020     -1.004      0.316      -0.058       0.019
PERHISP                     -0.0096      0.010     -0.951      0.341      -0.029       0.010
PCTCHRT                     -0.0022      0.022     -0.103      0.918      -0.045       0.040
PCTYRRND                    -0.0022      0.006     -0.348      0.728      -0.014       0.010
PERMINTE                     0.1068      0.787      0.136      0.892      -1.436       1.650
AVYRSEXP                    -0.0411      1.176     -0.035      0.972      -2.346       2.264
PERMINTE:AVYRSEXP           -0.0031      0.054     -0.057      0.954      -0.108       0.102
AVSALK                       0.0131      0.295      0.044      0.965      -0.566       0.592
PERMINTE:AVSALK             -0.0019      0.013     -0.145      0.885      -0.028       0.024
AVYRSEXP:AVSALK              0.0008      0.020      0.038      0.970      -0.039       0.041
PERMINTE:AVYRSEXP:AVSALK  5.978e-05      0.001      0.068      0.946      -0.002       0.002
PERSPENK                    -0.3097      4.233     -0.073      0.942      -8.606       7.987
PTRATIO                      0.0096      0.919      0.010      0.992      -1.792       1.811
PERSPENK:PTRATIO             0.0066      0.206      0.032      0.974      -0.397       0.410
PCTAF                       -0.0143      0.474     -0.030      0.976      -0.944       0.916
PERSPENK:PCTAF               0.0105      0.098      0.107      0.915      -0.182       0.203
PTRATIO:PCTAF               -0.0001      0.022     -0.005      0.996      -0.044       0.044
PERSPENK:PTRATIO:PCTAF      -0.0002      0.005     -0.051      0.959      -0.010       0.009
============================================================================================

最后,我们定义一个函数来使用公式框架执行自定义数据转换

[3]:
def double_it(x):
    return 2 * x


formula = "SUCCESS ~ double_it(LOWINC) + PERASIAN + PERBLACK + PERHISP + PCTCHRT + \
           PCTYRRND + PERMINTE*AVYRSEXP*AVSALK + PERSPENK*PTRATIO*PCTAF"
mod2 = smf.glm(formula=formula, data=dta, family=sm.families.Binomial()).fit()
print(mod2.summary())
                 Generalized Linear Model Regression Results
==============================================================================
Dep. Variable:                SUCCESS   No. Observations:                  303
Model:                            GLM   Df Residuals:                      282
Model Family:                Binomial   Df Model:                           20
Link Function:                  Logit   Scale:                          1.0000
Method:                          IRLS   Log-Likelihood:                -127.33
Date:                Thu, 03 Oct 2024   Deviance:                       8.5477
Time:                        15:44:55   Pearson chi2:                     8.48
No. Iterations:                     4   Pseudo R-squ. (CS):             0.1115
Covariance Type:            nonrobust
============================================================================================
                               coef    std err          z      P>|z|      [0.025      0.975]
--------------------------------------------------------------------------------------------
Intercept                    0.4037     25.036      0.016      0.987     -48.665      49.472
double_it(LOWINC)           -0.0102      0.005     -1.982      0.048      -0.020      -0.000
PERASIAN                     0.0159      0.017      0.910      0.363      -0.018       0.050
PERBLACK                    -0.0198      0.020     -1.004      0.316      -0.058       0.019
PERHISP                     -0.0096      0.010     -0.951      0.341      -0.029       0.010
PCTCHRT                     -0.0022      0.022     -0.103      0.918      -0.045       0.040
PCTYRRND                    -0.0022      0.006     -0.348      0.728      -0.014       0.010
PERMINTE                     0.1068      0.787      0.136      0.892      -1.436       1.650
AVYRSEXP                    -0.0411      1.176     -0.035      0.972      -2.346       2.264
PERMINTE:AVYRSEXP           -0.0031      0.054     -0.057      0.954      -0.108       0.102
AVSALK                       0.0131      0.295      0.044      0.965      -0.566       0.592
PERMINTE:AVSALK             -0.0019      0.013     -0.145      0.885      -0.028       0.024
AVYRSEXP:AVSALK              0.0008      0.020      0.038      0.970      -0.039       0.041
PERMINTE:AVYRSEXP:AVSALK  5.978e-05      0.001      0.068      0.946      -0.002       0.002
PERSPENK                    -0.3097      4.233     -0.073      0.942      -8.606       7.987
PTRATIO                      0.0096      0.919      0.010      0.992      -1.792       1.811
PERSPENK:PTRATIO             0.0066      0.206      0.032      0.974      -0.397       0.410
PCTAF                       -0.0143      0.474     -0.030      0.976      -0.944       0.916
PERSPENK:PCTAF               0.0105      0.098      0.107      0.915      -0.182       0.203
PTRATIO:PCTAF               -0.0001      0.022     -0.005      0.996      -0.044       0.044
PERSPENK:PTRATIO:PCTAF      -0.0002      0.005     -0.051      0.959      -0.010       0.009
============================================================================================

正如预期,第二个模型中 double_it(LOWINC) 的系数是第一个模型中 LOWINC 系数的一半

[4]:
print(mod1.params[1])
print(mod2.params[1] * 2)
-0.02039598715475645
-0.020395987154756174
/tmp/ipykernel_3451/1000445862.py:1: FutureWarning: Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`
  print(mod1.params[1])
/tmp/ipykernel_3451/1000445862.py:2: FutureWarning: Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`
  print(mod2.params[1] * 2)

上次更新:2024 年 10 月 3 日