LPRNet实现车牌识别并完成ONNX和TensorRT推理
本文不生产技术,只做技术的搬运工!!!
前言
最近公司有项目需要实现车牌识别,最开始作者是使用CCPD2019数据集,但是后来发现其数据集非常不均衡,转而使用CBLPRD-330k_v1数据集,该数据集可以直接在modelscope上下载,这里作者就不提供了。
参考
在网上查阅资料时,发现了一个宝藏博主写的帖子,其实现的训练框架非常实用,参考连接:
https://zhuanlan.zhihu.com/p/684048137https://zhuanlan.zhihu.com/p/684048137https://github.com/MapleAura/MapleLPRNet
https://github.com/MapleAura/MapleLPRNet
训练
大家自行根据作者提供的readme进行训练即可,非常好用。
转换
原作者提供了export.py转换脚本,但是由于作者使用的是LPRNetV2,因此进行了一些改动,代码如下:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import loggingimport torchfrom model.lprnet import LPRNet, CHARS,LPRNetV2
from model.stnet import STNet
from utils.general import set_logginglogger = logging.getLogger(__name__)
set_logging()if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--weights', type=str, default='/home/project_python/MapleLPRNet/runs/exp5/weights/final.pt', help='weights path')parser.add_argument('--batch-size', type=int, default=1, help='batch size')parser.add_argument('--img-size', default=(94, 24), help='the image size')parser.add_argument('--dropout_rate', default=0.5, help='dropout rate.')opts = parser.parse_args()# 打印参数logger.info("args: %s" % opts)# Inputdevice = torch.device('cuda')img = torch.randn((opts.batch_size, 3, opts.img_size[1], opts.img_size[0]), device=device,dtype=torch.float32)# 定义网络model = LPRNetV2(8, True, class_num=len(CHARS), dropout_rate=opts.dropout_rate).to(device)logger.info("Build network is successful.")# Load weightsckpt = torch.load(opts.weights, map_location=device)# 加载网络model.load_state_dict(ckpt["model"])model.eval()# 释放内存del ckpt# # Update model# for k, m in model.named_modules():# m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability# y = model(img) # dry run# ONNX exporttry:import onnxprint('\nStarting ONNX export with onnx %s...' % onnx.__version__)f = opts.weights.replace('.pt', '.onnx') # filenametorch.onnx.export(model, img, f, verbose=True, opset_version=12, input_names=['images'], output_names=['output'],dynamic_axes={'images': {0: 'batch_size'},'output': {0: 'batch_size'}})# Checksonnx_model = onnx.load(f) # load onnx modelonnx.checker.check_model(onnx_model) # check onnx model#print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable modelprint('ONNX export success, saved as %s' % f)except Exception as e:print('ONNX export failure: %s' % e)# Finishprint('Export complete.')
ONNX推理
import numpy as np
import onnxruntime as ort
import cv2
import torch
import timeCHARS = ['京', '沪', '津', '渝', '冀', '晋', '蒙', '辽', '吉', '黑','苏', '浙', '皖', '闽', '赣', '鲁', '豫', '鄂', '湘', '粤','桂', '琼', '川', '贵', '云', '藏', '陕', '甘', '青', '宁','新', '学', '港', '澳', '警', '使', '领', '应', '急', '挂','临','0', '1', '2', '3', '4', '5', '6', '7', '8', '9','A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K','L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V','W', 'X', 'Y', 'Z', '-']def load_image(file, img_size):image = cv2.imread(file)# 缩放image = cv2.resize(image, img_size)# 归一化image = (image.astype('float32') - 127.5) * 0.007843# to tensorimage = torch.from_numpy(image.transpose((2, 0, 1))).contiguous()return imagedef decode(preds):last_chars_idx = len(CHARS) - 1# greedy decodepred_labels = []labels = []for i in range(preds.shape[0]):pred = preds[i, :, :]pred_label = []for j in range(pred.shape[1]):pred_label.append(np.argmax(pred[:, j], axis=0))no_repeat_blank_label = []pre_c = -1for c in pred_label: # dropout repeate label and blank labelif (pre_c == c) or (c == last_chars_idx):if c == last_chars_idx:pre_c = ccontinueno_repeat_blank_label.append(c)pre_c = cpred_labels.append(no_repeat_blank_label)for _, label in enumerate(pred_labels):lb = ""for i in label:lb += CHARS[i]labels.append(lb)return labels, pred_labelsdef pred_deal(prebs):preb_labels = list()for i in range(prebs.shape[0]):preb = prebs[i, :, :]preb_label = list()for j in range(preb.shape[1]):preb_label.append(np.argmax(preb[:, j], axis=0))no_repeat_blank_label = list()pre_c = preb_label[0]if pre_c != len(CHARS) - 1:no_repeat_blank_label.append(pre_c)for c in preb_label: # dropout repeate label and blank labelif (pre_c == c) or (c == len(CHARS) - 1):if c == len(CHARS) - 1:pre_c = ccontinueno_repeat_blank_label.append(c)pre_c = cpreb_labels.append(no_repeat_blank_label)return preb_labelsif __name__ == '__main__':val_txt = "/home/project_python/LPRNet_Mine/dataset/CBLPRD-330k_v1/val.txt"onnx_model = "/home/project_python/MapleLPRNet/runs/exp5/weights/final.onnx"error_output = "./111/error_input.txt"flag = ""image_size = (94,24)session = ort.InferenceSession(onnx_model, providers=['CUDAExecutionProvider'])input_name = session.get_inputs()[0].nameoutput_name = session.get_outputs()[0].nameright_num = 0all_num = 0start = time.time()with open(val_txt, 'r') as f:lines = f.readlines()all_num = len(lines)for line in lines:if len(flag)>=2:if line.split(" ")[0]==flag:image_path = line.split(" ")[0]image = load_image(image_path, image_size)image = image.numpy()image = np.expand_dims(image, axis=0)# with open("./111/onnx_input.txt", "w") as f:# f.write(str(image))pred = session.run([output_name], {input_name: image})[0]labels, pred_labels = decode(pred)if labels[0] == line.split(" ")[1].replace("I","1").replace("O","0"):right_num += 1#print("target:", line.split(" ")[1], "predict:", labels[0])else:print("target:", line.split(" ")[1], "predict:", labels[0])print(line)else:continueelse:image_path = line.split(" ")[0]image = load_image(image_path, image_size)image = image.numpy()image = np.expand_dims(image, axis=0)# with open("./111/onnx_input.txt", "w") as f:# f.write(str(image))pred = session.run([output_name], {input_name: image})[0]labels, pred_labels = decode(pred)if labels[0] == line.split(" ")[1].replace("I", "1").replace("O", "0"):right_num += 1# print("target:", line.split(" ")[1], "predict:", labels[0])else:with open(error_output, "a") as f:f.write(line)print("target:", line.split(" ")[1], "predict:", labels[0])print(line)print("acc:", right_num / all_num)print("time:", time.time() - start)
TensorRT推理
import numpy as np
import tensorrt as trt
import cv2
import torch
import pycuda.driver as cuda
import pycuda.autoinit
import os
import timeCHARS = ['京', '沪', '津', '渝', '冀', '晋', '蒙', '辽', '吉', '黑','苏', '浙', '皖', '闽', '赣', '鲁', '豫', '鄂', '湘', '粤','桂', '琼', '川', '贵', '云', '藏', '陕', '甘', '青', '宁','新', '学', '港', '澳', '警', '使', '领', '应', '急', '挂','临','0', '1', '2', '3', '4', '5', '6', '7', '8', '9','A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K','L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V','W', 'X', 'Y', 'Z', '-']def load_engine(engine_file_path):TRT_LOGGER = trt.Logger(trt.Logger.WARNING)with open(engine_file_path, 'rb') as f, trt.Runtime(TRT_LOGGER) as runtime:return runtime.deserialize_cuda_engine(f.read())def create_context(engine):return engine.create_execution_context()def allocate_buffers(engine):inputs = []outputs = []bindings = []stream = cuda.Stream()for binding in engine:size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_sizedtype = trt.nptype(engine.get_binding_dtype(binding))host_mem = cuda.pagelocked_empty(size, dtype)device_mem = cuda.mem_alloc(host_mem.nbytes)bindings.append(int(device_mem))if engine.binding_is_input(binding):inputs.append({'host': host_mem, 'device': device_mem})else:outputs.append({'host': host_mem, 'device': device_mem})return inputs, outputs, bindings, stream# -----------------------------
# 推理函数
# -----------------------------
def infer(context, inputs, outputs, bindings, stream, input_data):# Transfer input data to the GPU[np.copyto(i['host'], input_data.ravel().astype(np.float32)) for i in inputs][cuda.memcpy_htod_async(i['device'], i['host'], stream) for i in inputs]# Execute the modelcontext.execute_async_v2(bindings=bindings, stream_handle=stream.handle)# Transfer predictions back from the GPU[cuda.memcpy_dtoh_async(o['host'], o['device'], stream) for o in outputs]# Synchronize the streamstream.synchronize()# Return the host outputreturn [o['host'] for o in outputs]# -----------------------------
# 图像预处理
# -----------------------------
def load_image(file, img_size):image = cv2.imread(file)if image is None:raise FileNotFoundError(f"Image file {file} not found.")image = cv2.resize(image, img_size) # BGR to RGB + resizeimage = (image.astype('float32') - 127.5) * 0.007843image = torch.from_numpy(image.transpose((2, 0, 1))).contiguous()return image# -----------------------------
# 解码函数
# -----------------------------
def decode(preds):last_chars_idx = len(CHARS) - 1pred_labels = []labels = []for i in range(preds.shape[0]):pred = preds[i, :, :]pred_label = []for j in range(pred.shape[1]):pred_label.append(np.argmax(pred[:, j], axis=0))no_repeat_blank_label = []pre_c = -1for c in pred_label:if (pre_c == c) or (c == last_chars_idx):if c == last_chars_idx:pre_c = ccontinueno_repeat_blank_label.append(c)pre_c = cpred_labels.append(no_repeat_blank_label)for _, label in enumerate(pred_labels):lb = ""for i in label:lb += CHARS[i]labels.append(lb)return labels, pred_labels# -----------------------------
# 主函数:批量验证
# -----------------------------
if __name__ == '__main__':right_num = 0all_num = 0error_output = "./111/error_input_trt.txt"trt_engine_path = "/home/project_python/MapleLPRNet/runs/exp5/weights/final.engine" # 替换为你的 TensorRT 模型路径val_txt = "/home/project_python/LPRNet_Mine/dataset/CBLPRD-330k_v1/val.txt"image_size = (94, 24)engine = load_engine(trt_engine_path)context = create_context(engine)inputs, outputs, bindings, stream = allocate_buffers(engine)start = time.time()with open(val_txt, 'r') as f:lines = f.readlines()all_num = len(lines)for line in lines:try:image_path = line.split(" ")[0]label = line.split(" ")[1]except ValueError:print(f"Invalid line: {line}")continuetry:image = load_image(image_path, image_size)except Exception as e:print(f"Error loading image {image_path}: {e}")continueimage = image.numpy()image = np.expand_dims(image, axis=0)# 推理pred = infer(context, inputs, outputs, bindings, stream, image)pred = np.array(pred).reshape((1,76,18))# 解码labels, _ = decode(pred)# 校验target = label.replace("I", "1").replace("O", "0")predict = labels[0]if predict == target:right_num += 1else:print(f"target: {target}, predict: {predict}")with open(error_output, "a") as ef:ef.write(f"{image_path} {label} {predict}\n")print(f"acc: {right_num / all_num:.4f}")print(f"time: {time.time() - start:.4f}")
注意事项
1.原作者提供的detect.py关于图像预处理的复现是有问题的,多做了一步通道转换,大家可以使用本文提供的预处理方式
2.中文车牌不包含I和O,而使用的数据集是使用GAN生成的,包含有I和O,原作者在dataloader中对这类数据进行了处理,因此当我们自己写推理时,在后处理时应将I和O做替换