基于Transformer 实现车辆检测与车牌识别(一)
系统架构概述
我们将构建一个两阶段系统:
车辆检测阶段:使用 Deformable DETR 检测图像中的车辆
车牌识别阶段:使用 TrOCR (Transformer-based OCR) 识别裁剪出的车牌区域
环境配置
首先安装所需的依赖库:
bash
# 创建虚拟环境 python -m venv vehicle-transformer source vehicle-transformer/bin/activate # Linux/Mac # 或 vehicle-transformer\Scripts\activate # Windows# 安装核心依赖 pip install torch torchvision torchaudio pip install transformers pip install opencv-python pip install Pillow pip install matplotlib pip install scipy pip install pyyaml pip install timm pip install easydict pip install shapely
第一阶段:车辆检测
1. 准备数据集
创建数据集目录结构:
text
dataset/ ├── train/ │ ├── images/ │ └── labels.json ├── val/ │ ├── images/ │ └── labels.json └── test/├── images/└── labels.json
使用 COCO 格式的标注文件 labels.json
:
json
{"images": [{"id": 1,"file_name": "image_001.jpg","width": 1920,"height": 1080}],"annotations": [{"id": 1,"image_id": 1,"category_id": 1,"bbox": [x, y, width, height],"area": area,"iscrowd": 0}],"categories": [{"id": 1,"name": "car"},{"id": 2,"name": "license_plate"}] }
2. 实现车辆检测模型
创建 vehicle_detector.py
:
python
import torch import torchvision from transformers import DeformableDetrConfig, DeformableDetrForObjectDetection from PIL import Image import cv2 import numpy as np import osclass VehicleDetector:def __init__(self, model_path=None):if model_path and os.path.exists(model_path):# 加载已训练的模型self.model = DeformableDetrForObjectDetection.from_pretrained(model_path)else:# 初始化新模型config = DeformableDetrConfig(num_queries=300,num_labels=2, # 车辆和车牌两类)self.model = DeformableDetrForObjectDetection(config)self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.model.to(self.device)self.model.eval()# 图像预处理self.transform = torchvision.transforms.Compose([torchvision.transforms.Resize(800),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])def preprocess_image(self, image):if isinstance(image, np.ndarray):image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))return self.transform(image).unsqueeze(0).to(self.device)def detect(self, image, confidence_threshold=0.7):# 预处理图像inputs = self.preprocess_image(image)# 推理with torch.no_grad():outputs = self.model(inputs)# 后处理results = self.postprocess(outputs, confidence_threshold)return resultsdef postprocess(self, outputs, confidence_threshold):# 将输出转换为检测结果logits = outputs.logits[0]boxes = outputs.pred_boxes[0]# 应用softmax获取概率probs = logits.softmax(-1)scores, labels = probs.max(-1)# 过滤低置信度检测keep = scores > confidence_thresholdscores = scores[keep]labels = labels[keep]boxes = boxes[keep]# 转换边界框格式为 [x, y, w, h]results = []for score, label, box in zip(scores, labels, boxes):x_center, y_center, width, height = box.tolist()x = x_center - width / 2y = y_center - height / 2results.append({"label": self.model.config.id2label[label.item()],"score": score.item(),"box": [x, y, width, height]})return resultsdef train(self, train_dataset, val_dataset, epochs=50, batch_size=4, lr=1e-4):# 训练代码optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)for epoch in range(epochs):self.model.train()total_loss = 0for i in range(0, len(train_dataset), batch_size):batch = train_dataset[i:i+batch_size]inputs = [self.preprocess_image(img) for img, _ in batch]targets = [{"labels": label, "boxes": box} for _, (label, box) in batch]# 前向传播outputs = self.model(inputs, targets)loss = outputs.loss# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_dataset)}")# 验证self.validate(val_dataset)def validate(self, val_dataset):self.model.eval()# 验证代码略passdef save_model(self, path):self.model.save_pretrained(path)
3. 训练车辆检测模型
创建 train_detector.py
:
python
from vehicle_detector import VehicleDetector from dataset_utils import CocoDataset import argparsedef main():parser = argparse.ArgumentParser(description="Train vehicle detection model")parser.add_argument("--data_path", type=str, required=True, help="Path to dataset")parser.add_argument("--epochs", type=int, default=50, help="Number of training epochs")parser.add_argument("--batch_size", type=int, default=4, help="Batch size for training")parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")parser.add_argument("--output_dir", type=str, default="./models/vehicle_detector", help="Output directory for model")args = parser.parse_args()# 加载数据集train_dataset = CocoDataset(args.data_path, "train")val_dataset = CocoDataset(args.data_path, "val")# 初始化检测器detector = VehicleDetector()# 训练模型detector.train(train_dataset, val_dataset, args.epochs, args.batch_size, args.lr)# 保存模型detector.save_model(args.output_dir)if __name__ == "__main__":main()
第二阶段:车牌识别
1. 准备车牌数据集
车牌数据集需要包含裁剪的车牌图像和对应的文本标签。
2. 实现车牌识别模型
创建 license_plate_recognizer.py
:
python
import torch from transformers import TrOCRProcessor, VisionEncoderDecoderModel from PIL import Image import cv2 import numpy as np import osclass LicensePlateRecognizer:def __init__(self, model_path=None):if model_path and os.path.exists(model_path):# 加载已训练的模型self.model = VisionEncoderDecoderModel.from_pretrained(model_path)self.processor = TrOCRProcessor.from_pretrained(model_path)else:# 初始化预训练模型self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")self.model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.model.to(self.device)self.model.eval()def preprocess_image(self, image):if isinstance(image, np.ndarray):image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))return imagedef recognize(self, image):# 预处理图像image = self.preprocess_image(image)# 预处理pixel_values = self.processor(image, return_tensors="pt").pixel_values.to(self.device)# 推理with torch.no_grad():generated_ids = self.model.generate(pixel_values)# 后处理generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]return generated_textdef train(self, train_dataset, val_dataset, epochs=10, batch_size=8, lr=5e-5):# 训练代码optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)for epoch in range(epochs):self.model.train()total_loss = 0for i in range(0, len(train_dataset), batch_size):batch = train_dataset[i:i+batch_size]images, texts = zip(*batch)# 预处理encoding = self.processor(images, text=texts, return_tensors="pt", padding=True)pixel_values = encoding["pixel_values"].to(self.device)labels = encoding["input_ids"].to(self.device)# 前向传播outputs = self.model(pixel_values=pixel_values, labels=labels)loss = outputs.loss# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_dataset)}")# 验证self.validate(val_dataset)def validate(self, val_dataset):self.model.eval()# 验证代码略passdef save_model(self, path):self.model.save_pretrained(path)self.processor.save_pretrained(path)
3. 训练车牌识别模型
创建 train_recognizer.py
:
python
from license_plate_recognizer import LicensePlateRecognizer from dataset_utils import LicensePlateDataset import argparsedef main():parser = argparse.ArgumentParser(description="Train license plate recognition model")parser.add_argument("--data_path", type=str, required=True, help="Path to dataset")parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")parser.add_argument("--batch_size", type=int, default=8, help="Batch size for training")parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate")parser.add_argument("--output_dir", type=str, default="./models/lp_recognizer", help="Output directory for model")args = parser.parse_args()# 加载数据集train_dataset = LicensePlateDataset(args.data_path, "train")val_dataset = LicensePlateDataset(args.data_path, "val")# 初始化识别器recognizer = LicensePlateRecognizer()# 训练模型recognizer.train(train_dataset, val_dataset, args.epochs, args.batch_size, args.lr)# 保存模型recognizer.save_model(args.output_dir)if __name__ == "__main__":main()
数据集工具
创建 dataset_utils.py
:
python
import json from PIL import Image import os import torch from torch.utils.data import Datasetclass CocoDataset(Dataset):def __init__(self, data_path, split="train"):self.data_path = data_pathself.split = split# 加载标注with open(os.path.join(data_path, split, "labels.json"), "r") as f:self.annotations = json.load(f)# 创建图像ID到图像信息的映射self.image_id_to_info = {img["id"]: img for img in self.annotations["images"]}# 创建图像ID到标注的映射self.image_id_to_anns = {}for ann in self.annotations["annotations"]:if ann["image_id"] not in self.image_id_to_anns:self.image_id_to_anns[ann["image_id"]] = []self.image_id_to_anns[ann["image_id"]].append(ann)def __len__(self):return len(self.annotations["images"])def __getitem__(self, idx):image_info = self.annotations["images"][idx]image_path = os.path.join(self.data_path, self.split, "images", image_info["file_name"])# 加载图像image = Image.open(image_path).convert("RGB")# 获取对应标注anns = self.image_id_to_anns.get(image_info["id"], [])# 提取标签和边界框labels = []boxes = []for ann in anns:labels.append(ann["category_id"])boxes.append(ann["bbox"])return image, (labels, boxes)class LicensePlateDataset(Dataset):def __init__(self, data_path, split="train"):self.data_path = data_pathself.split = split# 加载标注with open(os.path.join(data_path, split, "labels.json"), "r") as f:self.annotations = json.load(f)def __len__(self):return len(self.annotations)def __getitem__(self, idx):item = self.annotations[idx]image_path = os.path.join(self.data_path, self.split, "images", item["file_name"])# 加载图像image = Image.open(image_path).convert("RGB")return image, item["text"]
端到端推理系统
创建 inference.py
:
python
from vehicle_detector import VehicleDetector from license_plate_recognizer import LicensePlateRecognizer import cv2 import numpy as np from PIL import Imageclass VehicleLicenseSystem:def __init__(self, detector_path, recognizer_path):self.detector = VehicleDetector(detector_path)self.recognizer = LicensePlateRecognizer(recognizer_path)def process_image(self, image_path, output_path=None):# 读取图像image = cv2.imread(image_path)if image is None:raise ValueError(f"无法读取图像: {image_path}")# 检测车辆和车牌detections = self.detector.detect(image)results = []for detection in detections:label = detection["label"]score = detection["score"]x, y, w, h = detection["box"]# 转换为整数坐标x, y, w, h = int(x), int(y), int(w), int(h)# 裁剪检测区域roi = image[y:y+h, x:x+w]# 如果是车牌,直接识别if label == "license_plate":plate_text = self.recognizer.recognize(roi)results.append({"type": "license_plate","text": plate_text,"bbox": [x, y, w, h],"confidence": score})# 在图像上绘制结果cv2.rectangle(image, (x, y), (x+w, y+h), (0, 255, 0), 2)cv2.putText(image, plate_text, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)else: # 如果是车辆,只绘制边界框results.append({"type": "vehicle","bbox": [x, y, w, h],"confidence": score})cv2.rectangle(image, (x, y), (x+w, y+h), (255, 0, 0), 2)# 保存或显示结果if output_path:cv2.imwrite(output_path, image)else:cv2.imshow("Result", image)cv2.waitKey(0)cv2.destroyAllWindows()return resultsdef main():import argparseparser = argparse.ArgumentParser(description="Vehicle and License Plate Detection and Recognition")parser.add_argument("--image", type=str, required=True, help="Path to input image")parser.add_argument("--output", type=str, help="Path to output image")parser.add_argument("--detector", type=str, default="./models/vehicle_detector", help="Path to detector model")parser.add_argument("--recognizer", type=str, default="./models/lp_recognizer", help="Path to recognizer model")args = parser.parse_args()# 初始化系统system = VehicleLicenseSystem(args.detector, args.recognizer)# 处理图像results = system.process_image(args.image, args.output)# 打印结果for result in results:if result["type"] == "license_plate":print(f"车牌: {result['text']} (置信度: {result['confidence']:.2f})")else:print(f"车辆: (置信度: {result['confidence']:.2f})")if __name__ == "__main__":main()
训练和执行流程
1. 准备数据集
车辆检测数据集:使用COCO格式,包含车辆和车牌标注
车牌识别数据集:包含裁剪的车牌图像和对应的文本标签
2. 训练车辆检测模型
bash
python train_detector.py --data_path /path/to/dataset --epochs 50 --output_dir ./models/vehicle_detector
3. 训练车牌识别模型
bash
python train_recognizer.py --data_path /path/to/license_plate_dataset --epochs 10 --output_dir ./models/lp_recognizer
4. 运行推理
bash
python inference.py --image /path/to/test_image.jpg --output /path/to/output.jpg
优化和改进建议
数据增强:在训练过程中应用随机裁剪、旋转、颜色抖动等增强技术
模型优化:
使用更大的预训练模型提高精度
使用知识蒸馏技术压缩模型大小
应用量化技术加速推理
后处理优化:
添加非极大值抑制(NMS)处理重叠检测
添加车牌格式验证和校正
部署优化:
使用ONNX或TensorRT加速推理
实现批处理提高吞吐量
添加API接口方便集成
注意事项
此实现需要大量标注数据进行训练,如果数据不足,可以考虑使用预训练模型进行微调
车牌识别对图像质量要求较高,建议在实际应用中添加图像预处理步骤(如对比度增强、去模糊等)
不同国家和地区的车牌格式不同,可能需要针对特定格式调整模型和后续处理
这个方案提供了一个完整的基于Transformer的车辆检测和车牌识别系统。实际应用中可能需要根据具体需求和数据特点进行调整和优化。