当前位置: 首页 > backend >正文

python使用transformer库推理

代码

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer# 1. load model
model_path = "/ssd3/models/Qwen2.5-0.5B-Instruct/"model = AutoModelForCausalLM.from_pretrained(model_path,device_map='cuda',torch_dtype=torch.float16,
)# 2. init tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Need to set the padding token to the eos token for generation
tokenizer.pad_token = tokenizer.eos_tokenprompts = ["你是谁",
]for prompt in prompts:messages = [{"role": "user", "content": prompt},]batch = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)# 3. tokenizemodel_inputs = tokenizer([batch], return_tensors="pt").to('cuda')# model_inputs = tokenizer([prompt], padding=True, truncation=True, return_tensors="pt").to('cuda')# 4. infergenerated_ids = model.generate(**model_inputs, max_new_tokens=16)generated_ids = [output_ids[len(input_ids) :]for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]# 5. detokenizeresponse = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)print(response)

debug信息

# debug model: 
Qwen2ForCausalLM((model): Qwen2Model((embed_tokens): Embedding(151936, 896)(layers): ModuleList((0-23): 24 x Qwen2DecoderLayer((self_attn): Qwen2Attention((q_proj): Linear(in_features=896, out_features=896, bias=True)(k_proj): Linear(in_features=896, out_features=128, bias=True)(v_proj): Linear(in_features=896, out_features=128, bias=True)(o_proj): Linear(in_features=896, out_features=896, bias=False))(mlp): Qwen2MLP((gate_proj): Linear(in_features=896, out_features=4864, bias=False)(up_proj): Linear(in_features=896, out_features=4864, bias=False)(down_proj): Linear(in_features=4864, out_features=896, bias=False)(act_fn): SiLU())(input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)(post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)))(norm): Qwen2RMSNorm((896,), eps=1e-06)(rotary_emb): Qwen2RotaryEmbedding())(lm_head): Linear(in_features=896, out_features=151936, bias=False)
)# debug batch: 
<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
你是谁<|im_end|>
<|im_start|>assistant# debug model_inputs: 
{'input_ids': tensor([[151644,   8948,    198,   2610,    525,   1207,  16948,     11,   3465,553,  54364,  14817,     13,   1446,    525,    264,  10950,  17847,13, 151645,    198, 151644,    872,    198, 105043, 100165, 151645,198, 151644,  77091,    198]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}# debug generated_ids: 
[tensor([104198,     48,  16948,   3837, 102661,  99718, 102014, 104491],device='cuda:0')]['我是Qwen,阿里云推出的一种']

http://www.xdnf.cn/news/20146.html

相关文章:

  • Leetcode—721. 账户合并【中等】
  • Mattermost教程:用Docker搭建自己的开源Slack替代品 (团队聊天)
  • PyTorch训练循环详解:深入理解forward()、backward()和optimizer.step()
  • 光伏项目无人机踏勘--如何使用无人机自动航线规划APP
  • VMware替代 | ZStack生产级跨版本热升级等七大要素降低TCO50%
  • HDFS存储农业大数据的秘密是什么?高级大豆数据分析与可视化系统架构设计思路
  • OpenLayers常用控件 -- 章节五:鹰眼地图控件教程
  • 修改上次提交的Git提交日志
  • CodePerfAI体验:AI代码性能分析工具如何高效排查性能瓶颈、优化SQL执行耗时?
  • 《sklearn机器学习——聚类性能指标》调整兰德指数、基于互信息(mutual information)的得分
  • Mysql中模糊匹配常被忽略的坑
  • Netty从0到1系列之Netty整体架构、入门程序
  • Python迭代协议完全指南:从基础到高并发系统实现
  • 投资储能项目能赚多少钱?小程序帮你测算
  • Unity2022.3.41的TargetSdk更新到APILevel 35问题
  • Fairness, bias, and ethics|公平,偏见与伦理
  • 【科研绘图系列】R语言绘制论文合集图
  • 高等数学知识补充:三角函数
  • 脚本语言的大浪淘沙或百花争艳
  • JUnit入门:Java单元测试全解析
  • Boost搜索引擎 查找并去重(3)
  • 输入网址到网页显示的整个过程
  • 孙宇晨钱包被列入黑名单,WLFI代币价格暴跌引发中心化争议
  • Unix/Linux 平台通过 IP 地址获取接口名的 C++ 实现
  • 告别 “无效阅读”!2025 开学季超赞科技书单,带孩子解锁 AI、编程新技能
  • Docker部署PanSou 一款开源网盘搜索项目,集成前后端,一键部署
  • 基于单片机汽车防撞系统设计
  • validator列表校验
  • OCA、OCP、OCM傻傻分不清?Oracle认证就看这篇
  • 四六级学习资料管理系统的设计与实现(代码+数据库+LW)