当前位置: 首页 > news >正文

机器学习——逻辑回归(LogisticRegression)的核心参数:以约会数据集为例


理解 LogisticRegression 的核心参数:以约会数据集为例

逻辑回归(Logistic Regression)是机器学习中一种基础且重要的分类算法,特别适用于解决二分类和多分类问题。本文将基于 sklearn.linear_model.LogisticRegression 的用法,结合一个典型的约会数据集,通过代码实践,详解其核心参数的作用与调优技巧。

下方链接下载文件

作者《机器学习实战》 Peter Harrington 的配套代码

machinelearninginaction/Ch02/datingTestSet2.txt at master · pbharrin/machinelearninginactionhttps://github.com/pbharrin/machinelearninginaction/blob/master/Ch02/datingTestSet2.txt(完整代码在底部)

一、模型背景与代码示例

我们使用《机器学习实战》中的数据集 datingTestSet2.txt

其数据包含三列特征和一个表示喜欢程度的标签(1:不喜欢,2:魅力一般,3:非常喜欢)。

通过如下代码训练逻辑回归模型:

from sklearn.linear_model import LogisticRegression# 初始化逻辑回归模型
model = LogisticRegression()
model.fit(X_train, y_train)

接下来我们围绕 LogisticRegression 的主要参数逐个进行解析。


二、核心参数详解

1️⃣ C: 正则化强度的倒数(默认值:1.0)

  • 含义:控制正则化项的权重。较小的 C 值表示更强的正则化,会限制模型复杂度,有助于防止过拟合。

  • 实例:在本文代码中使用了 C=0.01,代表较强的正则化。

model = LogisticRegression(C=0.01)

建议:

  • 若模型过拟合(训练集准确高,测试集低)→ 减小 C

  • 若模型欠拟合(整体准确率都低)→ 增大 C


2️⃣ penalty: 正则化方式(默认值:'l2')

  • 可选值:'l1', 'l2', 'elasticnet', 'none'

  • 'l1':可产生稀疏模型(特征选择)

  • 'l2':默认值,更适合大多数线性问题

  • 'elasticnet':结合 l1l2

  • 'none':不使用正则化(风险较大)

model = LogisticRegression(penalty='l2')

❗ 注意:不同 solver 对支持的 penalty 有限制,例如 'liblinear' 支持 'l1''l2',而 'saga' 才支持 'elasticnet'


3️⃣ solver: 优化算法(默认值:'lbfgs')

  • 可选值:

    • 'liblinear':适用于小数据集,支持 'l1''l2'

    • 'lbfgs':适合多分类(支持 'l2'),速度快,默认值

    • 'newton-cg''sag''saga':适合大数据

model = LogisticRegression(solver='lbfgs')

✅ 实际建议:

  • 小数据集(如本文案例) → liblinear

  • 多分类任务 → lbfgssaga

  • 稀疏特征(如文本) → saga


