当前位置: 首页 > ops >正文

【TVM 教程】microTVM PyTorch 教程

Apache TVM 是一个深度的深度学习编译框架,适用于 CPU、GPU 和各种机器学习加速芯片。更多 TVM 中文文档可访问 →https://tvm.hyper.ai/

作者:Mehrdad Hessar

该教程展示了如何使用 PyTorch 模型进行 microTVM 主机驱动的 AOT 编译。此教程可以在使用 C 运行时(CRT)的 x86 CPU 上执行。

注意: 此教程仅在使用 CRT 的 x86 CPU 上运行,不支持在 Zephyr 上运行,因为该模型不适用于我们当前支持的 Zephyr 单板。

安装 microTVM Python 依赖项

TVM 不包含用于 Python 串行通信包,因此在使用 microTVM 之前我们必须先安装一个。我们还需要TFLite来加载模型。

pip install pyserial==3.5 tflite==2.1
import pathlib
import torch
import torchvision
from torchvision import transforms
import numpy as np
from PIL import Imageimport tvm
from tvm import relay
from tvm.contrib.download import download_testdata
from tvm.relay.backend import Executor
import tvm.micro.testing

加载预训练 PyTorch 模型

首先,从 torchvision 中加载预训练的 MobileNetV2 模型。然后,下载一张猫的图像并进行预处理,以便用作模型的输入。

model = torchvision.models.quantization.mobilenet_v2(weights="DEFAULT", quantize=True)
model = model.eval()input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
img_path = download_testdata(img_url, "cat.png", module="data")
img = Image.open(img_path).resize((224, 224))# 预处理图片并转换为张量
my_preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),]
)
img = my_preprocess(img)
img = np.expand_dims(img, 0)input_name = "input0"
shape_list = [(input_name, input_shape)]
relay_mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)

输出:

