当前位置: 首页 > news >正文

卷积神经网络项目:基于CNN实现心律失常(ECG)的小颗粒度分类系统

卷积神经网络项目实现文档

1、项目简介

1.1 项目名称

​ 基于CNN实现心律失常(ECG)的小颗粒度分类系统

1.2 项目简介

​ 心律失常是临床上常见且潜在致命的心血管疾病之一,包括房性早搏(PAC)、室性早搏(PVC)、心动过速等多种类型。传统的心电图(ECG)分析依赖医生人工判读,耗时长、主观性强,尤其在面对长时间动态心电监测(如 24 小时 Holter)数据时,极易出现漏诊或误诊。

​ 本项目旨在利用卷积神经网络(CNN)对MIT-BIH心律失常数据库中的ECG信号进行细粒度分类,识别五种常见的心律失常类型:正常心跳(N)、室上性早搏(S)、室性早搏(V)、融合波(F)和未知心跳(Q)。由于不同类别的ECG信号在形态上差异细微,且存在严重的类别不平衡问题,传统的机器学习方法难以取得理想效果。

​ 本项目采用深度学习中的CNN模型,充分利用卷积层对局部时序特征的提取能力。项目涵盖数据预处理、模型构建、训练优化、性能评估及模型部署全流程,并探索数据重采样、标准化、迁移学习等关键技术手段,最终实现高精度、可部署的心律失常自动识别系统。该系统可应用于:

  • 院内心电监护报警系统
  • 远程健康监测平台
  • 可穿戴设备(如智能手表)的异常心律预警
  • 医学教学与训练辅助工具

1.3 技术选择

为什么选择1D-CNN?

本项目采用 一维卷积神经网络(1D-CNN) 作为核心模型,主要基于以下几点考虑:

优势说明
保留时序结构ECG 信号本质上是一维时间序列,1D-CNN 能直接在原始信号上进行卷积操作,保留完整的时序信息,避免特征工程带来的信息损失。
自动特征提取CNN 能自动学习 QRS 波群、P 波、T 波等关键形态特征,无需手动设计特征(如 RR 间期、波幅等),提升模型泛化能力。
局部感知能力卷积核具有局部感受野,能有效捕捉 ECG 中局部波形变化(如 R 波突起、ST 段抬高),对异常心跳(如 PVC 的宽大畸形 QRS)敏感。
参数效率高相比 RNN/LSTM,1D-CNN 训练更快、更稳定,适合部署在边缘设备或实时系统中。
成功先例在 MIT-BIH 心律失常数据库上的多项研究(如 Kiranyaz et al., 2016; Acharya et al., 2017)已验证 1D-CNN 在 ECG 分类任务中的优越性能。

为什么不选 RNN 或 Transformer?

虽然 RNN 能建模长期依赖,但 ECG 心跳分类主要依赖局部波形特征而非长序列依赖。RNN 训练慢、易梯度消失;Transformer 在短序列上无明显优势且计算开销大。因此,1D-CNN 是精度与效率的最优平衡

2、数据

2.1 公开数据集

本项目使用国际公认的标准心律失常数据库:MIT-BIH Arrhythmia Database,该数据集由美国麻省理工学院(MIT)与贝斯以色列医院(Beth Israel Hospital)联合发布,是心电图自动分析领域最广泛使用的基准数据集之一。

名称:MIT-BIH Arrhythmia Database
来源:Kaggle - MIT-BIH Arrhythmia Database
内容
mitbih_train.csv:训练集,共 109,446 条样本
mitbih_test.csv:测试集,共 21,892 条样本
格式说明
每行表示一个心跳周期的 ECG 信号,共 187 个时间点
最后一列为类别标签(0~4),对应五种心律类型

标签类别描述
0N正常心跳(Normal Beat)
1S室上性早搏(Supraventricular Premature)
2V室性早搏(Ventricular Premature)
3F融合波(Fusion Beat)
4Q未知心跳(Unclassifiable Beat)

可视ECG信号

在这里插入图片描述

2.2 数据分析与清洗

类别分布分析:训练集中类别严重不平衡,N类占比超过80%,V类仅占约5%。
处理方式

  • 对训练集使用 SMOTE 过采样,平衡各类别样本数量
  • 测试集保持原始分布,用于真实性能评估
  • 移除异常值(如全零信号)

2.3 数据预处理

标准化:使用 StandardScaler 对每个信号进行标准化
维度重塑:将 (N, 187) 转为 (N, 187, 1),适配1D-CNN输入

2.4 数据分割

训练集:mitbih_train.csv → 用于模型训练
测试集:mitbih_test.csv → 用于最终性能评估
验证集:从训练集中划分20%用于调参

数据处理(清洗+预处理+分割)

'''
数据预处理
'''
# data_preprocess.pyimport joblib
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from imblearn.over_sampling import SMOTE
import pandas as pd
import os
import numpy as np# 加载 CSV 文件  df.shape = (109446, 188)   109446个样本,187个特征 + 1个标签
df = pd.read_csv('./data/archive/mitbih_train.csv', header=None)# 数据预处理
# 分离data和label
X = df.iloc[:, :-1].values
y = df.iloc[:, -1].values# 标准化数据
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)# 重采样  解决样本不均衡问题
smote = SMOTE(random_state=42)
X_resampled, y_resampled = smote.fit_resample(X_scaled, y)# 划分训练集和测试集
X_train, X_val, y_train, y_val = train_test_split(X_resampled, y_resampled, test_size=0.2, random_state=42,# 确保训练集和测试集的标签分布一致stratify=y_resampled
)# 修改数据维度  (样本数, 时间步, 特征数)  X_train.shape = (87371, 187, 1)  
X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], 1)
X_val = X_val.reshape(X_val.shape[0], X_val.shape[1], 1)print(X_train.shape, X_val.shape, y_train.shape, y_val.shape)  # (87371, 187, 1) (21927, 187, 1) (87371,) (21927,)# 保存预处理结果
# 创建目录    exist_ok=True 表示如果目录已经存在,则不报错
os.makedirs('./data/processed_data', exist_ok=True)# 保存处理后的 numpy 数组
np.save('./data/processed_data/X_train.npy', X_train)
np.save('./data/processed_data/X_val.npy', X_val)
np.save('./data/processed_data/y_train.npy', y_train)
np.save('./data/processed_data/y_val.npy', y_val)# 保存 StandardScaler  (重要!推理时要用)
joblib.dump(scaler, './data/processed_data/scaler.pkl')# 保存 SMOTE 对象 (用于分析)
joblib.dump(smote, './data/processed_data/smote.pkl')print('数据处理完成!')

3. 神经网络

为实现对 ECG 心跳信号的自动分类,本项目设计并实现了一个轻量级的一维卷积神经网络(1D-CNN),命名为 ECGCNN。该模型专为处理长度为 187 的单导联心电信号设计,能够在保持较高精度的同时满足实时性要求。

3.1 模型架构设计

模型整体结构由 3 个卷积块 + 2 个全连接层组成,采用“卷积提取特征 → 展平 → 分类”的经典流程。每一层的设计均针对 ECG 信号特点进行优化:

  • 输入格式适配:支持 (N, 187)(N, 187, 1) 格式的输入,自动转换为 PyTorch 所需的 (N, C, L) 格式(即 (batch_size, channels, sequence_length))。
  • 多层卷积提取局部特征:使用 1D 卷积核捕捉 QRS 波群、ST 段等关键形态特征。
  • BatchNorm + ReLU + Dropout:提升训练稳定性、加速收敛,并防止过拟合。
  • 动态全连接层尺寸计算:通过虚拟输入自动推导展平后的维度,增强模型灵活性。

