当前位置: 首页 > news >正文

PyTorch LSTM文本生成

PyTorch LSTM文本生成

1. 环境准备和导入

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
from collections import Counter
import string
import re# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)# 检查GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

2. 数据准备(使用简化的文本数据)

# 使用莎士比亚文本作为示例(可以替换为WikiText-2或其他数据集)
sample_text = """
To be, or not to be, that is the question:
Whether 'tis nobler in the mind to suffer
The slings and arrows of outrageous fortune,
Or to take arms against a sea of troubles
And by opposing end them. To die—to sleep,
No more; and by a sleep to say we end
The heart-ache and the thousand natural shocks
That flesh is heir to: 'tis a consummation
Devoutly to be wish'd. To die, to sleep;
To sleep, perchance to dream—ay, there's the rub:
For in that sleep of death what dreams may come,
When we have shuffled off this mortal coil,
Must give us pause—there's the respect
That makes calamity of so long life.
"""class TextDataset(Dataset):def __init__(self, text, seq_length=40):"""文本数据集类Args:text: 输入文本seq_length: 序列长度"""self.seq_length = seq_lengthself.text = self.preprocess_text(text)# 创建字符到索引的映射self.chars = sorted(list(set(self.text)))self.char_to_idx = {ch: i for i, ch in enumerate(self.chars)}self.idx_to_char = {i: ch for i, ch in enumerate(self.chars)}self.vocab_size = len(self.chars)print(f"文本长度: {len(self.text)}")print(f"词汇表大小: {self.vocab_size}")print(f"示例字符: {self.chars[:20]}")# 准备训练数据self.prepare_data()def preprocess_text(self, text):"""预处理文本"""# 转换为小写并保留基本标点text = text.lower().strip()# 移除多余空格text = re.sub(r'\s+', ' ', text)return textdef prepare_data(self):"""准备输入输出序列"""self.inputs = []self.targets = []for i in range(len(self.text) - self.seq_length):input_seq = self.text[i:i + self.seq_length]target_seq = self.text[i + 1:i + self.seq_length + 1]self.inputs.append([self.char_to_idx[ch] for ch in input_seq])self.targets.append([self.char_to_idx[ch] for ch in target_seq])def __len__(self):return len(self.inputs)def __getitem__(self, idx):return (torch.tensor(self.inputs[idx], dtype=torch.long),torch.tensor(self.targets[idx], dtype=torch.long))

3. LSTM模型定义

class LSTMGenerator(nn.Module):def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, num_layers=2, dropout=0.2):"""LSTM文本生成模型Args:vocab_size: 词汇表大小embedding_dim: 嵌入维度hidden_dim: 隐藏层维度num_layers: LSTM层数dropout: Dropout率"""super(LSTMGenerator, self).__init__()self.vocab_size = vocab_sizeself.embedding_dim = embedding_dimself.hidden_dim = hidden_dimself.num_layers = num_layers# 嵌入层self.embedding = nn.Embedding(vocab_size, embedding_dim)# LSTM层self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True,dropout=dropout if num_layers > 1 else 0)# Dropout层self.dropout = nn.Dropout(dropout)# 输出层self.fc = nn.Linear(hidden_dim, vocab_size)def forward(self, x, hidden=None):"""前向传播Args:x: 输入序列 [batch_size, seq_length]hidden: 隐藏状态"""batch_size = x.size(0)# 嵌入embedded = self.embedding(x)  # [batch_size, seq_length, embedding_dim]embedded = self.dropout(embedded)# LSTMif hidden is None:hidden = self.init_hidden(batch_size, x.device)lstm_out, hidden = self.lstm(embedded, hidden)lstm_out = self.dropout(lstm_out)# 输出层output = self.fc(lstm_out)  # [batch_size, seq_length, vocab_size]return output, hiddendef init_hidden(self, batch_size, device):"""初始化隐藏状态"""h0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(device)c0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(device)return (h0, c0)

4. 训练函数

