【LLM】——qwen2.5 VL模型导出到onnx
说明:将qwen2.5 vl模型导出到onnx,用onnxruntime进行推理
文章目录
- 1. visual 导出
- 2. vlmodel导出
1. visual 导出
这里有两种方式,第一种方式输入为pathes,第二种方式输入为image。
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
from transformers import AutoConfig
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
import osmodel_id = "/data1/chenjun/huf/Qwen2.5-VL-3B-Instruct"
devices = "cpu"
dtype = torch.float32
onnx_model = "output/qwen25vl/qwen2.5_visual.onnx"
os.makedirs(os.path.dirname(onnx_model), exist_ok=True)model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_id,torch_dtype=dtype,attn_implementation="eager",device_map=devices)
visual = model.visualimage_height = 476
image_width = 812def export_method1():image = torch.randn(1, 3, image_height, image_width).to(dtype=dtype).to(devices)grid_thw = torch.tensor([[1, image_height//14, image_width//14]], dtype=torch.int64).to(devices)dynamic_axes = {'image': {2: 'image_height', 3: 'image_width'}}grid_t = 1merge_size = 2channel = 3temporal_patch_size = 2patch_size = 14grid_h, grid_w = image_height // patch_size, image_width // patch_sizeori_forward = visual.forward ## Save the original forward methoddef export_onnx(model):def temp(image, grid_thw):patches = image.repeat(temporal_patch_size, 1, 1, 1)patches = patches.reshape(grid_t, temporal_patch_size, channel, grid_h//merge_size, merge_size, patch_size, grid_w//merge_size, merge_size, patch_size)patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)flatten_patches = patches.reshape(grid_t * grid_h * grid_w, channel * temporal_patch_size * patch_size * patch_size)feature = ori_forward(flatten_patches, grid_thw)return featurereturn tempvisual.forward = export_onnx(model)feature = visual(image, grid_thw)torch.onnx.export(visual, (image, grid_thw),onnx_model,input_names=['image', 'grid_thw'],output_names=['image_embeds'],dynamic_axes=dynamic_axes,do_constant_folding=True,verbose=True,opset_version=17)print(f"ONNX model exported to {onnx_model} successfully.")def export_method2():pathes = torch.randn(image_height//28 * image_width//28, 1176).to(dtype=dtype).to(devices)grid_thw = torch.tensor([[1, image_height//14, image_width//14]], dtype=torch.int64).to(devices)dynamic_axes = {'patches': {0: 'token_len'}}feature = visual(pathes, grid_thw)torch.onnx.export(visual, (pathes, grid_thw),onnx_model,input_names=['pathes', 'grid_thw'],output_names=['image_embeds'],dynamic_axes=dynamic_axes,do_constant_folding=True,verbose=True,opset_version=17)print(f"ONNX model exported to {onnx_model} successfully.")if __name__ == "__main__":export_method1()# export_method2()
2. vlmodel导出
主要包括forward改写,pytorch验证,onnx导出,onnxruntime推理
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
from transformers import AutoConfig, Qwen2VLForConditionalGeneration
from transformers.cache_utils import DynamicCache
from eagle.model.own_utils import prepare_attention_mask
import math
import osfrom eagle.onnx.onnx_runner import OnnxRunner, Qwen2_5_VLModel, Qwen2_5_VisualModelmodel_id = "/data1/chenjun/huf/Qwen2.5-VL-3B-Instruct"
devices = "cpu"
dtype = torch.float32
onnx_model = "output/qwen25vl_vl/qwen2.5_vl_vl.onnx"
os.makedirs(os.path.dirname(onnx_model), exist_ok=True)visual_onnx_model = "output/qwen25vl/qwen2.5_visual.onnx"model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_id,torch_dtype=dtype,attn_implementation="eager",device_map=devices,
)def rotate_half(x):"""Rotates half the hidden dims of the input."""x1 = x[..., : x.shape[-1] // 2]x2 = x[..., x.shape[-1] // 2 :]return torch.cat((-x2, x1), dim=-1)def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:batch, num_key_value_heads, slen, head_dim = hidden_states.shapehidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)def forward_new_onnx(self):self.variance_epsilon = float(1e-6)self.save_key = [None] * len(self.model.layers)self.save_value = [None] * len(self.model.layers)self.num_heads = self.config.num_attention_headsself.head_dim = self.config.hidden_size // self.num_headsself.num_layers = len(self.model.layers)def tmp(*inputs):inputs_embeds = inputs[-3] position_ids = inputs[-2] attention_mask = inputs[-1]hidden_states = inputs_embedscos_e, sin_e = self.model.rotary_emb(hidden_states, position_ids)mrope_section = self.model.layers[0].self_attn.rope_scaling["mrope_section"] * 2cos = torch.cat([m[i % 3] for i, m in enumerate(cos_e.split(mrope_section, dim=-1))], dim=-1).unsqueeze(1)sin = torch.cat([m[i % 3] for i, m in enumerate(sin_e.split(mrope_section, dim=-1))], dim=-1).unsqueeze(1)for i, layer in enumerate(self.model.layers):bsz, q_len, _ = hidden_states.size()hidden_states_norm = layer.input_layernorm.weight * (hidden_states / torch.sqrt(hidden_states.pow(2).mean(-1, keepdim=True) + self.variance_epsilon))q = layer.self_attn.q_proj(hidden_states_norm).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)k = layer.self_attn.k_proj(hidden_states_norm).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)v = layer.self_attn.v_proj(hidden_states_norm).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)q = (q * cos) + (rotate_half(q) * sin)k = (k * cos) + (rotate_half(k) * sin)k = torch.cat([inputs[i], k], dim=2)v = torch.cat([inputs[i + self.num_layers], v], dim=2)self.save_key[i] = kself.save_value[i] = vk = repeat_kv(k, layer.self_attn.num_key_value_groups)v = repeat_kv(v, layer.self_attn.num_key_value_groups)attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(layer.self_attn.head_dim) + attention_maskattn = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)attn_out = torch.matmul(attn, v)attn_out = layer.self_attn.o_proj(attn_out.transpose(1, 2).contiguous().view(1, -1, layer.self_attn.o_proj.in_features))hidden_states += attn_outresidual = hidden_stateshidden_states = layer.post_attention_layernorm.weight * (hidden_states / torch.sqrt(hidden_states.pow(2).mean(-1, keepdim=True) + self.variance_epsilon))hidden_states = layer.mlp.down_proj(layer.mlp.act_fn(layer.mlp.gate_proj(hidden_states)) * layer.mlp.up_proj(hidden_states))hidden_states += residualhidden_states = self.model.norm(hidden_states)logits = self.lm_head(hidden_states)return *self.save_key, *self.save_value, torch.argmax(logits, dim=-1, keepdim=True)return tmpmodel.model.forward = forward_new_onnx(model)def torch_infer():processor = AutoProcessor.from_pretrained(model_id)messages = [{"role": "user","content": [{"type": "image","image": "asserts/imgs/person.png",},{"type": "text", "text": "请详细描述图像"},],}]# Preparation for inferencetext = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)image_inputs, video_inputs = process_vision_info(messages)inputs = processor(text=[text],images=image_inputs,videos=video_inputs,padding=True,return_tensors="pt",)inputs = inputs.to(devices)input_ids = inputs.input_idspixel_values = inputs.pixel_valuesimage_grid_thw = inputs.image_grid_thwinputs_embeds = model.model.embed_tokens(inputs.input_ids)image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)## 将图像编码嵌入到token_embeds中n_image_tokens = (input_ids == model.config.image_token_id).sum().item()n_image_features = image_embeds.shape[0]assert n_image_tokens == n_image_features, f"Expected {n_image_tokens} image tokens, but got {n_image_features} image features."mask = input_ids == model.config.image_token_idmask_unsqueezed = mask.unsqueeze(-1)mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)image_mask = mask_expanded.to(inputs_embeds.device)image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)pass## 准备模型的输入,所有的图像token共用一个position_valuenum_heads = model.config.num_attention_headshead_dim = model.config.hidden_size // num_headsnum_key_value_heads = model.config.num_key_value_headsnum_layers = model.config.num_hidden_layerspast_keys_A = torch.zeros((1, num_key_value_heads, 0, head_dim), dtype=torch.float32)past_values_A = torch.zeros((1, num_key_value_heads, 0, head_dim), dtype=torch.float32)input_feed = []for i in range(num_layers):input_feed.append(past_keys_A)for i in range(num_layers):input_feed.append(past_values_A)attention_mask = prepare_attention_mask(input_ids.shape[1], 0, devices)position_ids, rope_deltas = model.get_rope_index(input_ids,image_grid_thw)model.rope_deltas = rope_deltasinput_feed.extend([inputs_embeds, position_ids, attention_mask])## 语言模型推理outputs = model.model(*input_feed)tokens = [outputs[-1][:, -1]]new_tokens = 0while new_tokens < 128:input_ids = tokens[-1]inputs_embeds = model.model.embed_tokens(input_ids)history_len = outputs[0].shape[2]batch_size, seq_length, _ = inputs_embeds.shapeposition_ids = torch.arange(seq_length, device=inputs_embeds.device)delta = history_len + model.rope_deltasposition_ids = position_ids.add(delta)position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)attention_mask = prepare_attention_mask(input_ids.shape[1], history_len, devices)input_feed = []for i in range(num_layers * 2):input_feed.append(outputs[i])input_feed.extend([inputs_embeds, position_ids, attention_mask])outputs = model.model(*input_feed)next_token = outputs[-1][:, -1]new_tokens += 1tokens.append(next_token)tokens = torch.cat(tokens, dim=-1)output_text = processor.batch_decode(tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)print(output_text)def export_onnx_and_demo():## 导出到onnxnum_heads = model.config.num_attention_headshead_dim = model.config.hidden_size // num_headsnum_key_value_heads = model.config.num_key_value_headsnum_layers = model.config.num_hidden_layersseq_len = 518past_keys_A = torch.zeros((1, num_key_value_heads, 0, head_dim), dtype=torch.float32)past_values_A = torch.zeros((1, num_key_value_heads, 0, head_dim), dtype=torch.float32)inputs_embeds = torch.randn(1, seq_len, model.config.hidden_size, dtype=dtype)position_ids = torch.zeros((3, 1, seq_len), dtype=torch.int64)attention_mask = torch.ones((1, 1, seq_len, seq_len), dtype=torch.int64)input_feed = []for i in range(num_layers):input_feed.append(past_keys_A)for i in range(num_layers):input_feed.append(past_values_A)input_feed.extend([inputs_embeds, position_ids, attention_mask])input_names = []out_names = []dynamic_axes = {}for i in range(num_layers):input_names.append(f"in_past_key_{i}")out_names.append(f"out_past_key_{i}")dynamic_axes[f"in_past_key_{i}"] = {2: 'seq_len'}dynamic_axes[f"out_past_key_{i}"] = {2: 'seq_len'}for i in range(num_layers):input_names.append(f"in_past_value_{i}")out_names.append(f"out_past_value_{i}")dynamic_axes[f"in_past_value_{i}"] = {2: 'seq_len'}dynamic_axes[f"out_past_value_{i}"] = {2: 'seq_len'}input_names.extend(['inputs_embeds', 'position_ids', 'attention_mask'])out_names.append('logits')dynamic_axes['inputs_embeds'] = {1: 'seq_len'}dynamic_axes['position_ids'] = {2: 'seq_len'}dynamic_axes['attention_mask'] = {2: 'seq_len', 3: 'history_len_plus_ids_len'}dynamic_axes['logits'] = {1: 'seq_len'}if not os.path.exists(onnx_model):print(f"Exporting ONNX model to {onnx_model}...")torch.onnx.export(model.model, tuple(input_feed),onnx_model,input_names=input_names,output_names=out_names,dynamic_axes=dynamic_axes,do_constant_folding=True,verbose=True,opset_version=17)print(f"ONNX model exported to {onnx_model} successfully.")onnxmodel = Qwen2_5_VLModel(onnx_model)## 准备输入数据processor = AutoProcessor.from_pretrained(model_id)messages = [{"role": "user","content": [{"type": "image","image": "asserts/imgs/person.png",},{"type": "text", "text": "请详细描述图像"},],}]# Preparation for inferencetext = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)image_inputs, video_inputs = process_vision_info(messages)inputs = processor(text=[text],images=image_inputs,videos=video_inputs,padding=True,return_tensors="pt",)inputs = inputs.to(devices)input_ids = inputs.input_idspixel_values = inputs.pixel_valuesimage_grid_thw = inputs.image_grid_thwinputs_embeds = model.model.embed_tokens(inputs.input_ids)## method 1: 使用torch的visual模型# image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)## method 2: 使用onnx的visual模型qwen2_5_visual = Qwen2_5_VisualModel(visual_onnx_model)image_embeds = qwen2_5_visual([pixel_values, image_grid_thw])## 将图像编码嵌入到token_embeds中n_image_tokens = (input_ids == model.config.image_token_id).sum().item()n_image_features = image_embeds.shape[0]assert n_image_tokens == n_image_features, f"Expected {n_image_tokens} image tokens, but got {n_image_features} image features."mask = input_ids == model.config.image_token_idmask_unsqueezed = mask.unsqueeze(-1)mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)image_mask = mask_expanded.to(inputs_embeds.device)image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)pass## 准备模型的输入,所有的图像token共用一个position_valuenum_heads = model.config.num_attention_headshead_dim = model.config.hidden_size // num_headsnum_key_value_heads = model.config.num_key_value_headsnum_layers = model.config.num_hidden_layerspast_keys_A = torch.zeros((1, num_key_value_heads, 0, head_dim), dtype=torch.float32)past_values_A = torch.zeros((1, num_key_value_heads, 0, head_dim), dtype=torch.float32)input_feed = []for i in range(num_layers):input_feed.append(past_keys_A)for i in range(num_layers):input_feed.append(past_values_A)attention_mask = prepare_attention_mask(input_ids.shape[1], 0, devices)position_ids, rope_deltas = model.get_rope_index(input_ids,image_grid_thw)model.rope_deltas = rope_deltasinput_feed.extend([inputs_embeds, position_ids, attention_mask])outputs = onnxmodel(input_feed)tokens = [outputs[-1][:, -1]]new_tokens = 0while new_tokens < 128:input_ids = tokens[-1]inputs_embeds = model.model.embed_tokens(input_ids)history_len = outputs[0][0].shape[2]batch_size, seq_length, _ = inputs_embeds.shapeposition_ids = torch.arange(seq_length, device=inputs_embeds.device)delta = history_len + model.rope_deltasposition_ids = position_ids.add(delta)position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)attention_mask = prepare_attention_mask(input_ids.shape[1], history_len, devices)input_feed = []for i in range(num_layers * 2):input_feed.append(outputs[0][i])input_feed.extend([inputs_embeds, position_ids, attention_mask])outputs = onnxmodel(input_feed)next_token = outputs[-1][:, -1]new_tokens += 1tokens.append(next_token)tokens = torch.cat(tokens, dim=-1)output_text = processor.batch_decode(tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)print(output_text)if __name__ == "__main__":# torch_infer()export_onnx_and_demo()