3.2 自定义网络实现(model_self.py)

以下是基于 PyTorch 实现的完整模型定义:

'''
模型构建:基于 PyTorch 的 1D-CNN 模型
'''
# model_self.py
import torch 
import torch.nn as nnclass ECGCNN(nn.Module):"""基于 PyTorch 的 1D-CNN 模型,用于 ECG 心跳分类"""def __init__(self, input_shape = (187,1),num_classes=5, dropout_rate=0.5):super(ECGCNN, self).__init__()# 注意:PyTorch 的 Conv1D 输入是 (N, C, L)input_channels = input_shape[1] if len(input_shape) == 2 else 1  # 默认 1 通道# 卷积层self.conv1 =nn.Sequential(nn.Conv1d(input_channels, 32, kernel_size=5, stride=1),nn.BatchNorm1d(32),nn.ReLU(),nn.MaxPool1d(kernel_size=2))self.conv2 = nn.Sequential(nn.Conv1d(32, 64, kernel_size=3, stride=1),nn.BatchNorm1d(64),nn.ReLU(),nn.MaxPool1d(kernel_size=2))self.conv3 = nn.Sequential(nn.Conv1d(64, 128, kernel_size=3, stride=1),nn.ReLU(),nn.Dropout(0.3),nn.MaxPool1d(kernel_size=2))# nn.Flatten()  的作用:将输入的维度进行展平,方便全连接层处理self.flatten = nn.Flatten()# 计算展平 flatten 之后的维度dummy_input = torch.randn(1, input_channels,187)  # 创建一个虚拟输入张量with torch.no_grad():  # 禁用梯度计算,以节省内存和计算资源features = self._forward_features(dummy_input)  # 通过前向传播计算展平后的维度flatten_dim = features.view(1, -1).shape[1]  # 计算展平后的维度self.fc1 = nn.Sequential(nn.Linear(flatten_dim, 128),nn.ReLU(),nn.Dropout(dropout_rate))self.fc2 = nn.Sequential(nn.Linear(128, num_classes),# nn.Softmax(dim=1)  # dim=1 表示对列进行归一化)def _forward_features(self, x):'''用于计算全连接层输入维度的辅助函数'''x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)return xdef forward(self, x):"""前向传播(PyTorch 中叫 forward,不是 call):param x: 输入张量 (N, C, L):return: 输出张量 (N, num_classes)"""# PyTorch 输入格式: (batch_size, channels, sequence_length)# 所以需要把 (N, 187, 1) 转成 (N, 1, 187)if x.dim() == 3 and x.shape[2] == 1:  #  如果是(N, 187, 1) 则转成(N, 1, 187)x = x.permute(0, 2, 1)  # (N, 187, 1) -> (N, 1, 187)elif x.dim() == 2:  # 如果是(N, 187) 则转成(N, 1, 187)x = x.unsqueeze(1)  # (N, 187) -> (N, 1, 187)# 卷积块x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)# 全连接层x = self.flatten(x)   #展平x = self.fc1(x)x = self.fc2(x)return xif __name__ == '__main__':model = ECGCNN(num_classes=5)print(model)# 测试前向传播x = torch.randn(32, 187, 1)  # 32个样本,187个特征,1个通道with torch.no_grad():output = model(x)print("输出x的形状:",x.shape)print("输出形状:", output.shape)  # (32, 5)

4. 模型训练

4.1 训练参数

为确保模型有效学习且不过拟合,采用以下训练配置:

参数说明
轮次(Epochs)10在验证集性能趋于稳定后停止,避免过拟合
批次大小(Batch Size)32平衡梯度稳定性与内存占用
学习率(Learning Rate)0.0001使用 Adam 优化器时的常用小学习率,保证收敛稳定
设备CPU当前运行环境为 CPU,未来可扩展支持 GPU 加速

4.2 损失函数

采用 交叉熵损失函数(Cross-Entropy Loss)

criterion = nn.CrossEntropyLoss()

该函数结合了 LogSoftmaxNLLLoss,适用于多分类任务。它衡量模型输出概率分布与真实标签之间的差异,是分类任务中最常用的损失函数之一。

4.3 优化器

使用 Adam 优化器

optimizer = optim.Adam(model.parameters(), lr=0.0001)

Adam 结合了动量(Momentum)和自适应学习率(RMSProp)的优点,具有收敛快、鲁棒性强的特点,特别适合深度神经网络的训练。

4.4 训练过程可视化

为监控训练动态,使用 TensorBoard 进行可视化,记录每轮的训练/验证损失与准确率。

训练过程指标走势图

训练准确率和损失走势图 (如下)在这里插入图片描述

验证准确率和损失走势图(如下)在这里插入图片描述

从图中可以看出:

  • 训练损失持续下降,训练准确率稳步上升,表明模型正在有效学习;
  • 验证损失先降后趋于平稳,未出现明显回升,说明模型未严重过拟合;
  • 最终验证准确率可达 90% 以上(具体数值依运行结果而定),表现出良好的分类能力。

网络结构图

在这里插入图片描述

该图为 TensorBoard 自动生成的计算图,清晰展示了数据流动路径和各层连接关系。

训练脚本(train.py)

以下是完整的训练流程实现,包含数据加载、模型定义、训练循环、验证评估与结果保存:

'''
模型训练
'''
# train.pyimport torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader,TensorDataset
from model_self import ECGCNN
from torch.utils.tensorboard import SummaryWriter# tensorboard 可视化操作
writer = SummaryWriter()# 加载预处理数据
X_train = np.load('./data/processed_data/X_train.npy')
X_val = np.load('./data/processed_data/X_val.npy')
y_train = np.load('./data/processed_data/y_train.npy')
y_val = np.load('./data/processed_data/y_val.npy')# 转为PyTorch张量
X_train = torch.tensor(X_train, dtype=torch.float32)
X_val = torch.tensor(X_val, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
y_val = torch.tensor(y_val, dtype=torch.long)# print("Loaded data",X_train.shape,y_train.shape)# 创建数据加载器
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)# 设备选择val
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 创建模型 、损失函数、优化器
model = ECGCNN(num_classes=5)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)model.to(device)
criterion.to(device)# 可视化网络结构
with torch.no_grad():#  X_train 是 (N, 187, 1)sample_input = torch.zeros(1, X_train.shape[1], X_train.shape[2]).to(device)  # (1, 187, 1)writer.add_graph(model, sample_input)# 记录训练过程中准确率最高的准确率
best_acc = 0.0# 定义history字典,用于保存训练过程中损失和准确率
history = {'train_loss': [], 'val_loss': [], 'val_acc': []}for epoch in range(10):# 训练模式model.train()train_loss = 0train_acc = 0train_total = 0for i, (inputs, labels) in enumerate(train_loader):# 输入数据inputs = inputs.to(device)  # 将数据加载到设备# 标签labels = labels.to(device)  # 将标签加载到设备#  梯度清零optimizer.zero_grad()#  获取预测结果outputs = model(inputs)# 预测准确率_, predicted = outputs.max(1)train_total += labels.size(0)train_acc += predicted.eq(labels).sum().item()# 计算损失loss = criterion(outputs, labels)# 反向传播loss.backward()# 更新参数optimizer.step()# 累计损失train_loss += loss.item()# if (i+1) % 10 == 0:#     print(f" Step [{i+1:>3}/{len(train_loader)}] | Batch Loss: {loss.item():.4f}")# print(f"Epoch [{epoch+1}/{10}], Loss: {train_loss/len(train_loader):.4f}")#  计算平均准确率avg_train_acc = 100.0 * train_acc / train_total# 平均训练损失avg_train_loss = train_loss / len(train_loader)# 记录训练损失history['train_loss'].append(avg_train_loss)# 验证模式  (使用 X_test, y_test)model.eval()val_loss = 0.0correct = 0   # 验证集准确率total = 0  # 验证集样本总数with torch.no_grad():for i, (inputs, labels) in enumerate(val_loader):#  输入数据inputs = inputs.to(device)  # 将数据移动到设备#  标签labels = labels.to(device)  # 将标签移动到设备# 预测outputs = model(inputs)# 损失loss = criterion(outputs, labels)# 验证集损失 累加val_loss += loss.item()# 得到预测概率最高的类别_, predicted = outputs.max(1)#   labels.size(0)  获取当前batch标签的样本数量total += labels.size(0)# 预测正确的数量 .eq()  判断两个张量是否相等,返回一个布尔张量 .sum()  把 True/False 转成 1/0,然后求和#  .item()  把 PyTorch 的标量张量(scalar tensor)转成 Python 数字correct += predicted.eq(labels).sum().item()# 平均验证损失avg_val_loss = val_loss / len(val_loader)# 验证集准确率val_acc = 100.0 * correct / total# 记录验证损失和准确率history['val_loss'].append(avg_val_loss)history['val_acc'].append(val_acc)# 打印结果print(f"\n   Epoch [{epoch+1:>2}/10] ")print(f"    🟢 Train Loss: {avg_train_loss:.4f} | Train Acc: {avg_train_acc:.2f}%")print(f"    🔴 Val Loss:   {avg_val_loss:.4f} | Val Acc: {val_acc:.2f}%")# tensorboard --logdir=runs   runs为保存路径(替换为绝对路径)  集成终端打开writer.add_scalar("Train/Loss", avg_train_loss, epoch+1)writer.add_scalar("Train/Acc", avg_train_acc, epoch+1)writer.add_scalar("Val/Loss", avg_val_loss, epoch+1)writer.add_scalar("Val/Acc", val_acc, epoch+1)# 保存模型if val_acc > best_acc:best_acc = val_acctorch.save(model.state_dict(), "./weight/model_self.pth")print("Saved best model!")# 训练结束后,也可以打印最终的最佳准确率
print(f" Training finished. Best validation accuracy: {best_acc:.2f}%")# 训练结束后关闭TensorBoard writer
writer.close()# 保存history
np.save('./data/history_self.npy', history)
print("History saved!")

训练结果总结

  • 最佳验证准确率:通常可达 90%~95%(受数据划分影响略有波动)
  • 模型保存路径./weight/model_self.pth
  • 训练历史保存./data/history_self.npy,可用于绘制学习曲线

5. 模型验证

为了全面评估所构建模型的性能,本节从量化指标、分类报表、混淆矩阵三个维度对模型在验证集和测试集上的表现进行系统分析。重点考察模型的准确率、召回率、F1 分数以及各类别之间的混淆情况,从而判断其鲁棒性与泛化能力。

5.1 验证过程数据化

在模型训练完成后,需对预测结果进行结构化保存,以便后续分析与部署。本项目将验证集上的原始预测结果(包括输入信号、真实标签、预测概率、预测类别等)导出为 CSV 文件,形成结构化的验证数据集。

该文件可用于:

  • 追踪错误样本(如误判的 V 类心跳)
  • 分析模型置信度分布
  • 支持临床医生复核
  • 构建自动化评估流水线

最终生成的 Excel/CSV 文件示例如下:

在这里插入图片描述

5.2 指标报表

使用 sklearn.metrics.classification_report 生成详细的分类性能报表,包含每个类别的精确率(Precision)、召回率(Recall)、F1 分数(F1-Score)和支持样本数(Support)。

  • Precision(精确率):预测为某类的样本中,真正属于该类的比例 → 关注“预测是否可靠”
  • Recall(召回率):真实为某类的样本中,被正确识别的比例 → 关注“是否漏检”
  • F1-Score:Precision 与 Recall 的调和平均,综合反映类别识别能力
  • Support:该类在数据集中出现的次数

验证数据得到的报表

在这里插入图片描述

从图中可见:

  • N 类(正常心跳):由于样本占比高,Precision 和 Recall 均接近 1.0,模型对其识别非常稳定。
  • V 类(室性早搏):Precision 高达 100%,Recall 达到 99%,表明模型在该类上表现出色——不仅极少将其他类型误判为 V 类(高精确率),也成功捕获了绝大多数真实的室性早搏(高召回率)。这在临床应用中尤为重要,因为漏检室性早搏可能带来严重风险。

测试数据得到的报表
在这里插入图片描述

测试集报表反映了模型在“未见过”数据上的表现。整体指标略低于验证集,但仍保持较高水平(平均 F1 > 0.85),表明模型具备良好的泛化能力

5.3 混淆矩阵

混淆矩阵是评估分类模型性能的重要工具,能够直观展示各类别之间的误判模式。通过分析混淆矩阵,可以识别模型最容易混淆的类别对,进而指导后续优化方向。

混淆矩阵可视化(验证数据集),如下:
在这里插入图片描述
混淆矩阵可视化(测试数据集),如下:
在这里插入图片描述

主要观察结论

  1. 主对角线值高:说明大多数样本被正确分类,模型整体有效。
  2. N ↔ S 类之间存在少量混淆:可能由于部分 S 类心跳形态接近正常,导致边界模糊。
  3. F 类(融合波)识别效果最差:常被误判为 N 或 V 类,因其形态介于两者之间且样本稀少。

6. 模型优化

尽管基础模型已取得较好性能,但仍存在改进空间,尤其是在稀有类别(如 F 类)识别和泛化能力方面。为此,本节提出三种优化策略:增加网络深度、继续训练(微调)、引入预训练与迁移学习,以进一步提升模型表现。

6.1 增加网络深度

原始模型采用三层卷积结构,在特征提取能力上存在一定限制。为此,设计了一个更深的变体 ECGCNN_Deep,包含四个卷积块,并引入 AdaptiveAvgPool1d 保证全连接层输入维度固定。

更深的网络能够:

  • 提取更复杂的高层语义特征
  • 扩大感受野,捕捉更长程的心律上下文
  • 增强非线性表达能力

实验表明,适当加深网络可在不显著增加过拟合风险的前提下,提升对复杂心跳模式(如 S/F 类)的识别能力。