def train_model(model, dataset, epochs=100, batch_size=64, lr=0.001):"""训练模型"""dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=lr)scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)model.train()losses = []for epoch in range(epochs):epoch_loss = 0batch_count = 0for batch_idx, (inputs, targets) in enumerate(dataloader):inputs, targets = inputs.to(device), targets.to(device)# 前向传播optimizer.zero_grad()output, _ = model(inputs)# 计算损失loss = criterion(output.reshape(-1, model.vocab_size),targets.reshape(-1))# 反向传播loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 5)  # 梯度裁剪optimizer.step()epoch_loss += loss.item()batch_count += 1avg_loss = epoch_loss / batch_countlosses.append(avg_loss)scheduler.step(avg_loss)if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}')return losses

5. 文本生成函数

def generate_text(model, dataset, seed_text, length=200, temperature=1.0):"""生成文本Args:model: 训练好的模型dataset: 数据集(用于字符映射)seed_text: 种子文本length: 生成长度temperature: 温度参数(控制随机性)"""model.eval()# 预处理种子文本seed_text = seed_text.lower()# 转换为索引input_seq = [dataset.char_to_idx.get(ch, 0) for ch in seed_text]generated_text = seed_textwith torch.no_grad():for _ in range(length):# 准备输入if len(input_seq) > dataset.seq_length:input_seq = input_seq[-dataset.seq_length:]x = torch.tensor([input_seq], dtype=torch.long).to(device)# 预测output, _ = model(x)output = output[0, -1, :] / temperature# 采样probabilities = F.softmax(output, dim=0)next_idx = torch.multinomial(probabilities, 1).item()# 添加到序列input_seq.append(next_idx)generated_text += dataset.idx_to_char[next_idx]return generated_text

6. 主训练流程

def main():# 创建数据集dataset = TextDataset(sample_text, seq_length=40)# 创建模型model = LSTMGenerator(vocab_size=dataset.vocab_size,embedding_dim=128,hidden_dim=256,num_layers=2,dropout=0.2).to(device)print(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}")# 训练模型print("\n开始训练...")losses = train_model(model, dataset, epochs=100, batch_size=32, lr=0.001)# 生成文本示例print("\n生成文本示例:")print("-" * 50)# 不同温度参数的生成temperatures = [0.5, 0.8, 1.0, 1.2]seed_texts = ["to be", "the heart", "sleep"]for seed in seed_texts:print(f"\n种子文本: '{seed}'")for temp in temperatures:generated = generate_text(model, dataset, seed, length=100, temperature=temp)print(f"Temperature {temp}: {generated}")print()# 绘制损失曲线import matplotlib.pyplot as pltplt.figure(figsize=(10, 5))plt.plot(losses)plt.title('Training Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.grid(True)plt.show()return model, datasetif __name__ == "__main__":model, dataset = main()

7. 高级功能:条件文本生成

class ConditionalLSTM(nn.Module):"""带条件的LSTM生成器(如情感、风格等)"""def __init__(self, vocab_size, num_conditions, embedding_dim=128, hidden_dim=256, condition_dim=32):super(ConditionalLSTM, self).__init__()self.embedding = nn.Embedding(vocab_size, embedding_dim)self.condition_embedding = nn.Embedding(num_conditions, condition_dim)# LSTM输入包含文本嵌入和条件嵌入self.lstm = nn.LSTM(embedding_dim + condition_dim,hidden_dim,num_layers=2,batch_first=True,dropout=0.2)self.fc = nn.Linear(hidden_dim, vocab_size)def forward(self, x, condition):# 获取嵌入text_embedded = self.embedding(x)cond_embedded = self.condition_embedding(condition)# 扩展条件嵌入以匹配序列长度cond_embedded = cond_embedded.unsqueeze(1).expand(-1, text_embedded.size(1), -1)# 连接嵌入combined = torch.cat([text_embedded, cond_embedded], dim=-1)# LSTM和输出lstm_out, _ = self.lstm(combined)output = self.fc(lstm_out)return output

8. 评估和可视化

