当前位置: 首页 > news >正文

第TR5周:Transformer实战:文本分类

  •    🍨 本文为🔗365天深度学习训练营中的学习记录博客
  •    🍖 原作者:K同学啊

1.准备工作

1.1.加载数据

import torch
import torch.nn as nn
import torchvision
import os,PIL,warnings
import pandas as pd

warnings.filterwarnings('ignore')

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

 #加载自定义中文数据
train_data=pd.read_csv('train.csv',sep='\t',header=None)
train_data.head()

 

#构建数据集迭代器
def coustom_data_iter(texts,labels):
    for x,y in zip(texts,labels):
        yield x,y

train_iter=coustom_data_iter(train_data[0].values[:],train_data[1].values[:])

2.数据预处理

2.1.构建词典

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import jieba

tokenizer=jieba.lcut

def yield_tokens(data_iter):
    for text,_ in data_iter:
        yield tokenizer(text)

vocab=build_vocab_from_iterator(yield_tokens(train_iter),specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

vocab(['我','想','看','和平','精英','上','战神','必备','技巧','的','游戏','视频'])

label_name=list(set(train_data[1].values[:]))
print(label_name)

text_pipeline=lambda x:vocab(tokenizer(x))
label_pipeline=lambda x: label_name.index(x)

print(text_pipeline('我想看和平精英上战神必备技巧的游戏视频'))
print(label_pipeline('Video-Play'))

2.2.生成数据批次和迭代器 

from torch.utils.data import DataLoader

def collate_batch(batch):
    label_list,text_list,offsets=[],[],[0]

    for (_text,_label) in batch:
        #标签列表
        label_list.append(label_pipeline(_label))

        #文本列表
        processed_text=torch.tensor(text_pipeline(_text),dtype=torch.int64)
        text_list.append(processed_text)

        #偏移量
        offsets.append(processed_text.size(0))

    label_list=torch.tensor(label_list,dtype=torch.int64)
    text_list=torch.cat(text_list)
    offsets=torch.tensor(offsets[:-1]).cumsum(dim=0)  #返回维度dim中输入元素的累计和

    return text_list.to(device),label_list.to(device),offsets.to(device)

2.3.构建数据集

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset

BATCH_SIZE=4

train_iter=coustom_data_iter(train_data[0].values[:],train_data[1].values[:])
train_dataset=to_map_style_dataset(train_iter)

split_train_,split_valid_=random_split(train_dataset,
                                       [int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])
train_dataloader=DataLoader(split_train_,batch_size=BATCH_SIZE,
                            shuffle=True,collate_fn=collate_batch)
valid_dataloader=DataLoader(split_valid_,batch_size=BATCH_SIZE,
                            shuffle=True,collate_fn=collate_batch)

3.模型构建

3.1.定义位置编码函数

import math,os,torch

class PositionalEncoding(nn.Module):
    def __init__(self,embed_dim,max_len=500):
        super(PositionalEncoding,self).__init__()

        pe=torch.zeros(max_len,embed_dim)
        position=torch.arange(0,max_len,dtype=torch.float).unsqueeze(1)

        div_term=torch.exp(torch.arange(0,embed_dim,2).float()*(-math.log(100.0)/embed_dim))
        pe[:,0::2]=torch.sin(position*div_term)
        pe[:,1::2]=torch.cos(position*div_term)
        pe=pe.unsqueeze(0).transpose(0,1)

        self.register_buffer('pe',pe)

    def forward(self,x):
        x=x+self.pe[:x.size(0)]
        return x

3.2.定义Transformer模型

from tempfile import TemporaryDirectory
from typing import Tuple
from torch import nn,Tensor
from torch.nn import TransformerEncoder,TransformerEncoderLayer
from torch.utils.data import dataset

class TransformerModel(nn.Module):
    def __init__(self,vocab_size,embed_dim,num_class,nhead=8,d_hid=256,nlayers=12,dropout=0.1):
        super().__init__()
        self.embedding=nn.EmbeddingBag(vocab_size,
                                       embed_dim,
                                       sparse=False)
        self.pos_encoder=PositionalEncoding(embed_dim)

        #定义编码器层
        encoder_layers=TransformerEncoderLayer(embed_dim,nhead,d_hid,dropout)
        self.transformer_encoder=TransformerEncoder(encoder_layers,nlayers)
        self.embed_dim=embed_dim
        self.linear=nn.Linear(embed_dim*4,num_class)

    def forward(self,src,offsets,src_mask=None):
        src=self.embedding(src,offsets)
        src=self.pos_encoder(src)
        output=self.transformer_encoder(src,src_mask)

        output=output.view(4,embed_dim*4)
        output=self.linear(output)

        return output
 

3.3.初始化模型

vocab_size=len(vocab)
embed_dim=64
num_class=len(label_name)

model=TransformerModel(vocab_size,
                       embed_dim,
                       num_class).to(device)

3.4.定义训练函数

import time

def train(dataloader):
    model.train()
    total_acc,train_loss,total_count=0,0,0
    log_interval=300
    start_time=time.time()

    for idx,(text,label,offsets) in enumerate(dataloader):
        predicted_label=model(text,offsets)
        optimizer.zero_grad()

        loss=criterion(predicted_label,label)
        loss.backward()
        optimizer.step()

        total_acc+=(predicted_label.argmax(1)==label).sum().item()
        train_loss+=loss.item()
        total_count+=label.size(0)

        if idx%log_interval==0 and idx>0:
            elapsed=time.time() - start_time
            print('| epoch {:1d} | {:4d}/{:4d} batches | train_acc {:4.3f} train_loss {:4.5f}'.format(epoch,idx,len(dataloader),total_acc/total_count,train_loss/total_count))
            total_acc,train_loss,total_count=0,0,0
            start_time=time.time()

3.5.定义评估函数

def evaluate(dataloader):
    model.eval()
    total_acc,train_loss,total_count=0,0,0

    with torch.no_grad():
        for idx,(text,label,offsets) in enumerate(dataloader):
            predicted_label=model(text,offsets)

            loss=criterion(predicted_label,label)

            total_acc+=(predicted_label.argmax(1)==label).sum().item()
            train_loss+=loss.item()
            total_count+=label.size(0)

    return total_acc/total_count,train_loss/total_count

4.训练模型

4.1.模型训练

EPOCH=10

criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(),lr=1e-2)