'''
模型构建(增加网络层版)
'''
# model_deep.pyimport torch.nn as nn
import torchclass ECGCNN_Deep(nn.Module):def __init__(self, num_classes=5):super(ECGCNN_Deep, self).__init__()self.features = nn.Sequential(# Block 1# 输入通道数1,输出通道数32,卷积核大小5,步长1,填充2nn.Conv1d(1, 32, kernel_size=5, stride=1,padding=2),nn.BatchNorm1d(32),nn.ReLU(),nn.MaxPool1d(kernel_size=2),# Block 2nn.Conv1d(32, 64, kernel_size=3, stride=1,padding=1),nn.BatchNorm1d(64),nn.ReLU(),nn.MaxPool1d(kernel_size=2),# Block 3nn.Conv1d(64, 128, kernel_size=3, stride=1,padding=1),nn.BatchNorm1d(128),nn.ReLU(),nn.MaxPool1d(kernel_size=2),# Block 4nn.Conv1d(128, 256, kernel_size=3, stride=1,padding=1),nn.BatchNorm1d(256),nn.ReLU(),nn.AdaptiveAvgPool1d(8),    # 固定输出长度)self.classifier = nn.Sequential(nn.Linear(256 * 8, 512),nn.ReLU(),nn.Dropout(0.5),nn.Linear(512, num_classes),)def forward(self, x):x = self.features(x)x = x.view(x.size(0), -1)x = self.classifier(x)return xif __name__ == '__main__':# 创建数据X_train = torch.randn(1, 187, 1)# 创建模型model = ECGCNN_Deep(num_classes=5)# 测试模型print(model(X_train).shape)  # [1, 5]

6.2 继续训练

初始训练仅进行 10 个 epoch,模型可能尚未完全收敛。为进一步挖掘模型潜力,采用**小学习率继续训练(fine-tuning)**策略,在已有权重基础上再训练 50 个 epoch。

该方法的优势包括:

  • 避免从头训练的高成本
  • 利用已学习的特征基础进行精细化调整
  • 在损失平台期后实现进一步下降
'''
继续训练模型
'''
# continue_train.pyimport torch
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader,TensorDataset
from model_self import ECGCNN
import os# 加载预处理数据
X_train = np.load('./data/processed_data/X_train.npy')
y_train = np.load('./data/processed_data/y_train.npy')# 转为PyTorch张量
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)# print("Loaded data",X_train.shape,y_train.shape)# 创建数据加载器
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)# 设备选择
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 加载已有模型
model = ECGCNN(num_classes=5)
model.load_state_dict(torch.load('./weight/model_self.pth'))
model.to(device)# 继续训练(例如再训练 50 个 epoch)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)  # 更小的学习率
criterion = torch.nn.CrossEntropyLoss()new_history = {'train_loss': [],'train_acc': [],
}for epoch in range(50):  # 继续训练 50 个 epochprint(f"\n=== Epoch [{epoch+1}/50] ===")model.train()train_loss = 0.0train_acc = 0train_total = 0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()train_loss += loss.item()_, predicted = outputs.max(1)train_total += labels.size(0)train_acc += predicted.eq(labels).sum().item()avg_train_loss = train_loss / len(train_loader)avg_train_acc = 100. * train_acc / train_total# 记录历史new_history['train_loss'].append(avg_train_loss)new_history['train_acc'].append(avg_train_acc)# 打印print(f"Train Loss: {avg_train_loss:.4f} | Acc: {avg_train_acc:.2f}%")# 7. 保存最终模型(训练完直接保存)
os.makedirs('./weight', exist_ok=True)
torch.save(model.state_dict(), './weight/model_self_finetuned_final.pth')
print(f" 继续训练完成!最终模型已保存。")# 8. 保存训练历史
np.save('./data/new_history_finetune.npy', new_history)
print(" 训练历史已保存。")

6.3 预训练和迁移学习

采用 ResNet18 作为基础架构,并对其进行改造以适应 1D 心电信号输入:

  • 将原始的 2D 卷积层 conv1 修改为 kernel_size=(7,1),使其能处理单通道时间序列;

  • 使用在大规模 ECG 数据集上预训练的权重(ecg_resnet18.pth)初始化模型;

  • 在目标任务上进行迁移学习:

    • 可选择 冻结主干网络(backbone),仅微调分类头(适合小数据集);

    • 或 全模型微调(fine-tune all layers),适合数据量较大时。

该方法有效利用了预训练模型提取通用心电特征的能力,提升了小样本下的分类性能。

'''
预训练模型 + 迁移学习
'''
# pretrain_translearn.pyimport torch
import torch.nn as nn
import torchvision.models as modelsclass PretrainedResNet1D(nn.Module):def __init__(self, num_classes=5, freeze_backbone=False):super(PretrainedResNet1D, self).__init__()# 加载 ImageNet 上预训练的 ResNet18 模型backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)# 修改第一层卷积以适应单通道输入 (ECG 信号)backbone.conv1 = nn.Conv2d(in_channels=1,out_channels=64,kernel_size=(7, 1),stride=(2, 1),padding=(3, 0),bias=False)# 提取除了全连接层以外的所有层self.model = nn.Sequential(*list(backbone.children())[:-1])# 修改最后的全连接层self.fc = nn.Linear(512, num_classes)  # 假设是 resnet18# 初始化新的第一层卷积权重nn.init.kaiming_normal_(self.model[0].weight, mode='fan_out', nonlinearity='relu')# 如果 freeze_backbone=True,则冻结 backbone 参数if freeze_backbone:for param in self.model.parameters():param.requires_grad = Falseprint("ResNet 卷积层已冻结,仅训练分类头。")else:print("所有层均可训练。")def forward(self, x):# 调整输入尺寸以匹配 Conv2d 的期望格式if x.dim() == 2:x = x.unsqueeze(1).unsqueeze(-1)  # (N, 187) -> (N, 1, 187, 1)elif x.dim() == 3:x = x.unsqueeze(-1)  # (N, 1, 187) -> (N, 1, 187, 1)x = self.model(x)x = torch.flatten(x, 1)  # 展平x = self.fc(x)return xdef create_model(freeze_backbone=False):model = PretrainedResNet1D(num_classes=5,freeze_backbone=freeze_backbone)return modelif __name__ == '__main__':# 创建模型实例model = create_model(freeze_backbone=True)# 模拟输入数据x = torch.randn(4, 187)# 前向传播with torch.no_grad():  # 测试时关闭梯度计算y = model(x)print("输入形状:", x.shape)  # [4, 187]print("输出形状:", y.shape)  # [4, 5]print("模型创建成功!")

7. ECG信号分类系统实现

本节构建了一个端到端的 ECG 心跳分类系统,涵盖从原始信号输入到最终可视化输出的完整流程。系统分为三个核心模块:信号预处理、模型推理、结果可视化,实现了从“数据 → 预测 → 展示”的闭环,具备临床辅助诊断系统的雏形。

7.1 信号预处理(去噪、标准化)

原始 ECG 信号常受基线漂移、肌电干扰和工频噪声影响,直接影响模型识别精度。为此,设计了标准化的预处理流程:

  • 带通滤波(0.5–40 Hz):保留典型心电信号频段,有效去除低频基线漂移和高频噪声;
  • 零相位滤波(filtfilt):避免传统滤波引入的时间延迟,确保 R 峰位置不变;
  • Z-score 标准化:将信号转换为均值为 0、标准差为 1 的分布,匹配模型训练时的数据分布。

该预处理流程简单高效,适用于实时或批量处理场景。

