当前位置: 首页 > ai >正文

大模型之路(day 1)

这段时间以来,全身心的投入了研究大模型,虽然还是入门,但比之前已经好了非常多了,不得不说,计算机的学习特别需要强大的自驱力和耐心,以及检索能力。知乎确实在这些知识的分享上做的比csdn好太多了

万事开头难,坚持才是一切

本文对karpathy的minbpe项目进行了学习,加上了中文注释。在笔记本上也能训练,无需租服务器

 base.py

import unicodedatadef get_stats(ids,counts=None):#ids是整数列表,counts一个字典,如果不为空则更新#给一串序列,返回相邻pair的次数#Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}counts = {} if counts is None else countsfor pair in zip(ids,ids[1:]):#元组,元组做参数,前面和后面配对,这里会生成 相邻元素对列表counts[pair] = counts.get(pair,0)+1#get返回键对应值,如果没有,返回默认值0return countsdef merge(ids,pair,idx):#在ids列表中替换所有连续出现的元素对pair为一个新的整数newids = []i =0while i < len(ids):if ids[i]==pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:newids.append(idx)i+=2else:newids.append(ids[i])i+=1return newidsdef replace_control_characters(s:str) -> str:#把字符串控制字符替换为转义的形式#因为他们本身不可视,但影响文本显示chars = []for ch in s:if unicodedata.category(ch)[0] != "C":#获取当前字符的unicode类别,Lu为大写字母,Cc为控制字符chars.append(ch)else:chars.append(f"\\u{ord(ch):04x}")#ord获取字符的asi,转化为转义格式return "".join(chars)#h的utf-8 0x68 也就是104def render_token(t:bytes) ->str:s = t.decode('utf-8',errors='replace')#如果遇到无法解码字符,用替代字符s = replace_control_characters(s)return s#the base tokenizer class
class Tokenizer:#分词器基类def __init__(self):self.merges = {}#(int,int)->int,表示哪些token对合并为新tokenself.pattern = ""#分词的正则表达式规则self.special_tokens = {}#特殊tokenself.vocab = self._build_vocab()#生成词表def train(self,text,vocab_size,verbose=False):raise NotImplementedErrordef encode(self,text):raise NotImplementedErrordef decode(self,ids):raise NotImplementedErrordef _build_vocab(self):vocab = {idx:bytes([idx]) for idx in range(256)}#初始化所有单字节tokenfor (p0,p1),idx in self.merges.items():vocab[idx] = vocab[p0] + vocab[p1]for special,idx in self.special_tokens.items():vocab[idx] = special.encode("utf-8")#对于special,用UTF-8编码为bytesreturn vocabdef save(self,file_prefix):model_file = file_prefix + ".model"with open(model_file,'w') as f:f.write("minbpe v1\n")#版本f.write(f"{self.pattern}\n")#分词规则f.write(f"{len(self.special_tokens)}\n")#特殊token数量for special,idx in self.special_tokens.items():f.write(f"{special} {idx}\n")#每个特殊tokenfor idx1,idx2 in self.merges:f.write(f"{idx1} {idx2}\n")#合并对 vocab_file = file_prefix + ".vocab"inverted_merges = {idx : pair for pair ,idx in self.merges.items()}with open(vocab_file,"w",encoding = "utf-8") as f:for idx,token in self.vocab.items():s = render_token(token)#把bytes变成可读字符串if idx in inverted_merges:#如果是merge token 显示merge来源idx0,idx1 = inverted_merges[idx]s0 = render_token(self.vocab[idx0])s1 = render_token(self.vocab[idx1])f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n")else:# 否则就是原始 token(前256个)f.write(f"[{s}] {idx}\n")def load(self,model_file):#inverse of save() only for model file#加载模型参数assert model_file.endswich(".model")merges = {}#存储合并规则special_tokens = {}#存储特殊tokenidx=256#新的token初始编号with open(model_file,'r',encoding="utf-8") as f:version = f.readline().strip()assert version == "minbpe v1"self.pattern = f.readline().strip()#patternnum_special = int(f.readline().strip())for _ in range(num_special):special,special_idx = f.readline().strip().split()special_tokens[special] = int(special_idx)for line in f:idx1,idx2 = map(int,line.split())merges[(idx1,idx2)] = idxidx += 1self.merges = mergesself.special_tokens = special_tokensself.vocab = self._build_vocab()

regex_tokenizer.py