def evaluate_model(model, dataset, num_samples=5):"""评估模型生成质量"""model.eval()# 计算困惑度dataloader = DataLoader(dataset, batch_size=32, shuffle=False)criterion = nn.CrossEntropyLoss()total_loss = 0total_count = 0with torch.no_grad():for inputs, targets in dataloader:inputs, targets = inputs.to(device), targets.to(device)output, _ = model(inputs)loss = criterion(output.reshape(-1, model.vocab_size),targets.reshape(-1))total_loss += loss.item() * inputs.size(0)total_count += inputs.size(0)perplexity = np.exp(total_loss / total_count)print(f"困惑度 (Perplexity): {perplexity:.2f}")# 生成多样性评估generated_samples = []for _ in range(num_samples):seed = random.choice(["to ", "the ", "and "])text = generate_text(model, dataset, seed, length=100, temperature=0.8)generated_samples.append(text)# 计算唯一n-gramdef get_ngrams(text, n):return set([text[i:i+n] for i in range(len(text)-n+1)])all_bigrams = set()all_trigrams = set()for text in generated_samples:all_bigrams.update(get_ngrams(text, 2))all_trigrams.update(get_ngrams(text, 3))print(f"唯一2-gram数: {len(all_bigrams)}")print(f"唯一3-gram数: {len(all_trigrams)}")return perplexity, generated_samples

使用说明

  1. 数据准备:代码使用简化的莎士比亚文本,可以替换为:

    • WikiText-2/WikiText-103
    • Penn Treebank
    • 任何文本文件
  2. 模型配置

    • 调整embedding_dimhidden_dim控制模型容量
    • 增加num_layers提高模型复杂度
    • 调整temperature控制生成随机性
  3. 训练技巧

    • 使用梯度裁剪防止梯度爆炸
    • 使用学习率调度器自适应调整学习率
    • 适当的dropout防止过拟合
  4. 生成策略

    • Temperature采样:控制输出分布的尖锐程度
    • Top-k采样:只从概率最高的k个字符中采样
    • Beam搜索:生成多个候选序列并选择最优
  5. 使用预训练模型:如GPT-2、BERT等

  6. 添加注意力机制:提高长序列建模能力

  7. 实现GAN架构:生成对抗网络提高生成质量

  8. 多任务学习:同时训练多种文本生成任务

核心功能

字符级LSTM模型:支持任意文本数据
温度控制采样:调节生成文本的随机性
条件生成:可扩展为带条件(情感、风格)的生成
完整训练流程:包含优化器、学习率调度器

技术亮点

梯度裁剪:防止梯度爆炸
Dropout正则化:防止过拟合
困惑度评估:量化生成质量
多样性分析:n-gram统计

可扩展性

支持替换为WikiText-2、Penn Treebank等公开数据集
可集成注意力机制、Transformer架构
支持Beam搜索、Top-k采样等高级生成策略

http://www.xdnf.cn/news/1256491.html

相关文章:

  • 专题:2025财务转型与AI赋能数字化报告|附30+份报告PDF汇总下载
  • Casrel关系抽取
  • 【2025最新】在 macOS 上构建 Flutter iOS 应用
  • 关于时钟门控ICG的一切(与门及或门门控)
  • 紫光同创Logos2+RK3568JHF开发板:国产异构计算平台的破局者
  • Mongodb常用命令简介
  • 将Excel数据导入SQL Server数据库,并更新源表数据
  • 超全的软件测试项目平台,10多个项目部署在线上环境,浏览器直接访问
  • 树莓派安装OpenCV环境
  • 8、Redis的HyperLogLog、事务Multi、管道Pipeline,以及Redis7.0特性
  • STM32 HAL库外设编程学习笔记
  • iOS 文件管理实战指南,用户文件、安全访问与开发调试方案
  • npm 与 npx 区别详解。以及mcp中npx加载原理。
  • 多线程 future.get()的线程阻塞是什么意思?
  • [无需 Mac] 使用 GitHub Actions 构建 iOS 应用
  • 全栈:如何操作在SQLserver里面CRUD(增删改查)
  • stm32cubeide memory analyzer 不显示BUG
  • 使用Puppeteer轻松自动化浏览器操作
  • 高并发内存池 设计PageCache(4)
  • W25Q64模块
  • jetson上使用opencv的gstreamer进行MIPI和USB摄像头的连接以及udp推流
  • 网站IP被劫持?三步自建防护盾
  • 【基础知识】springboot+vue 基础框架搭建(更新中)
  • 数据库入门:从零开始构建你的第一个数据库
  • [Oracle] DECODE()函数
  • oracle 11G安装大概率遇到问题
  • Java面试宝典:JVM的垃圾收集算法
  • c++之 栈浅析
  • 如何提高云手机中数据信息的安全性?
  • 数字取证:可以恢复手机上被覆盖的数据吗?