4️⃣ multi_class: 多分类策略(默认:'auto')

  • 'ovr'(一对其余,One-vs-Rest):训练多个二分类器,速度快,解释性强

  • 'multinomial':直接优化多分类损失函数,预测效果通常更优

  • 'auto':自动选择(liblinearovr,其他 → multinomial

model = LogisticRegression(multi_class='multinomial')

5️⃣ max_iter: 最大迭代次数(默认值:100)

  • 当模型无法收敛时,可以调大该值,如设置为 1000

  • 若出现如下报错:ConvergenceWarning: lbfgs failed to converge → 增大 max_iter

model = LogisticRegression(max_iter=1000)

6️⃣ class_weight: 类别权重(默认值:None)

  • 用于处理类别不平衡问题,如设为 'balanced' 会自动按样本数调整权重

  • 或自定义字典,例如 {1:1, 2:2, 3:3}

model = LogisticRegression(class_weight='balanced')

7️⃣ random_state: 随机种子(可重复结果)

  • 设置后模型行为可复现,例如 random_state=42

  • 在划分训练/测试集、优化器初始化中有用

model = LogisticRegression(random_state=42)

除了LogisticRegression的参数,还有:

train_test_split的参数:

参数名类型说明
X, y数组或矩阵特征矩阵 X 和标签向量 y,支持 NumPy、Pandas、List 等
test_sizefloat 或 int测试集占比(如 0.25)或测试集样本数(如 100)
train_sizefloat 或 int训练集占比或样本数,默认自动补足(1 - test_size)
random_stateint随机种子,用于保证划分可复现。设为固定值(如 42)结果不会变
shufflebool是否在划分前打乱数据(默认 True,一般都要打乱)
stratifyarray-like 或 None分层抽样依据(常设为 y),用于保持标签比例一致(分类任务推荐)

三、系数和截距

print(model.coef_)       # 每个类别的特征系数(权重)
print(model.intercept_)  # 每个类别的偏置(截距)

对于三分类模型(标签为 1、2、3),会输出三组线性决策函数(即分割面):

y = w1*x1 + w2*x2 + w3*x3 + b

如输出结果如下:

分割线1: y = -0.1234x1 + 0.2345x2 - 0.5678x3 + 1.2345
分割线2: y = ...

四、总结:调参建议

问题建议参数
模型过拟合减小 C,增加正则化强度
模型欠拟合增大 C,尝试 multinomial
类别不平衡使用 class_weight='balanced'
收敛慢或警告增加 max_iter,或更换 solver
特征太多,想降维使用 penalty='l1', solver='liblinear'

附:完整模型构建代码

import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_splitdata = np.loadtxt('datingTestSet2.txt')
X = data[:, :-1]
y = data[:, -1]# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=100)# 训练逻辑回归模型
model = LogisticRegression(C=0.01, max_iter=200, multi_class='auto', solver='lbfgs')
model.fit(X_train, y_train)print("训练集准确率:", model.score(X_train, y_train))
print("测试集准确率:", model.score(X_test, y_test))
print("权重系数:", model.coef_)
print("截距:", model.intercept_)# 自变量系数和截距
a = model.coef_
b = model.intercept_
print(f"分割线1:y = {a[0][0]:.4f}x1 + {a[0][1]:.4f}x2 + {a[0][2]:.4f}x3 + {b[0]:.4f}")
print(f"分割线2:y = {a[1][0]:.4f}x1 + {a[1][1]:.4f}x2 + {a[1][2]:.4f}x3 + {b[1]:.4f}")
print(f"分割线3:y = {a[2][0]:.4f}x1 + {a[2][1]:.4f}x2 + {a[2][2]:.4f}x3 + {b[2]:.4f}")

分割线中的系数四舍五入了。

http://www.xdnf.cn/news/1225675.html

相关文章:

  • Linux中Docker Swarm介绍和使用
  • Leetcode 10 java
  • linux81 shell通配符:[list],‘‘ ``““
  • python文件操作:读取文件内容read
  • 噪声对比估计(NCE):原理、演进与跨领域应用
  • 【深度学习①】 | Numpy数组篇
  • C#线程同步(二)锁
  • 国产开源大模型崛起:使用Kimi K2/Qwen2/GLM-4.5搭建编程助手
  • Go语言中的盲点:竞态检测和互斥锁的错觉
  • ctfshow_web签到题
  • 从内部保护你的网络
  • 江协科技STM32 12-2 BKP备份寄存器RTC实时时钟
  • TwinCAT3编程入门2
  • 从 0 到 1 认识 Spring MVC:核心思想与基本用法(下)
  • 自动化框架pytest
  • 【Kubernetes 指南】基础入门——Kubernetes 集群(二)
  • 雷达微多普勒特征代表运动中“事物”的运动部件。
  • Ubuntu 开启wifi 5G 热点
  • p5.js 3D模型(model)入门指南
  • ubuntu 镜像克隆
  • hadoop.yarn 带时间的LRU 延迟删除
  • Ubuntu-Server-24.04-LTS版本操作系统如何关闭自动更新,并移除不必要的内核
  • C#常见的转义字符
  • Vue3 setup、ref和reactive函数
  • Vue 详情模块 1
  • C++对象访问有访问权限是不是在ide里有效
  • 解决MySQL不能编译存储过程的问题
  • 《Java 程序设计》核心知识点梳理与深入探究
  • SpringMVC全局异常处理+拦截器使用+参数校验
  • 2025 腾讯广告算法大赛 Baseline 项目解析