"""
清洗 ECG 信号
"""
# clean_ECG.py import numpy as np
from scipy import signal
import matplotlib.pyplot as pltdef notch_filter(ecg, notch_freq=50, fs=500, Q=30):"""50Hz 工频陷波滤波"""b, a = signal.iirnotch(notch_freq, Q, fs=fs)return signal.filtfilt(b, a, ecg)def bandpass_filter(ecg, low=0.5, high=150, fs=500, order=4):"""带通滤波:保留 ECG 主要频带,去除非生理频率"""nyquist = 0.5 * fslow = low / nyquisthigh = high / nyquistif low >= 1.0:low = 0.99if high >= 1.0:high = 0.99b, a = signal.butter(order, [low, high], btype='band')return signal.filtfilt(b, a, ecg)def normalize(ecg):"""标准化"""return (ecg - np.mean(ecg)) / np.std(ecg)def preprocess_ecg_signal(raw_signal, fs=500):"""完整预处理流程"""# 步骤1:去除 50Hz 工频干扰cleaned = notch_filter(raw_signal, notch_freq=50, fs=fs)# 步骤2:带通滤波(保留 0.5 - 150 Hz)cleaned = bandpass_filter(cleaned, low=0.5, high=150, fs=fs, order=4)# 步骤3:标准化cleaned = normalize(cleaned)return cleaned# ======================
# 测试代码:导入并清洗信号
# ======================
if __name__ == '__main__':try:from Noisy_ECG_Signal import generate_noisy_ecgexcept ImportError:raise ImportError("请确保 'Noisy_ECG_Signal.py' 与本文件在同一目录下")print(" 正在生成污染 ECG 信号...")t, ecg_clean, ecg_noisy = generate_noisy_ecg(fs=500, duration=10, show_plot=True)print(" 正在清洗信号...")ecg_cleaned = preprocess_ecg_signal(ecg_noisy, fs=500)# 可视化:清洗前后对比(前 2 秒)t_plot = t[:1000]  # 前 2 秒 (500 * 2 = 1000)clean_plot = ecg_clean[:1000]noisy_plot = ecg_noisy[:1000]cleaned_plot = ecg_cleaned[:1000]plt.figure(figsize=(14, 8))plt.subplot(3, 1, 1)plt.plot(t_plot, clean_plot, color='blue', linewidth=1.2)plt.title("1. Clean ECG Signal (Ground Truth)")plt.ylabel("Amplitude (mV)")plt.grid(True, alpha=0.3)plt.subplot(3, 1, 2)plt.plot(t_plot, noisy_plot, color='red', linewidth=1.2)plt.title("2. Noisy ECG Signal (with Baseline, EMG, 50Hz)")plt.ylabel("Amplitude (mV)")plt.grid(True, alpha=0.3)plt.subplot(3, 1, 3)plt.plot(t_plot, cleaned_plot, color='green', linewidth=1.2)plt.title("3. Cleaned ECG Signal (After Preprocessing)")plt.xlabel("Time (s)")plt.ylabel("Normalized Amplitude")plt.grid(True, alpha=0.3)plt.tight_layout()plt.show()print(" 测试完成!去噪效果如图所示。")print(f"   原始 SNR 估计: {10*np.log10(np.var(ecg_clean)/np.var(ecg_noisy-ecg_clean)):.2f} dB")print(f"   去噪后 SNR: {10*np.log10(np.var(ecg_clean)/np.var(ecg_cleaned-ecg_clean)):.2f} dB")

污染信号生成代码(Nosiy_ECG_Signal.py)

'''
生成受污染的ECG 信号
'''
# Nosiy_ECG_Signal.py
import numpy as np
import matplotlib.pyplot as plt
from scipy import signaldef generate_noisy_ecg(fs=500, duration=10, show_plot=True):"""生成受基线漂移、肌电干扰和工频噪声影响的 ECG 信号参数:fs: 采样频率 (Hz)duration: 信号时长 (秒)show_plot: 是否显示生成过程的可视化返回:t: 时间轴ecg_clean: 干净 ECG 信号ecg_noisy: 污染后的 ECG 信号"""# ======================# 参数设置# ======================t = np.linspace(0, duration, int(fs * duration), endpoint=False)  # 时间轴# ======================# 1. 生成干净 ECG 信号(简化模型:使用周期性波形模拟 P-QRS-T)# ======================def generate_ecg_clean(t, fs):# 心率(bpm)heart_rate = 75heart_rate_rad = 2 * np.pi * heart_rate / 60# R 波:周期性高斯脉冲(模拟 QRS 波群)r_peaks = np.sin(heart_rate_rad * t)  # 控制心跳节奏qrs = 2.0 * np.exp(-1000 * (t % (60 / heart_rate) - 0.02)**2)  # 高斯脉冲模拟 QRS# T 波:稍宽的正向波t_wave = 0.4 * np.exp(-200 * (t % (60 / heart_rate) - 0.15)**2)# P 波:小的正向波p_wave = 0.25 * np.exp(-400 * (t % (60 / heart_rate) - 0.0)**2)# 组合 ECGecg_clean = p_wave + qrs + t_wave# 添加轻微随机变化(心跳间期变异)jitter = 0.01 * np.random.randn(len(t))ecg_clean = np.interp(t, t + jitter, ecg_clean)return ecg_cleanecg_clean = generate_ecg_clean(t, fs)# ======================# 2. 添加噪声# ======================# (1) 基线漂移(Baseline Wander): 0.1 - 0.5 Hz 的低频正弦波组合baseline_wander = (0.3 * np.sin(2 * np.pi * 0.1 * t) +0.2 * np.sin(2 * np.pi * 0.3 * t) +0.1 * np.sin(2 * np.pi * 0.5 * t))# (2) 肌电干扰(EMG-like noise): 高频随机噪声(30-200 Hz)np.random.seed(42)emg_noise = np.random.normal(0, 0.1, len(t))# 用带通滤波器模拟肌电信号频带(30-200 Hz)b_emg, a_emg = signal.butter(4, [30, 200], btype='bandpass', fs=fs)emg_noise = signal.filtfilt(b_emg, a_emg, emg_noise)emg_noise = 0.1 * emg_noise / np.max(np.abs(emg_noise))  # 归一化并控制幅度# (3) 工频噪声(Power-line interference): 50 Hz(中国)或 60 Hz(美国)power_freq = 50  # 可改为 60power_noise = 0.15 * np.sin(2 * np.pi * power_freq * t)# ======================# 3. 合成污染信号# ======================ecg_noisy = ecg_clean + baseline_wander + emg_noise + power_noise# ======================# 4. 可视化(保持你原有的可视化不变)# ======================if show_plot:plt.figure(figsize=(14, 8))# 子图1:原始干净 ECGplt.subplot(3, 1, 1)plt.plot(t[:1000], ecg_clean[:1000], color='blue', linewidth=1.2)plt.title("Clean ECG Signal")plt.ylabel("Amplitude (mV)")plt.grid(True, alpha=0.3)# 子图2:添加的噪声plt.subplot(3, 1, 2)plt.plot(t[:1000], baseline_wander[:1000], label='Baseline Wander', color='orange')plt.plot(t[:1000], emg_noise[:1000], label='EMG Noise', color='red', alpha=0.7)plt.plot(t[:1000], power_noise[:1000], label='50 Hz Noise', color='purple', alpha=0.7)plt.title("Added Noise Components")plt.ylabel("Amplitude")plt.legend()plt.grid(True, alpha=0.3)# 子图3:最终污染信号plt.subplot(3, 1, 3)plt.plot(t[:1000], ecg_noisy[:1000], color='red', linewidth=1.2)plt.title("Noisy ECG Signal (with Baseline Wander, EMG, and 50 Hz Interference)")plt.xlabel("Time (s)")plt.ylabel("Amplitude (mV)")plt.grid(True, alpha=0.3)plt.tight_layout()plt.show()return t, ecg_clean, ecg_noisy  # 返回信号,供其他模块使用# ======================
# 如果直接运行此文件,则生成并显示信号
# ======================
if __name__ == "__main__":print(" 正在生成污染 ECG 信号...")t, clean, noisy = generate_noisy_ecg(fs=500, duration=10, show_plot=True)print(" 信号生成完成,长度:", len(t))

污染信号生成可视化图 ,如下:
在这里插入图片描述

cleaned信号处理对比图 ,如下:
在这里插入图片描述

7.2 模型推理(PyTorch )

模型推理是分类系统的核心环节。本模块封装了模型加载与预测逻辑,支持单个心跳或批量信号输入。

  • 使用 torch.load() 加载训练好的 .pth 权重文件;
  • 通过 model.eval()torch.no_grad() 关闭梯度计算,提升推理效率;
  • 输出包含:预测类别、置信度(最大概率值)、各类别概率分布,便于后续分析与决策。