for epoch in range(1,EPOCH+1):
    epoch_start_time=time.time()
    train(train_dataloader)
    val_acc,val_loss=evaluate(valid_dataloader)

    lr=optimizer.state_dict()['param_groups'][0]['lr']

    print('-'*68)
    print('| epoch {:1d} | time: {:4.2f}s | valid_acc {:4.3f} valid_loss {:4.3f} | lr {:4.6f}'.format(epoch,time.time()-epoch_start_time,val_acc,val_loss,lr))

 

4.2.模型评估

test_acc,test_loss=evaluate(valid_dataloader)
print('模型准确率为:{:5.4f}'.format(test_acc))

5.总结

使用 jieba 分词构建词汇表并将文本转换为数值序列,再通过嵌入层(nn.EmbeddingBag)结合自定义的位置编码(PositionalEncoding)将输入嵌入空间映射,随后利用多层 TransformerEncoder 编码文本语义特征,最后通过线性层进行多类别分类,训练过程使用交叉熵损失函数和 SGD 优化器,在模型评估阶段输出准确率和损失值;其中 Transformer 是一种基于注意力机制的深度学习结构

http://www.xdnf.cn/news/141139.html

相关文章:

  • 图像识别系统 - Ubuntu部署指南(香橙派开发板测试)-学习记录1
  • MySQL 详解之函数:数据处理与计算的利器
  • HOW - 如何模拟实现 gpt 展示答案的交互效果
  • form表单提交前设置请求头request header及文件下载
  • 线程怎么创建?Java 四种方式一网打尽
  • uniapp 仿企微左边公司切换页
  • FreeRTOS
  • 斗鱼娱乐电玩平台源码搭建实录
  • 短视频矩阵系统可视化剪辑功能开发,支持OEM
  • QT 连接数据库操作(15)
  • Pandas 数据导出:如何将 DataFrame 追加到 Excel 的不同工作表
  • 银发科技:AI健康小屋如何破解老龄化困局
  • MYSQL之数据类型
  • 【MySQL】3分钟解决MySQL深度分页问题
  • git 命令集
  • 【Web应用服务器_Tomcat】一、Tomcat基础与核心功能详解
  • 如何配置Spark
  • Spring-Framework源码环境搭建
  • 7.10 GitHub Sentinel CLI开发实战:Python构建企业级监控工具的5大核心技巧
  • JMeter添加HTTP请求默认值元件的作用详解
  • 百度打响第一枪!通用超级智能体时代,真的来了
  • 常用第三方库:flutter_boost混合开发
  • Android Kotlin 依赖注入全解:Koin appModule 配置与多 ViewModel 数据共享实战指南
  • 解决视频处理中的 HEVC 解码错误:Could not find ref with POC xxx【已解决】
  • 创建型设计模式之:简单工厂模式、工厂方法模式、抽象工厂模式、建造者模式和原型模式
  • 【QQMusic项目复习笔记——音乐管理模块详解】第四章
  • 1.10软考系统架构设计师:优秀架构设计师 - 练习题附答案及超详细解析
  • 时序数据库IoTDB在航空航天领域的解决方案
  • BiliNote:开源的AI视频笔记生成工具,让知识提取与分享更高效——跨平台自动生成结构化笔记,实现从视频到Markdown的智能转化
  • PT report_timing详解