/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torch/ao/quantization/utils.py:310: UserWarning: must run observer before calling calculate_qparams. Returning default values.warnings.warn(
Downloading: "https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth" to /workspace/.cache/torch/hub/checkpoints/mobilenet_v2_qnnpack_37f702c5.pth0%|          | 0.00/3.42M [00:00<?, ?B/s]61%|######    | 2.09M/3.42M [00:00<00:00, 11.6MB/s]
100%|##########| 3.42M/3.42M [00:00<00:00, 18.5MB/s]
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torch/_utils.py:314: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()device=storage.device,
/workspace/python/tvm/relay/frontend/pytorch_utils.py:47: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.return LooseVersion(torch_ver) > ver
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/setuptools/_distutils/version.py:346: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.other = LooseVersion(other)

定义目标、运行时与执行器

在本教程中,我们使用 AOT 主机驱动执行器。为了在 x86 机器上对嵌入式模拟环境编译模型,我们使用 C 运行时(CRT),并使用主机微型目标。使用该设置,TVM 为 C 运行时编译可以在 x86 CPU 机器上运行的模型,可以在物理微控制器上运行的相同流程。CRT 使用 src/runtime/crt/host/main.cc 中的 main()。要使用物理硬件,请将 board 替换为另一个物理微型目标,例如 nrf5340dk_nrf5340_cpuapp 或 mps2_an521,并将平台类型更改为 Zephyr。在《为 Arduino 上的 microTVM 训练视觉模型》和《microTVM TFLite 教程》中,可以看到 更多目标示例。

target = tvm.micro.testing.get_target(platform="crt", board=None)# 使用 C 运行时 (crt) 并通过设置 system-lib 为 True 打开静态链接
runtime = tvm.relay.backend.Runtime("crt", {"system-lib": True})# 使用 AOT 执行器代替图或 vm 执行器。不要使用未包装的 API 或 C 风格调用
executor = Executor("aot")

编译模型

现在为目标编译模型:

with tvm.transform.PassContext(opt_level=3,config={"tir.disable_vectorize": True},
):module = tvm.relay.build(relay_mod, target=target, runtime=runtime, executor=executor, params=params)

创建 microTVM 项目

现在,我们已经将编译好的模型作为 IRModule 准备好,我们还需要创建一个固件项目,以便在 microTVM 中使用编译好的模型。为此,我们需要使用 Project API。

template_project_path = pathlib.Path(tvm.micro.get_microtvm_template_projects("crt"))
project_options = {"verbose": False, "workspace_size_bytes": 6 * 1024 * 1024}temp_dir = tvm.contrib.utils.tempdir() / "project"
project = tvm.micro.generate_project(str(template_project_path),module,temp_dir,project_options,
)

构建、烧录和执行模型

接下来,我们构建 microTVM项 目并进行烧录。烧录步骤特定于物理微控制器,如果通过主机的 main.cc 模拟微控制器,或者选择 Zephyr 模拟单板作为目标,则会跳过该步骤。

project.build()
project.flash()input_data = {input_name: tvm.nd.array(img.astype("float32"))}
with tvm.micro.Session(project.transport()) as session:aot_executor = tvm.runtime.executor.aot_executor.AotModule(session.create_aot_executor())aot_executor.set_input(**input_data)aot_executor.run()result = aot_executor.get_output(0).numpy()

查询 Synset 名称

查询在 1000 个类别 Synset 中的 top-1 的预测。

synset_url = ("https://raw.githubusercontent.com/Cadene/""pretrained-models.pytorch/master/data/""imagenet_synsets.txt"
)
synset_name = "imagenet_synsets.txt"
synset_path = download_testdata(synset_url, synset_name, module="data")
with open(synset_path) as f:synsets = f.readlines()synsets = [x.strip() for x in synsets]
splits = [line.split(" ") for line in synsets]
key_to_classname = {spl[0]: " ".join(spl[1:]) for spl in splits}class_url = ("https://raw.githubusercontent.com/Cadene/""pretrained-models.pytorch/master/data/""imagenet_classes.txt"
)
class_path = download_testdata(class_url, "imagenet_classes.txt", module="data")
with open(class_path) as f:class_id_to_key = f.readlines()class_id_to_key = [x.strip() for x in class_id_to_key]# Get top-1 result for TVM
top1_tvm = np.argmax(result)
tvm_class_key = class_id_to_key[top1_tvm]# Convert input to PyTorch variable and get PyTorch result for comparison
with torch.no_grad():torch_img = torch.from_numpy(img)output = model(torch_img)# Get top-1 result for PyTorchtop1_torch = np.argmax(output.numpy())torch_class_key = class_id_to_key[top1_torch]print("Relay top-1 id: {}, class name: {}".format(top1_tvm, key_to_classname[tvm_class_key]))
print("Torch top-1 id: {}, class name: {}".format(top1_torch, key_to_classname[torch_class_key]))

输出结果:

Relay top-1 id: 282, class name: tiger cat
Torch top-1 id: 282, class name: tiger cat

该脚本总运行时间:(1分26.552秒)

下载 Python 源代码:micro_pytorch.py

下载 Jupyter notebook:micro_pytorch.ipynb

http://www.xdnf.cn/news/5620.html

相关文章:

  • 如何快速入门大模型?
  • 【套题】GESP C++四级认证各题详解/详细代码
  • 查看购物车
  • sql语句面经手撕(定制整理版)
  • MYSQL 全量,增量备份与恢复
  • HTTP3
  • 一次IPA被破解后的教训(附Ipa Guard等混淆工具实测)
  • [Java] 输入输出方法+猜数字游戏
  • 支持私有化部署的小天互连即时通讯平台:助力企业数字化转型的通讯利器
  • lenis选项卡举例
  • LeetCode 373 查找和最小的 K 对数字题解
  • Git安装教程及常用命令
  • 【DeepSeek问答记录】请结合实例,讲解一下pytorch的DataLoader的使用方法
  • 11 配置Hadoop集群-免密登录
  • 一文读懂如何使用MCP创建服务器
  • ARMV8 RK3399 u-boot TPL启动流程分析 --crt0.S
  • 恰到好处TDR
  • SID310S/D/Q-10MHz, 低噪声, 轨至轨, CMOS 运算放大器
  • 二叉树路径总和
  • 10:00开始面试,10:08就出来了,问的问题有点变态。。。
  • wordcount在mapreduce的例子
  • 解读RTOS:第二篇 · 线程/任务管理与调度策略
  • WebGIS开发新突破:揭秘未来地理信息系统的神秘面纱
  • 回答 | 图形数据库neo4j社区版可以应用小型企业嘛?
  • 宇树科技安全漏洞揭示智能机器人行业隐忧
  • 视频翻译软件有哪些?推荐5款视频翻译工具[特殊字符][特殊字符]
  • 树莓派4 yolo 11l.pt性能优化后的版本
  • 摆脱拖延症的详细计划示例
  • Java根据文件名前缀自动分组图片文件
  • 社交APP如何借助游戏盾守护业务稳定