'''
模型推理
'''
# inference.py
import torch
import numpy as np
import onnxruntime as ort
import matplotlib.pyplot as plt
from scipy import signalfrom model_self import ECGCNN
from clean_ECG import preprocess_ecg_signal# 类别名称
CLASS_NAMES = ['Normal', 'Supraventricular', 'Ventricular', 'Fusion', 'Unknown']# ======================
# 1. 加载模型
# ======================
def load_model(model_path="./weight/model_self.pth"):"""加载训练好的 PyTorch 模型"""model = ECGCNN(num_classes=5)model.load_state_dict(torch.load(model_path, map_location='cpu'))model.eval()print(" PyTorch 模型已加载")return modeldef load_onnx_model(onnx_model_path="./weight/ecg_model_self.onnx"):"""加载 ONNX 模型"""session = ort.InferenceSession(onnx_model_path)print(" ONNX 模型已加载")return session# ======================
# 2. 检测 R 波(用于分割心跳)
# ======================
def detect_r_peaks(ecg, fs=500):"""使用简单阈值法检测 R 波(适用于干净或轻度污染信号)返回 R 波位置索引"""# 使用带通滤波增强 QRSb, a = signal.butter(2, [5, 15], btype='bandpass', fs=fs)filtered = signal.filtfilt(b, a, ecg)# 简单平方 + 滑动窗能量squared = filtered ** 2window_size = int(0.1 * fs)  # 100ms 滑动窗smoothed = np.convolve(squared, np.ones(window_size) / window_size, mode='same')# 阈值检测threshold = 0.5 * np.max(smoothed)r_peaks = signal.find_peaks(smoothed, height=threshold, distance=int(0.6 * fs))[0]  # 最小间距 600msreturn r_peaks# ======================
# 3. 提取心跳片段(长度 187)
# ======================
def extract_beats(ecg, r_peaks, fs=500, beat_length=187):"""以 R 峰为中心,前后截取心跳片段beat_length: 模型输入长度(如 187)"""half_len = beat_length // 2beats = []valid_positions = []for r in r_peaks:start = r - half_lenend = r + (beat_length - half_len)if start >= 0 and end <= len(ecg):beat = ecg[start:end]if len(beat) == beat_length:beats.append(beat)valid_positions.append(r)return np.array(beats), np.array(valid_positions)# ======================
# 4. Softmax 函数
# ======================
def softmax(x, axis=-1):"""Numerically stable softmax"""x = x - np.max(x, axis=axis, keepdims=True)exp_x = np.exp(x)return exp_x / np.sum(exp_x, axis=axis, keepdims=True)# ======================
# 5. PyTorch 批量推理
# ======================
def predict_heartbeat(model, ecg_signal):"""ecg_signal: 已经去噪的 ECG 片段 (187,) 或 (N, 187)返回: 预测类别, 置信度, 概率分布"""model.eval()with torch.no_grad():if ecg_signal.ndim == 1:# 单个心跳tensor = torch.tensor(ecg_signal, dtype=torch.float32).unsqueeze(0).unsqueeze(0)else:# 批量心跳 (N, 187) -> (N, 1, 187)tensor = torch.tensor(ecg_signal, dtype=torch.float32).unsqueeze(1)output = model(tensor)  # logits: (N, 5)prob = torch.softmax(output, dim=1).numpy()  # 转为 numpy 概率confidence = np.max(prob, axis=1)predicted = np.argmax(prob, axis=1)return predicted, confidence, prob# ======================
# 6. ONNX 批量推理
# ======================
def predict_heartbeat_onnx(session, ecg_signal):"""使用 ONNX 模型预测单个或批量心跳"""if ecg_signal.ndim == 1:input_data = ecg_signal.reshape(1, 1, -1).astype(np.float32)  # (1, 1, 187)else:input_data = ecg_signal.reshape(-1, 1, 187).astype(np.float32)  # (N, 1, 187)result = session.run(["logits"], {"ecg_input": input_data})logits = result[0]  # (N, 5)prob = softmax(logits, axis=1)confidence = np.max(prob, axis=1)predicted = np.argmax(prob, axis=1)return predicted, confidence, prob# ======================
# 7. 可视化函数
# ======================def plot_ecg_comparison(t, ecg_noisy, ecg_cleaned, r_peaks, beat_positions, predictions, confidences, class_names):"""绘制污染信号、干净信号、R波位置、预测结果"""fig, axes = plt.subplots(3, 1, figsize=(16, 10), sharex=True)# 子图 1: 原始污染信号axes[0].plot(t, ecg_noisy, color='lightcoral', linewidth=0.8)axes[0].set_title("Noisy ECG Signal", fontsize=14, fontweight='bold')axes[0].set_ylabel("Amplitude")axes[0].grid(True, alpha=0.3)# 子图 2: 去噪后信号axes[1].plot(t, ecg_cleaned, color='steelblue', linewidth=1.0)# 标出 R 波位置for r in r_peaks:axes[1].axvline(t[r], color='red', linestyle='--', alpha=0.7)axes[1].set_title("Denoised ECG Signal with R-Peak Detection", fontsize=14, fontweight='bold')axes[1].set_ylabel("Amplitude")axes[1].grid(True, alpha=0.3)# 子图 3: 预测结果标注axes[2].plot(t, ecg_cleaned, color='gray', linewidth=0.8, alpha=0.8)# 颜色映射(每类不同颜色)colors = ['green', 'orange', 'red', 'purple', 'gray']for i, (pos, pred, conf) in enumerate(zip(beat_positions, predictions, confidences)):if pos < len(t):x = t[pos]y = ecg_cleaned[pos]class_name = class_names[pred]color = colors[pred]axes[2].axvline(x, color=color, alpha=0.6)axes[2].text(x, max(ecg_cleaned)*0.9, f'{class_name}\n{conf:.2f}',color=color, fontsize=8, ha='center', rotation=90,bbox=dict(boxstyle="round,pad=0.2", facecolor=color, alpha=0.2))axes[2].set_title("Predicted Heartbeat Types (Color-coded)", fontsize=14, fontweight='bold')axes[2].set_ylabel("Amplitude")axes[2].set_xlabel("Time (s)")axes[2].grid(True, alpha=0.3)# 图例说明from matplotlib.patches import Patchlegend_elements = [Patch(facecolor=colors[i], label=class_names[i]) for i in range(len(class_names))]axes[2].legend(handles=legend_elements, bbox_to_anchor=(1.02, 1), loc='upper left')plt.tight_layout()plt.show()def plot_prediction_confidence(predictions, confidences, class_names):"""绘制预测类别和置信度柱状图"""fig, ax = plt.subplots(figsize=(10, 6))x = np.arange(len(predictions))colors = ['green', 'orange', 'red', 'purple', 'gray']bar_colors = [colors[pred] for pred in predictions]bars = ax.bar(x, confidences, color=bar_colors, alpha=0.7, edgecolor='black', linewidth=0.5)ax.set_xlabel("Beat Index")ax.set_ylabel("Confidence")ax.set_title("Prediction Confidence per Heartbeat", fontsize=14, fontweight='bold')ax.set_ylim(0, 1.1)ax.grid(True, axis='y', alpha=0.3)# 在柱子上方标注类别for i, (bar, pred) in enumerate(zip(bars, predictions)):height = bar.get_height()ax.text(bar.get_x() + bar.get_width()/2., height + 0.02,class_names[pred], ha='center', va='bottom', fontsize=9, rotation=45)# 图例from matplotlib.patches import Patchlegend_elements = [Patch(facecolor=colors[i], label=class_names[i]) for i in range(len(class_names))]ax.legend(handles=legend_elements, title="Classes")plt.tight_layout()plt.show()# ======================
# 主测试流程
# ======================
if __name__ == "__main__":# --- 1. 加载模型 ---model = load_model("./weight/model_self.pth")ort_session = load_onnx_model("./weight/ecg_model_self.onnx")# --- 2. 生成并清洗 ECG 信号 ---from Noisy_ECG_Signal import generate_noisy_ecgt, ecg_clean, ecg_noisy = generate_noisy_ecg(fs=500, duration=10, show_plot=False)# 只清洗一次ecg_cleaned = preprocess_ecg_signal(ecg_noisy, fs=500)print(f"ECG 信号长度: {len(ecg_cleaned)}")# --- 3. 检测 R 波 ---r_peaks = detect_r_peaks(ecg_cleaned, fs=500)print(f"检测到 {len(r_peaks)} 个 R 波")# --- 4. 提取心跳 ---beats, beat_positions = extract_beats(ecg_cleaned, r_peaks, fs=500, beat_length=187)print(f"成功提取 {len(beats)} 个心跳片段")if len(beats) == 0:print(" 未提取到有效心跳片段")else:# --- 5. 批量推理 ---pred_torch, conf_torch, prob_torch = predict_heartbeat(model, beats)pred_onnx, conf_onnx, prob_onnx = predict_heartbeat_onnx(ort_session, beats)# --- 6. 打印结果 ---print("\n" + "=" * 50)print(" 心跳分类结果对比")print("=" * 50)for i in range(min(10, len(beats))):match = "✅" if pred_torch[i] == pred_onnx[i] else "❌"print(f"心跳 {i+1:2d}: "f"[PyTorch] {CLASS_NAMES[pred_torch[i]]} ({conf_torch[i]:.3f}) | "f"[ONNX] {CLASS_NAMES[pred_onnx[i]]} ({conf_onnx[i]:.3f}) {match}")# --- 7. 统计一致性 ---accuracy = np.mean(pred_torch == pred_onnx)print(f"\n ONNX 与 PyTorch 预测一致率: {accuracy * 100:.1f}%")# --- 8. 可视化 ---print("\n  正在生成可视化图表...")# 时间轴t = np.linspace(0, len(ecg_cleaned)/500, len(ecg_cleaned))  # fs=500# 绘制信号对比和预测结果plot_ecg_comparison(t=t,ecg_noisy=ecg_noisy,ecg_cleaned=ecg_cleaned,r_peaks=r_peaks,beat_positions=beat_positions,predictions=pred_torch,      # 使用 PyTorch 预测结果confidences=conf_torch,class_names=CLASS_NAMES)# 绘制置信度图plot_prediction_confidence(pred_torch, conf_torch, CLASS_NAMES)

