当前位置: 首页 > news >正文

鸢尾花分类(KNN)

1. 加载数据集

iris = load_iris()
X = iris.data  # 特征数据 (150个样本 x 4个特征)
y = iris.target  # 目标变量 (3种类别)
feature_names = iris.feature_names  # 特征名称
target_names = iris.target_names  # 类别名称# 将数据转换为DataFrame便于分析
df = pd.DataFrame(X, columns=feature_names)
df['species'] = y
df['species'] = df['species'].map({0: 'setosa', 1: 'versicolor', 2: 'virginica'})

2. 数据探索

print("数据集维度:", X.shape)
print("\n特征示例:")
print(df.head())
print("\n统计摘要:")
print(df.describe())
print("\n类别分布:")
print(df['species'].value_counts())

运行结果:

数据集维度: (150, 4)特征示例:sepal length (cm)  sepal width (cm)  ...  petal width (cm)  species
0                5.1               3.5  ...               0.2   setosa
1                4.9               3.0  ...               0.2   setosa
2                4.7               3.2  ...               0.2   setosa
3                4.6               3.1  ...               0.2   setosa
4                5.0               3.6  ...               0.2   setosa[5 rows x 5 columns]统计摘要:sepal length (cm)  sepal width (cm)  petal length (cm)  petal width (cm)
count         150.000000        150.000000         150.000000        150.000000
mean            5.843333          3.057333           3.758000          1.199333
std             0.828066          0.435866           1.765298          0.762238
min             4.300000          2.000000           1.000000          0.100000
25%             5.100000          2.800000           1.600000          0.300000
50%             5.800000          3.000000           4.350000          1.300000
75%             6.400000          3.300000           5.100000          1.800000
max             7.900000          4.400000           6.900000          2.500000类别分布:
species
setosa        50
versicolor    50
virginica     50
Name: count, dtype: int64

3. 数据可视化

# 可视化分析
plt.figure(figsize=(12, 8))# 使用matplotlib的rcParams设置字体,否则图片中中文可能会乱码
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

3.1 特征分布直方图

# 特征分布直方图
for i, feature in enumerate(feature_names):plt.subplot(2, 2, i+1)sns.histplot(data=df, x=feature, hue='species', kde=True)plt.title(f'{feature} 分布')
plt.tight_layout()
plt.show()

运行结果:
在这里插入图片描述

3.2 特征关系散点图

# 特征关系散点图
sns.pairplot(df, hue='species', palette='viridis')
plt.suptitle('特征关系散点图矩阵', y=1.02)
plt.show()

运行结果:
在这里插入图片描述

4. 数据预处理

# 划分训练集和测试集 (80%训练, 20%测试)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y
)# 特征标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

5. 模型训练

使用K近邻算法(KNN):一个新样本的类别(或值)由其周围最相似的“邻居”的类别(或值)决定。

# 使用K近邻算法
knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(X_train, y_train)

6. 模型评估

y_pred = knn.predict(X_test)print("\n测试集准确率: {:.2f}%".format(accuracy_score(y_test, y_pred) * 100))
print("\n分类报告:")
print(classification_report(y_test, y_pred, target_names=target_names))# 混淆矩阵
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',xticklabels=target_names,yticklabels=target_names)
plt.xlabel('预测值')
plt.ylabel('真实值')
plt.title('混淆矩阵')
plt.show()

运行结果:

测试集准确率: 93.33%分类报告:precision    recall  f1-score   supportsetosa       1.00      1.00      1.00        10versicolor       0.83      1.00      0.91        10virginica       1.00      0.80      0.89        10accuracy                           0.93        30macro avg       0.94      0.93      0.93        30
weighted avg       0.94      0.93      0.93        30

在这里插入图片描述

7. 决策树模型对比

dt = DecisionTreeClassifier(max_depth=3, random_state=42)
dt.fit(X_train, y_train)
dt_pred = dt.predict(X_test)
print("\n决策树准确率: {:.2f}%".format(accuracy_score(y_test, dt_pred) * 100))# 可视化决策树
plt.figure(figsize=(15, 10))
plot_tree(dt, feature_names=feature_names,class_names=target_names, filled=True)
plt.title('鸢尾花分类决策树')
plt.show()

运行结果:

决策树准确率: 96.67%

在这里插入图片描述

8. 新样本预测

new_sample = np.array([[5.1, 3.5, 1.4, 0.2]])  # 新样本数据
new_sample_scaled = scaler.transform(new_sample)  # 标准化
prediction = knn.predict(new_sample_scaled)
print("\n新样本预测结果:", target_names[prediction][0])

运行结果:

新样本预测结果: setosa

9. 完整代码

