当前位置: 首页 > web >正文

大模型微调示例四之Llama-Factory-DPO

大模型微调示例四之Llama-Factory-DPO

  • 一、强化学习数据处理
  • 二、配置训练文档
  • 三、模型预测

一、强化学习数据处理

原始数据地址:https://nijianmo.github.io/amazon/index.html

第一步:读取 video game 信息

import codecs, json, re
from random import shuffle# 第一步:读取 video game 信息
# key 是 productID,value是 title
games = {}
cc = 0with codecs.open('./data/src_data/meta_Video_Games.json', mode='r') as fin:for line in fin:tmp_info = json.loads(line.strip())# asin - ID of the product# title - name of the productgames[tmp_info["asin"]] = tmp_info["title"]if len(games) % 10000 == 0:print(f'Length of games: {len(games)}')

第二步:读取用户评分信息

# key 是 userid,value 是评价的游戏和评分
user_reviews = {}cc = 0
with codecs.open('./data/src_data/Video_Games_5.json', mode='r') as fin:for line in fin:tmp_info = json.loads(line.strip())# reviewerID - ID of the reviewerreviewer_id = tmp_info["reviewerID"]time_info = re.split(', | ', tmp_info["reviewTime"])review_time = time_info[2] + '-' + time_info[0] + '-' + time_info[1]# asin - ID of the productproduct_id = tmp_info["asin"]# overall - rating of the productrating = tmp_info["overall"]# if cc > 1000:#     break# print(tmp_info)# print(user_reviews)if product_id in games.keys():product_title = games[product_id]if reviewer_id in user_reviews.keys():user_reviews[reviewer_id].append((product_title, rating, review_time))else:user_reviews[reviewer_id] = [(product_title, rating, review_time)]if len(user_reviews) % 10000 == 0:print(f'Length of user_reviews: {len(user_reviews)}')cc += 1user_reviews_sorted = {}
for k, v in user_reviews.items():# 首先去重v = list(set(v))# 然后根据评价时间从小到大排序,表示用户的评价历史v_sorted = sorted(v, key=lambda x: x[2])# 选择具有7个及以上的评论样本if len(v) >= 7:# print(f'v: {v}, v_sorted: {v_sorted}')user_reviews_sorted[k] = v_sorted
print(f'Length of user_reviews_sorted: {len(user_reviews_sorted)}')

第三步 训练数据生成

# 总样本
samples = []
# 指令
instruction = "You are an assistant working on Video Games recommendations. Given the user's history of Video Games they have shopped, which includes the \"Title\" of the Video Games and the \"Rating\" the user rate (the Rating value is like or dislike), please decide whether the user likes to shop the target Video Games by outputting the order of their titles."
samples = []
cc = 0
for k, v in user_reviews_sorted.items():# print('-'*10)# print(v)sample_input = "User shopped Video Games histories (Title and Rating): \n"# 前面的当作对话历史for vv in v[0: -2]:# 当 rating 大于 3.0 的时候设置为 likeif vv[1] > 3.0:rating = 'like'# 当 rating 小于等于 3.0 的时候设置为 dislikeelse:rating = 'dislike'sample_input += "<Title: {}, Rating: {}>\n".format(vv[0], rating)sample_input += "Based on the Video Games histories, please sort the following two Video Games titles. The one in the front is what the user like and should be recommended to user: \n"# 最后两个设置为需要预测的目标sample_input += "<Title: " + v[-2][0] + '>\n'sample_input += "<Title: " + v[-1][0] + '>\n'# print(f'v[-1][1]: {v[-1][1]}, v[-2][1]: {v[-2][1]}')# 保证有一个是 like,有一个是 dislikeif (v[-1][1] > 3.0 and v[-2][1] <= 3.0) or (v[-1][1] <= 3.0 and v[-2][1] > 3.0):# print(f'v[-1][1] != v[-2][1]: {v[-1][1]}, {v[-2][1]}')if v[-1][1] > v[-2][1]:# likeoption1 = v[-1][0]# dislikeoption2 = v[-2][0]else:# likeoption1 = v[-2][0]# dislikeoption2 = v[-1][0]# chosen 是 like 在前面chosen = "<Title: " + option1 + '>\n' + "<Title: " + option2 + '>'# rejected 是 dislike 在前面rejected = "<Title: " + option2 + '>\n' + "<Title: " + option1 + '>'sample = {"instruction": instruction,"input": sample_input,"chosen": chosen,"rejected": rejected}# print(f'--------')# print(v)# print(sample)samples.append(sample)if len(samples) % 10000 == 0:print(f'Length of samples: {len(samples)}')# cc += 1# if cc > 10:#     breakprint(f'Length of samples: {len(samples)}')

第四步 划分 train 和 test 保存样本