import regex as re
from base import Tokenizer,get_stats,merge#加一个.能从当前目录下面找GPT2_SPLIT_PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""class RegexTokenizer(Tokenizer):def __init__(self,pattern=None):super().__init__()self.pattern = GPT4_SPLIT_PATTERN if pattern is None else patternself.compiled_pattern = re.compile(self.pattern)#把正则表达式编译为对象,加快后续使用速度self.special_tokens = {}self.inverse_special_tokens = {} #反向查表,用于解码快速找到token对应字符串def train(self,text,vocab_size,verbose=False):#接受原始文本,BPE算法训练出一个vocab_size大小的词表assert vocab_size >=256num_merges = vocab_size -256text_chunks = re.findall(self.compiled_pattern,text)ids = [list(ch.encode("utf-8")) for ch in text_chunks]#每个token片段用UTF-8编码#这里的list把utf-8编码转换为整数merges = {}vocab = {idx:bytes([idx]) for idx in range(256)}#vocab[65] = b'A'#vocab[200] = b'\xc8'for i in range(num_merges):stats = {}for chunk_ids in ids:get_stats(chunk_ids,stats)#见base ids是整数列表pair = max(stats,key=stats.get)#找到频率最高的pair,key参数指定根据什么排序#在这里是找出对应value最大那个keyidx = 256 + iids = [merge(chunk_ids,pair,idx) for chunk_ids in ids]#见basemerges[pair] = idxvocab[idx] = vocab[pair[0]] +vocab[pair[1]]#都是基于0-255,所以都是存储字节if verbose:print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")self.merges = mergesself.vocab = vocabdef register_special_tokens(self,special_tokens):# special_tokens is a dictionary of str -> int# example: {"<|endoftext|>": 100257}self.special_tokens = special_tokensself.inverse_special_tokens = {v:k for k,v in special_tokens.items()}def decode(self,ids):#把token id转换回字符串part_bytes = []for idx in ids:if idx in self.vocab:part_bytes.append(self.vocab[idx])elif idx in self.inverse_special_tokens:part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8"))else:raise ValueError(f"invalid token id: {idx}")text_bytes = b"".join(part_bytes)#拼接所有字节text = text_bytes.decode("utf-8",errors="replace")return textdef _encode_chunk(self,text_bytes):#对一个utf-8字节流做BPE编码,返回token id 列表#用训练好的规则对新文本编码ids =list(text_bytes)while len(ids) >= 2:stats = get_stats(ids) #统计pair = min(stats,key=lambda p: self.merges.get(p,float("inf")))#从当前pair,找到在merges中编号最小的piar,如果不在,index就是infif pair not in self.merges:breakidx = self.merges[pair]ids = merge(ids,pair,idx)return idsdef encode_ordinary(self,text):#只对普通文本编码,不考虑特殊tokentext_chunks = re.findall(self.compiled_pattern,text)ids = []for chunk in text_chunks:chunk_bytes = chunk.encode("utf-8")chunk_ids = self._encode_chunk(chunk_bytes)ids.extend(chunk_ids)#和append区别在于用列表添加列表return idsdef encode(self,text,allowed_special="none_raise"):special = Nonespecial = Noneif allowed_special == "all":special = self.special_tokenselif allowed_special == "none":special = {}elif allowed_special == "none_raise":special = {}assert all(token not in text for token in self.special_tokens)  # 这里检查有没有非法 tokenelif isinstance(allowed_special, set):special = {k: v for k, v in self.special_tokens.items() if k in allowed_special}#自定义tokenelse:raise ValueError(...)special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")"#re.escape转义有特殊意义字符,用|连接特殊token,外面加上括号做捕获special_chunks = re.split(special_pattern, text)ids = []for part in special_chunks:if part in special:ids.append(special[part])else:ids.extend(self.encode_ordinary(part))return ids

train.py

import os
import timefrom regex_tokenizer import RegexTokenizer# 打开文本文件,读取训练数据
text = open("dataset/taylorswift.txt", "r", encoding="utf-8").read()# 创建模型输出目录
os.makedirs("models", exist_ok=True)# 计时开始
t0 = time.time()# 只训练 RegexTokenizer
tokenizer = RegexTokenizer()
tokenizer.train(text, 512, verbose=True)# 保存模型
prefix = os.path.join("models", "regex")
tokenizer.save(prefix)# 计时结束
t1 = time.time()
print(f"Training took {t1 - t0:.2f} seconds")

http://www.xdnf.cn/news/718.html

相关文章:

  • 嵌入式学习——远程终端登录和桌面访问
  • Java Web项目(一)
  • Mysql相关知识2:Mysql隔离级别、MVCC、锁
  • 深度可分离卷积与普通卷积的区别及原理
  • 【C++】继承----上篇
  • mysql
  • QSS【QT】
  • 常见超低噪声 LDO,ADM7150、LP5907、SGN2036、TPL910
  • 力扣刷题 - 203.移除链表元素
  • 4.20刷题记录(单调栈)
  • 基于springboot的商城
  • 积木报表查询出现jdbc.SQLServerException: 对象名 ‘user_tab_comment 的解决方法
  • 力扣算法ing(61 / 100)
  • 5.1 掌握函数定义与参数传递的奥秘
  • 【Qt】信号和槽
  • [安全实战]逆向工程核心名词详解
  • DAY6:从执行计划到索引优化的完整指南
  • React基础知识(补充中)
  • PyTorch基础学习系列一
  • 安卓手机怎样配置数据加速
  • Java File 类详解
  • 从事计算机视觉需要掌握哪些知识
  • 微信小程序通过mqtt控制esp32
  • Map遍历
  • Linux 进程概念补充 (自用)
  • 【数据结构】红黑树
  • 2181、合并零之间的节点
  • 右起第2个LED灯的闪烁(STC89C52单片机)
  • HTTP 1.0 和 2.0 的区别
  • (done) 吴恩达版提示词工程 1. 引言 (Base LLM 和 Instruction Tuned LLM)