人工智能学习:基于seq2seq模型架构实现翻译
一、数据集介绍
-
data/eng-fra-v2.txt 是我们案例使用的数据集
-
左半部分是英文,右半部分是法文
i am from brazil . je viens du bresil . i am from france . je viens de france . i am from russia . je viens de russie . i am frying fish . je fais frire du poisson . i am not kidding . je ne blague pas . i am on duty now . maintenant je suis en service . i am on duty now . je suis actuellement en service . i am only joking . je ne fais que blaguer . i am out of time . je suis a court de temps . i am out of work . je suis au chomage . i am out of work . je suis sans travail . i am paid weekly . je suis payee a la semaine . i am pretty sure . je suis relativement sur . i am truly sorry . je suis vraiment desole . i am truly sorry . je suis vraiment desolee .
二、案例实现步骤
基于GRU的seq2seq模型架构实现翻译的过程:
- 第一步: 导入工具包和工具函数
- 第二步: 对持久化文件中的数据进行处理, 以满足模型训练要求
- 第三步: 构建基于GRU的编码器和解码器
- 第四步: 构建模型训练函数, 并进行训练
- 第五步: 构建模型评估函数, 并进行测试以及Attention效果分析
1、导入工具包和工具函数
# 用于正则表达式
import re# 用于构建网络结构和函数的torch工具包
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader# torch中预定义的优化方法工具包
import torch.optim as optim
import time# 用于随机生成数据
import random
import numpy as np
import matplotlib.pyplot as plt# 设备选择, 我们可以选择在cuda或者cpu上运行你的代码
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 起始标志 SOS->Start Of Sequence
SOS_token = 0
# 结束标志 EOS->End Of Sequence
EOS_token = 1
# 最大句子长度不能超过10个(包含标点),用于设置每个句子样本的中间语义张量c长度都为10。
MAX_LENGTH = 10
# 数据文件路径
data_path = "./data/eng-fra-v2.txt"# 文本清洗工具函数
def normalizeString(s: str):"""字符串规范化函数, 参数s代表传入的字符串"""s = s.lower().strip()# 在.!?前加一个空格, 即用 “空格 + 原标点” 替换原标点。# \1 代表 捕获的标点符号,即 ., !, ? 之一。s = re.sub(r"([.!?])", r" \1", s)# 用一个空格替换原标点,意味着 标点符号被完全去掉,只留下空格。# s = re.sub(r"([.!?])", r" ", s)# 使用正则表达式将字符串中 不是 至少1个小写字母和正常标点的都替换成空格s = re.sub(r"[^a-z.!?]+", r" ", s)return s
2、数据预处理
对持久化文件中数据进行处理, 以满足模型训练要求
1、清洗文本和构建文本字典
-
清洗文本和构建文本字典思路分析
Python# my_getdata() 清洗文本构建字典思路分析 # 1 按行读文件 open().read().strip().split(\n) my_lines # 2 按行清洗文本 构建语言对 my_pairs[] tmppair[] # 2-1格式 [['英文', '法文'], ['英文', '法文'], ['英文', '法文'], ['英文', '法文']....] # 2-2调用清洗文本工具函数normalizeString(s) # 3 遍历语言对 构建英语单词字典 法语单词字典 my_pairs->pair->pair[0].split(' ') pair[1].split(' ')->word # 3-1 english_word2index english_word_n french_word2index french_word_n # 其中 english_word2index = {0: "SOS", 1: "EOS"} english_word_n=2 # 3-2 english_index2word french_index2word # 4 返回数据的7个结果 # english_word2index, english_index2word, english_word_n, # french_word2index, french_index2word, french_word_n, my_pairs
-
代码实现
Python
def my_getdata():
# 1 按行读文件 open().read().strip().split(\n)
with open(data_path, "r", encoding="utf-8") as f:
my_lines = f.read().strip().split("\n")
print("my_lines--->", len(my_lines))
# 2 按行清洗文本 构建语言对 my_pairs
# 格式 [['英文句子', '法文句子'], ['英文句子', '法文句子'], ['英文句子', '法文句子'], ... ]
tmp_pair, my_pairs = [], []
for l in my_lines:
for s in l.split("\t"):
tmp_pair.append(normalizeString(s))
my_pairs.append(tmp_pair)
# 清空tmp_pair, 存储下一个句子的英语和法语的句子对
tmp_pair = []
# my_pairs = [[normalizeString(s) for s in l.split('\t')] for l in my_lines]
# print('my_pairs--->', my_pairs)
print("len(my_pairs)--->", len(my_pairs))
# 打印前4条数据
print(my_pairs[:4])
# 打印第8000条的英文 法文数据
print("my_pairs[8000][0]--->", my_pairs[8000][0])
print("my_pairs[8000][1]--->", my_pairs[8000][1])
# 3 遍历语言对 构建英语单词字典 法语单词字典
# 3-1 english_word2index english_word_n french_word2index french_word_n
# SOS->Start Of Sequence
# EOS->End Of Sequence
english_word2index = {"SOS": 0, "EOS": 1}
# 第三个单词的下标值从2开始
english_word_n = 2
french_word2index = {"SOS": 0, "EOS": 1}
french_word_n = 2
# 遍历语言对 获取英语单词字典 法语单词字典
# {单词1:下标1, 单词2:下标2, ...}
for pair in my_pairs:
for word in pair[0].split(" "):
if word not in english_word2index:
english_word2index[word] = english_word_n
# 更新下一个单词的下标值
english_word_n += 1
for word in pair[1].split(" "):
if word not in french_word2index:
french_word2index[word] = french_word_n
french_word_n += 1
# 3-2 english_index2word french_index2word
# # {下标1:单词1, 下标2:单词2, ...}
english_index2word = {v: k for k, v in english_word2index.items()}
french_index2word = {v: k for k, v in french_word2index.items()}
print("len(english_word2index)-->", len(english_word2index))
print("len(french_word2index)-->", len(french_word2index))
print("english_word_n--->", english_word_n, "french_word_n-->", french_word_n)
return (
english_word2index,
english_index2word,
english_word_n,
french_word2index,
french_index2word,
french_word_n,
my_pairs,
)
if __name__ == "__main__":
# 获取英语单词字典 法语单词字典 语言对列表my_pairs
(
english_word2index,
english_index2word,
english_word_n,
french_word2index,
french_index2word,
french_word_n,
my_pairs,
) = my_getdata()
输出结果:
Python
my_lines---> 63594
len(pairs)---> 63594
[['i m .', 'j ai ans .'], ['i m ok .', 'je vais bien .'], ['i m ok .', 'ca va .'], ['i m fat .', 'je suis gras .']]
my_pairs[8000][0]---> they re in the science lab .
my_pairs[8000][1]---> elles sont dans le laboratoire de sciences .
len(english_word2index)--> 2803
len(french_word2index)--> 4345
english_word_n---> 2803 french_word_n--> 4345
2、构建数据源对象
Python
# 原始数据 -> 数据源MyPairsDataset --> 数据迭代器DataLoader
# 构造数据源 MyPairsDataset,把语料xy 文本数值化 再转成tensor_x tensor_y
# 1 __init__(self, my_pairs)函数 设置self.my_pairs 条目数self.sample_len
# 2 __len__(self)函数 获取样本条数
# 3 __getitem__(self, index)函数 获取第几条样本数据
# 按索引 获取数据样本 x y
# 样本x 文本数值化 word2id x.append(EOS_token)
# 样本y 文本数值化 word2id y.append(EOS_token)
# 返回tensor_x, tensor_y
class MyPairsDataset(Dataset):
def __init__(self, my_pairs, english_word2index, french_word2index):
# 样本x
self.my_pairs = my_pairs
self.english_word2index = english_word2index
self.french_word2index = french_word2index
# 样本条目数
self.sample_len = len(my_pairs)
# 获取样本条数
def __len__(self):
return self.sample_len
# 获取第几条 样本数据
def __getitem__(self, index):
# 对index异常值进行修正 [0, self.sample_len-1]
index = min(max(index, 0), self.sample_len - 1)
# 按索引获取 数据样本 x y
x = self.my_pairs[index][0] # 英文句子
y = self.my_pairs[index][1] # 法文句子
# 样本x 文本数值化
x = [self.english_word2index[word] for word in x.split(" ")]
x.append(EOS_token)
tensor_x = torch.tensor(x, dtype=torch.long, device=device)
# print('tensor_x.shape===>', tensor_x.shape, tensor_x)
# 样本y 文本数值化
y = [self.french_word2index[word] for word in y.split(" ")]
y.append(EOS_token)
tensor_y = torch.tensor(y, dtype=torch.long, device=device)
# 注意 tensor_x tensor_y都是一维数组,通过DataLoader拿出的数据是二维数据
# print('tensor_y.shape===>', tensor_y.shape, tensor_y)
# 返回结果
return tensor_x, tensor_y
3、构建数据加载器
Python
def dm_test_MyPairsDataset():
# 1 调用my_getdata函数获取数据
(
english_word2index,
english_index2word,
english_word_n,
french_word2index,
french_index2word,
french_word_n,
my_pairs,
) = my_getdata()
# 2 实例化dataset对象
mypairsdataset = MyPairsDataset(my_pairs, english_word2index, french_word2index)
# 3 实例化dataloader
mydataloader = DataLoader(dataset=mypairsdataset, batch_size=1, shuffle=True)
for i, (x, y) in enumerate(mydataloader):