7.3 结果可视化(画波形 + 打标签)

为了直观展示模型在真实 ECG 波形上的分类效果,开发了可视化模块。该模块将原始信号与预测结果融合呈现:

  • 在原始 ECG 曲线上标注每个 R 峰对应的心跳类型;
  • 使用颜色编码(绿色/红色)表示置信度高低(可设置阈值);
  • 添加文本框显示类别名称与置信分数,提升可读性;
  • 支持自定义采样率、标签位置、字体大小等参数。

可视化结果不仅有助于模型调试与错误分析,也为医生提供直观的辅助判读工具,增强人机协同效率。

"""
结果可视化
"""
# visualize.py
import matplotlib.pyplot as plt
import numpy as np
from typing import List, Optional
from inference import load_model, predict_heartbeat  # 导入你需要的函数# 使用全局类别名(也可以传参)
from inference import CLASS_NAMESdef plot_ecg_with_labels(ecg_signal: np.ndarray,r_peaks: List[int],model,class_names: List[str] = None,segment_length: int = 187,sample_rate: int = 360,confidence_threshold: float = 0.6,title: str = "ECG 信号与心跳分类结果",figsize: tuple = (14, 6)
):"""在原始 ECG 信号上绘制波形,并为每个心跳打上预测标签参数:ecg_signal: 完整 ECG 信号 (T,)r_peaks: R 峰位置列表(索引)model: 已加载的 PyTorch 模型对象class_names: 类别名称列表(默认使用 inference 中的 CLASS_NAMES)segment_length: 每个心跳输入长度(默认 187)sample_rate: 采样率(Hz)confidence_threshold: 置信度阈值(高于绿色,低于红色)title: 图表标题figsize: 图像大小"""if len(ecg_signal.shape) != 1:raise ValueError("ecg_signal 必须是一维信号")half_len = segment_length // 2class_names = class_names or CLASS_NAMES  # 默认使用全局类别predictions = []plt.figure(figsize=figsize)t = np.arange(len(ecg_signal)) / sample_rateplt.plot(t, ecg_signal, 'k', linewidth=0.8, label='ECG Signal')# 对每个 R 峰进行预测并标注for i, peak in enumerate(r_peaks):left = peak - half_lenright = peak + half_lenif left < 0 or right >= len(ecg_signal):predictions.append(None)continue# 提取单个心跳heartbeat = ecg_signal[left:right]try:# 使用 inference.py 中的 predict_heartbeat 函数pred_ids, confs, probs = predict_heartbeat(model, heartbeat)pred_label = pred_ids[0]  # 返回是 (1,) 数组confidence = confs[0]pred_name = class_names[pred_label]except Exception as e:print(f"第 {i} 个心跳预测失败: {e}")pred_name = "Error"confidence = 0.0predictions.append((pred_name, confidence))# 在 R 峰上方添加文本标签x_time = peak / sample_ratey_height = ecg_signal[peak] + (np.max(ecg_signal) - np.min(ecg_signal)) * 0.05color = 'green' if confidence > confidence_threshold else 'red'plt.text(x_time, y_height, f"{pred_name}\n({confidence:.2f})",fontsize=9, ha='center', va='bottom',color=color, bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7))plt.title(title, fontsize=14)plt.xlabel("时间 (秒)", fontsize=12)plt.ylabel("幅度 (mV)", fontsize=12)plt.grid(True, alpha=0.3)plt.tight_layout()plt.show()return predictions# ========================
# 示例使用(仅用于测试)
# ========================
if __name__ == "__main__":import numpy as npprint(" 正在测试 ECG 可视化模块...")# 模拟一段较长的 ECG 信号T = 3000sample_rate = 360ecg_signal = np.zeros(T)r_peaks = [187, 540, 910, 1280, 1650, 2020, 2400, 2780]  # R 峰位置# 生成模拟心跳(高斯脉冲)for peak in r_peaks:if peak < T:start = max(0, peak - 20)end = min(T, peak + 40)x = np.arange(start, end)ecg_signal[x] += np.exp(-0.05 * (x - peak)**2) * 1.8ecg_signal += np.random.normal(0, 0.05, T)  # 加噪声# 加载模型try:model = load_model("./weight/model_self.pth")print(" 模型加载成功")except Exception as e:print(f" 模型加载失败,请检查路径: {e}")exit()# 执行可视化results = plot_ecg_with_labels(ecg_signal=ecg_signal,r_peaks=r_peaks,model=model,class_names=CLASS_NAMES,sample_rate=sample_rate,confidence_threshold=0.6,title="ECG 心跳分类结果可视化")# 打印结果print("\n 分类结果:")for i, res in enumerate(results):if res:print(f"心跳 {i+1}: {res[0]} (置信度: {res[1]:.3f})")else:print(f"心跳 {i+1}: 越界或分析失败")