# 首先打乱
shuffle(samples)train = samples[:int(len(samples)*0.8)]
test = samples[int(len(samples)*0.8):]print(f'总样本数: {len(samples)},训练集样本数: {len(train)},测试集样本数: {len(test)}')with open("./data/processed/rlhf_train.json", "w", encoding='utf-8') as save_file:json.dump(train, save_file, indent=4)with open("./data/processed/rlhf_test.json", "w", encoding='utf-8') as save_file:json.dump(test, save_file, indent=4) # , sort_keys=True

二、配置训练文档

rlhf_train.yaml

### model
model_name_or_path: /ZhipuAI/glm-4-9b-chat### method
stage: dpo
do_train: true
finetuning_type: lora
lora_target: all
lora_rank: 16
lora_alpha: 32
pref_beta: 0.1
pref_loss: orpo### dataset
dataset: amazon_video_games
template: glm4
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16### output
output_dir: ./saves/amazon_video_games_orpo
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 5.0e-6
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500

rlhf_inference.yaml

model_name_or_path: /ZhipuAI/glm-4-9b-chat
adapter_name_or_path: ./saves/amazon_video_games_orpo
template: glm4
finetuning_type: lora

三、模型预测

import json
from openai import OpenAI
from tqdm import tqdm# 加载模型
client = OpenAI(api_key="EMPTY",# 需要修改为大模型地址base_url="http://10.114.16.65:8000/v1/"
)
# 加载测试数据
test_file_path = "./data/processed/rlhf_test.json"
with open(test_file_path, "r", encoding='utf-8') as test_file:test_data = json.load(test_file)
print(len(test_data))
# 开始预测
labels = []
predictions = []
cc = 0
for each_test in tqdm(test_data):chat_completion = client.chat.completions.create(messages=[{"role": "system","content": each_test["instruction"]},{"role": "user","content": each_test["input"],}],model="glm4",)predictions.append(chat_completion.choices[0].message.content)labels.append(each_test["chosen"])if len(labels) % 100 == 0:correct = 0wrong = 0for l, p in zip(labels, predictions):l = l.strip()p = p.strip()# print(f'l: {l}, p: {p}')if l == p:correct += 1else:wrong += 1# print(f'\nl: {l}, \np: {p}')print(f'总样本数:{len(labels)},准确数:{correct}, 错误数:{wrong}, 准确率:{correct / len(labels)}')cc += 1# if cc > 100:#     breakassert len(predictions) == len(labels)correct = 0
wrong = 0for l, p in zip(labels, predictions):l = l.strip()p = p.strip()if l == p:correct += 1else:wrong += 1
print(f'总样本数:{len(labels)},准确数:{correct}, 错误数:{wrong}, 准确率:{correct/len(labels)}')
http://www.xdnf.cn/news/19079.html

相关文章:

  • 若依cloud集训总结
  • 汉字这颗穿越时空的智慧之光,在未来绽放出更加耀眼的光芒
  • 深入解析Java并发编程与单例模式
  • 文件系统挂载详细分析(《图解Linux内核》虚拟文件系统篇笔记三)
  • 神经网络为何能 “学习”?从神经元到深度学习模型的层级结构解析
  • 打破存储局限:CS 创世 SD NAND 如何优化瑞芯微(RK)与北京君正平台的贴片式 SD 卡性能
  • 【C++成长之旅】C++入门基础:从 Hello World 到命名空间与函数重载的系统学习
  • Bscan Bonding Chain
  • 印度尼西亚数据源 PHP 对接文档
  • Mysql——分库分表
  • Redis发布订阅:实时消息系统的极简解决方案
  • 从数字到价值:ESG评级的深层变革
  • Linux827 测试
  • 计算机日常答疑,一起寻找问题的最优解
  • LeetCode算法日记 - Day 24: 颜色分类、排序数组
  • PyTorch图像预处理完全指南:从基础操作到GPU加速实战
  • 完整实验命令解析:从集群搭建到负载均衡配置(2)
  • [vcpkg] Windows入门使用介绍
  • day22 回溯算法part01
  • 服务器类型与TCP并发服务器构建(SELECT)
  • 设计模式:桥接模式(Bridge Pattern)
  • 《Linux内存管理:实验驱动的深度探索》【附录】【实验环境搭建 7】【使用buildroot方式构建文件系统】
  • 【开发便利】让远程Linux服务器能够访问内网git仓库
  • 链表-25.k个一组翻转链表-力扣(LeetCode)
  • 深入解析 Flink Function
  • Vue将内容生成为二维码,并将所有二维码下载为图片,同时支持批量下载(下载为ZIP),含解决一次性生成过多时页面崩溃解决办法
  • TCP 并发服务器构建
  • 智芯MCU 勘误文档问题解析
  • 【Java知识】Java线程相关对象全面解析与最佳实践
  • 阿里云——应用交付与负载均衡