# 导入必要的库
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.tree import DecisionTreeClassifier, plot_tree# 加载数据集
iris = load_iris()
X = iris.data  # 特征数据 (150个样本 x 4个特征)
y = iris.target  # 目标变量 (3种类别)
feature_names = iris.feature_names  # 特征名称
target_names = iris.target_names  # 类别名称# 将数据转换为DataFrame便于分析
df = pd.DataFrame(X, columns=feature_names)
df['species'] = y
df['species'] = df['species'].map({0: 'setosa', 1: 'versicolor', 2: 'virginica'})# 数据探索
print("数据集维度:", X.shape)
print("\n特征示例:")
print(df.head())
print("\n统计摘要:")
print(df.describe())
print("\n类别分布:")
print(df['species'].value_counts())# 可视化分析
plt.figure(figsize=(12, 8))# 使用matplotlib的rcParams设置字体,否则图片中中文可能会乱码
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False# 特征分布直方图
for i, feature in enumerate(feature_names):plt.subplot(2, 2, i+1)sns.histplot(data=df, x=feature, hue='species', kde=True)plt.title(f'{feature} 分布')
plt.tight_layout()
plt.show()# 特征关系散点图
sns.pairplot(df, hue='species', palette='viridis')
plt.suptitle('特征关系散点图矩阵', y=1.02)
plt.show()# 数据预处理
# 划分训练集和测试集 (80%训练, 20%测试)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y
)# 特征标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)# 模型训练(使用K近邻算法)
knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(X_train, y_train)# 模型评估
y_pred = knn.predict(X_test)print("\n测试集准确率: {:.2f}%".format(accuracy_score(y_test, y_pred) * 100))
print("\n分类报告:")
print(classification_report(y_test, y_pred, target_names=target_names))# 混淆矩阵
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',xticklabels=target_names,yticklabels=target_names)
plt.xlabel('预测值')
plt.ylabel('真实值')
plt.title('混淆矩阵')
plt.show()# 使用决策树模型进行对比(可选)
dt = DecisionTreeClassifier(max_depth=3, random_state=42)
dt.fit(X_train, y_train)
dt_pred = dt.predict(X_test)
print("\n决策树准确率: {:.2f}%".format(accuracy_score(y_test, dt_pred) * 100))# 可视化决策树
plt.figure(figsize=(15, 10))
plot_tree(dt, feature_names=feature_names,class_names=target_names, filled=True)
plt.title('鸢尾花分类决策树')
plt.show()# 进行新样本预测(示例)
new_sample = np.array([[5.1, 3.5, 1.4, 0.2]])  # 新样本数据
new_sample_scaled = scaler.transform(new_sample)  # 标准化
prediction = knn.predict(new_sample_scaled)
print("\n新样本预测结果:", target_names[prediction][0])
http://www.xdnf.cn/news/949429.html

相关文章:

  • 【AI News | 20250609】每日AI进展
  • 测试微信模版消息推送
  • 开源:FTP同步工具
  • 【粤语克隆】粤语声音,一秒克隆:如何用AI为岭南文化按下快进键
  • composer init
  • LeetCode - 647. 回文子串
  • 具身智能之人形机器人核心零部件介绍
  • 教程:PyCharm 中搭建多级隔离的 Poetry 环境(从 Anaconda 到项目专属.venv)
  • 重启Eureka集群中的节点,对已经注册的服务有什么影响
  • 深入理解JavaScript设计模式之单例模式
  • AirPosture | 通过 AirPods 矫正坐姿
  • 安科瑞户储ADL200N-CT:即插即用破解家庭光伏安装困局
  • HBase学习:通俗易懂的实例解析
  • K8S认证|CKS题库+答案| 10. Trivy 扫描镜像安全漏洞
  • Java中HashMap底层原理深度解析:从数据结构到红黑树优化
  • 人工智能 - 在Dify、Coze、n8n、FastGPT和RAGFlow之间做出技术选型
  • Excel处理控件Aspose.Cells教程:在Excel 文件中创建、操作和渲染时间线
  • 国内外UI自动化测试工具全景分析:国产创新与国际领先工具对比
  • Rougamo.Fody 实现一个AOP日志
  • UI框架-通知组件
  • TMC2226超静音步进电机驱动控制模块
  • 高抗扰度汽车光耦合器的特性
  • 渗透实战PortSwigger Labs指南:自定义标签XSS和SVG XSS利用
  • sshd代码修改banner
  • 开发一套外卖系统软件需要多少钱?
  • 简单介绍C++中 string与wstring
  • 动手学深度学习13.3. 目标检测和边界框-笔记练习(PyTorch)
  • 神经网络学习-神经网络简介【Transformer、pytorch、Attention介绍与区别】
  • 盲盒一番赏小程序:引领盲盒新潮流
  • [免费]微信小程序问卷调查系统(SpringBoot后端+Vue管理端)【论文+源码+SQL脚本】