图像分类标注小工具
图像分类标注小工具
不说废话
上代码
import os
import cv2
import shutil
import csvclass ImageLabeler:def __init__(self, input_dir, output_dir, class_names, csv_path='label_log.csv', preview_size=(800, 800)):self.input_dir = input_dirself.output_dir = output_dirself.class_names = class_namesself.csv_path = csv_pathself.preview_size = preview_sizeself.image_files = self._get_image_files()self.labeled_images = self._read_labeled_images() # 读取已经标注过的图像self.image_files = [f for f in self.image_files if f not in self.labeled_images] # 跳过已标注的self.index = 0self.history = []self._create_class_folders()self._init_csv()def _get_image_files(self):image_paths = []for root, _, files in os.walk(self.input_dir):for f in files:if f.lower().endswith(('.jpg', '.jpeg', '.png')):full_path = os.path.join(root, f)rel_path = os.path.relpath(full_path, self.input_dir)image_paths.append(rel_path)return sorted(image_paths)def _read_labeled_images(self):if not os.path.exists(self.csv_path):return set()with open(self.csv_path, mode='r', newline='') as file:reader = csv.reader(file)next(reader) # 跳过表头return set(row[0] for row in reader)def _create_class_folders(self):for class_name in self.class_names:os.makedirs(os.path.join(self.output_dir, class_name), exist_ok=True)def _init_csv(self):if not os.path.exists(self.csv_path):with open(self.csv_path, mode='w', newline='') as file:writer = csv.writer(file)writer.writerow(['image_name', 'label'])def _resize_preview(self, img):return cv2.resize(img, self.preview_size, interpolation=cv2.INTER_AREA)def _write_csv(self, image_name, label):with open(self.csv_path, mode='a', newline='') as file:writer = csv.writer(file)writer.writerow([image_name, label])def _remove_last_csv_entry(self):with open(self.csv_path, mode='r', newline='') as file:lines = file.readlines()if len(lines) <= 1:print("CSV 中没有可删除的记录。")returnwith open(self.csv_path, mode='w', newline='') as file:file.writelines(lines[:-1])def label_images(self):print("开始图像标注:")print("按数字键 1、2、3... 进行分类:【1:埃及;2:希腊;3:罗马和方形;4:其他】")print("按 空格键 回退,按 ESC 退出")total = len(self.image_files)while self.index < total:img_name = self.image_files[self.index]img_path = os.path.join(self.input_dir, img_name)img = cv2.imread(img_path)if img is None:print(f"无法读取图像:{img_path}")self.index += 1continueresized_img = self._resize_preview(img)progress_title = f"[{self.index + 1}/{total}] Label: {img_name}"cv2.imshow(progress_title, resized_img)cv2.moveWindow(progress_title, 100, 100)key = cv2.waitKey(0)cv2.destroyWindow(progress_title)if key in [27]:print("退出标注工具。")breakelif key == ord('s'):print(f"跳过: {img_name}")self.index += 1continueelif key == 32:if self.history:last = self.history.pop()self.index = last['index']if os.path.exists(last['copied_path']):os.remove(last['copied_path'])self._remove_last_csv_entry()print(f"撤销: {last['image_name']} → {last['label']}")else:print("无历史记录可撤销。")continueelse:class_index = key - ord('1')if 0 <= class_index < len(self.class_names):class_name = self.class_names[class_index]dst_path = os.path.join(self.output_dir, class_name, img_name)# 确保目标文件夹存在(包括子目录)os.makedirs(os.path.dirname(dst_path), exist_ok=True)shutil.copy(img_path, dst_path)self._write_csv(img_name, class_name)self.history.append({'index': self.index,'image_name': img_name,'label': class_name,'copied_path': dst_path})print(f"{img_name} → {class_name} ({self.index + 1}/{total})")self.index += 1else:print("无效按键,跳过该图片。")cv2.destroyAllWindows()if __name__ == '__main__':input_folder = 'batch_0002'output_folder = 'labeled_images'categories = ['1_Egyptian', '2_Greek', '3_Roman', '4_Other'] # 可自定义分类名labeler = ImageLabeler(input_dir=input_folder,output_dir=output_folder,class_names=categories,csv_path='label_log.csv',# preview_size=(640, 480)preview_size=(800, 800))labeler.label_images()