8. 模型移植

为了提升模型的跨平台部署能力,降低对 PyTorch 框架的依赖,本项目采用 ONNX(Open Neural Network Exchange) 格式进行模型导出与推理验证。ONNX 支持多种运行时(如 ONNX Runtime、TensorRT、CoreML),可在服务器、移动端、边缘设备上高效运行,极大增强了模型的工程落地潜力。

8.1 导出ONNX

使用 torch.onnx.export 工具将训练好的 PyTorch 模型转换为 .onnx 文件。关键配置如下:

  • dummy_input:提供示例输入张量,用于追踪计算图;
  • opset_version=11:确保支持常用算子(如 Conv1d、BatchNorm);
  • dynamic_axes:允许动态 batch size,适应不同输入规模;
  • do_constant_folding=True:优化常量节点,减小模型体积并提升推理速度。

导出后可通过 Netron 等工具查看模型结构,确认节点连接正确。

在这里插入图片描述

'''
模型导出onnx
'''
# export_onnx.pyimport torch
from model_self import ECGCNN  # 确保路径正确def main():# 1. 定义模型结构(必须和训练时一致)model = ECGCNN(num_classes=5)# 2. 加载你训练好的权重model.load_state_dict(torch.load("./weight/model_self.pth", map_location='cpu'))model.eval()  # 切换到推理模式# 3. 构造一个示例输入(shape: batch x channel x length)dummy_input = torch.randn(1, 1, 187)  # 和ECG 数据 shape 一致# 4. 导出为 ONNX  (通过输入构造一个示例输入 进行追踪推理)torch.onnx.export(model,dummy_input,"./weight/ecg_model_self.onnx",  # 输出路径export_params=True,           # 保存模型参数opset_version=11,             # ONNX 版本兼容性do_constant_folding=True,     # 优化input_names=['ecg_input'],    # 输入名output_names=['logits'],  # 输出名dynamic_axes={'ecg_input': {0: 'batch_size'},'logits': {0: 'batch_size'}}  # 支持动态 batch)print(" 成功导出 ONNX 模型到:./weight/ecg_model_self.onnx")if __name__ == "__main__":main()

8.2 使用ONNX推理

利用 ONNX Runtime 加载 .onnx 模型并执行推理,验证其输出是否与原始 PyTorch 模型一致。

  • onnxruntime.InferenceSession 提供跨平台高性能推理引擎;
  • 输入需按 input_names 指定的名称传入(如 'ecg_input');
  • 输出返回概率分布,取最大值作为预测结果。

经测试,ONNX 模型与 PyTorch 模型的预测结果完全一致(误差 < 1e-6),说明转换成功。同时,ONNX Runtime 在 CPU 上的推理速度更快,更适合部署在资源受限设备上。

import torch
import numpy as np
import onnxruntime as ort
from model_self import ECGCNN  
from preprocess import preprocess_ecg_signal# 类别名称
CLASS_NAMES = ['Normal', 'Supraventricular', 'Ventricular', 'Fusion', 'Unknown'#  ONNX 推理函数
def predict_heartbeat_onnx(session, ecg_signal):"""使用 ONNX 模型预测单个或批量心跳"""if ecg_signal.ndim == 1:input_data = ecg_signal.reshape(1, 1, -1).astype(np.float32)  # (1, 1, 187)else:input_data = ecg_signal.reshape(-1, 1, 187).astype(np.float32)  # (N, 1, 187)result = session.run(["logits"], {"ecg_input": input_data})logits = result[0]  # (N, 5)prob = softmax(logits, axis=1)confidence = np.max(prob, axis=1)predicted = np.argmax(prob, axis=1)return predicted, confidence, probif __name__ == "__main__":ort_session = load_onnx_model("./weight/ecg_model_self.onnx")# 生成并清洗 ECG 信号 from Noisy_ECG_Signal import generate_noisy_ecgt, ecg_clean, ecg_noisy = generate_noisy_ecg(fs=500, duration=10, show_plot=False)# 只清洗一次ecg_cleaned = preprocess_ecg_signal(ecg_noisy, fs=500)print(f"ECG 信号长度: {len(ecg_cleaned)}")# 检测 R 波 r_peaks = detect_r_peaks(ecg_cleaned, fs=500)print(f"检测到 {len(r_peaks)} 个 R 波")# 提取心跳 beats, beat_positions = extract_beats(ecg_cleaned, r_peaks, fs=500, beat_length=187)print(f"成功提取 {len(beats)} 个心跳片段")# 使用 ONNX 模型推理pred_onnx, conf_onnx, prob_onnx = predict_heartbeat_onnx(ort_session, beats)print(f"\n【ONNX 推理】")print(f"预测类别: {CLASS_NAMES[pred_onnx]}")print(f"置信度: {conf_onnx:.3f}")

9. 项目总结

9.1 问题及解决办法

在进行基于CNN实现心律失常(ECG)的小颗粒度分类时,可能会遇到以下问题及解决办法:

问题解决方案
类别严重不平衡(N类占80%)使用 SMOTE 过采样平衡训练集
模型过拟合添加 Dropout、BatchNorm
输入维度不匹配reshape 为 (N, 187, 1) 适配 1D-CNN
推理时预处理不一致保存 StandardScaler 并在推理时复用

9.2 收获

  • 掌握了 1D-CNN 在时间序列分类中的应用
  • 学会了使用 SMOTE 解决类别不平衡问题
  • 实践了从训练到部署的完整流程(PyTorch → ONNX)
  • 提升了模型可视化与可解释性能力
http://www.xdnf.cn/news/1396243.html

相关文章:

  • HAProxy 负载均衡全解析:从基础部署、负载策略到会话保持及性能优化指南
  • docker命令(二)
  • 现状摸底:如何快速诊断企业的“数字化健康度”?
  • PCIe 6.0 TLP深度解析:从结构设计到错误处理的全链路机制
  • 算法题(194):字典树
  • 从0到1玩转 Google SEO
  • Suno-API - OpenI
  • “FAQ + AI”智能助手全栈实现方案
  • Python从入门到高手9.4节-基于字典树的敏感词识别算法
  • 8月29日星期五今日早报简报微语报早读
  • 轮廓周长,面积,外接圆,外接矩形近似轮廓和模板匹配和argparse模块实现代码参数的动态配置
  • 【C++】掌握类模板:多参数实战技巧
  • 基于Net海洋生态环境保护系统的设计与实现(代码+数据库+LW)
  • MYSQL速通(2/5)
  • 小杰机器视觉(six)——模板匹配
  • UCIE Specification详解(十)
  • TypeScript: Symbol.iterator属性
  • WINTRUST!_GetMessage函数分析之CRYPT32!CryptSIPGetSignedDataMsg函数的作用是得到nt5inf.cat的信息
  • AI的“科学革命”:Karpathy吹响号角,从“经院哲学”走向“实验科学”
  • 基于STM32单片机的智能温室控制声光报警系统设计
  • Geocodify 的 API
  • CD71.【C++ Dev】二叉树的三种非递归遍历方式
  • 网络编程 反射【详解】 | Java 学习日志 | 第 15 天
  • 2025牛客暑期多校训练营4 G Ghost in the Parentheses 题解记录
  • Day17 Docker学习
  • uac播放与录制
  • 论文阅读:arixv 2025 WideSearch: Benchmarking Agentic Broad Info-Seeking
  • React Three Fiber
  • LBM——大型行为模型助力波士顿人形Atlas完成多任务灵巧操作:CLIP编码图像与语义,之后DiT去噪扩散生成动作
  • 编程速递:RAD Studio 13 即将到来的功能