Python 实现将图像发送到指定的 API 进行推理
Python 实现将图像发送到指定的 API 进行推理
flyfish
整体功能概述
该代码实现了一个图像处理器,可将指定文件夹中的图像文件依次发送到一个或多个 API 进行推理,接收 API 的响应并将推理结果保存到文件中。同时,代码具备错误处理、状态记录和断点续传功能,即使在程序中断(如服务器故障、用户手动中断)后,也能在重启时接着之前的进度继续处理图像。
具体功能模块
- 初始化(
__init__
方法)- 读取提示文本文件,获取推理所需的提示信息。
- 检查并确保图像文件夹和结果保存文件夹存在,若不存在则创建。
- 填充任务队列,将图像文件路径和对应的结果保存路径添加到队列中。
- 加载之前保存的程序状态,包括已处理图像列表、总推理时间和推理次数。
- 读取提示文本(
_read_prompt
方法):从指定的文本文件中读取提示信息,若文件不存在则输出错误信息并退出程序。 - 确保文件夹存在(
_ensure_folder_exists
方法):检查指定的文件夹是否存在,若不存在则创建该文件夹。 - 图像编码(
_encode_image_to_base64
方法):将图像文件编码为 Base64 字符串,以便在 HTTP 请求中传输。 - 发送请求(
_send_request
方法):向指定的 API 发送 POST 请求,包含提示信息和 Base64 编码的图像数据。若请求失败,会进行重试,每次重试间隔 5 秒。 - 处理响应(
_process_response
方法):解析 API 的响应,将推理结果和推理时间保存到文件中。若响应不是有效的 JSON 数据或处理过程中出现其他错误,会记录错误信息到日志文件中。 - 填充任务队列(
_populate_task_queue
方法):遍历图像文件夹,将所有符合条件(.jpg
、.jpeg
、.png
)的图像文件路径和对应的结果保存路径添加到任务队列中。 - 加载状态(
_load_status
方法):从状态文件中加载之前保存的程序状态,包括已处理图像列表、总推理时间和推理次数。若状态文件不存在,则返回初始值。 - 保存状态(
_save_status
方法):将当前的程序状态(已处理图像列表、总推理时间和推理次数)保存到状态文件中,以便程序中断后能接着之前的进度继续处理。 - 记录错误信息(
_log_error
方法):将错误信息和对应的图像文件名记录到错误日志文件中,方便后续排查问题。 - 工作线程(
worker
方法):从任务队列中取出任务,对图像进行编码、发送请求、处理响应,并更新程序状态。若图像已处理过,则跳过该任务。 - 处理图像(
process_images
方法):为每个 API 地址启动一个工作线程,并行处理图像。若用户手动中断程序(按下Ctrl+C
),会保存当前的程序状态后退出。最后,输出总推理时间和平均推理时间。
主程序
在 if __name__ == "__main__"
块中,创建了 ImageProcessor
类的实例,并调用 process_images
方法开始处理图像。用户可以根据需要修改提示文本文件路径、图像文件夹路径、结果保存文件夹路径和 API 地址列表。
import requests
import base64
import os
import time
import threading
import queue
import jsonclass ImageProcessor:def __init__(self, prompt_file_path, images_folder, result_folder, api_urls):"""初始化 ImageProcessor 类。:param prompt_file_path: 提示文本文件的路径:param images_folder: 图像文件夹的路径:param result_folder: 结果保存文件夹的路径:param api_urls: 发送请求的 API 地址列表"""self.prompt = self._read_prompt(prompt_file_path)self.images_folder = images_folderself.result_folder = result_folderself.api_urls = api_urlsself._ensure_folder_exists(self.images_folder)self._ensure_folder_exists(self.result_folder)self.task_queue = queue.Queue()self._populate_task_queue()self.status_file = os.path.join(self.result_folder, "status.json")self.error_log_file = os.path.join(self.result_folder, "error_log.txt")self.processed_images, self.total_inference_time, self.inference_count = self._load_status()def _read_prompt(self, prompt_file_path):"""读取提示文本文件。:param prompt_file_path: 提示文本文件的路径:return: 提示文本"""try:with open(prompt_file_path, "r", encoding="utf-8") as f:return f.read().strip()except FileNotFoundError:print("Error: prompt file not found.")exit(1)def _ensure_folder_exists(self, folder):"""确保指定的文件夹存在,如果不存在则创建。:param folder: 文件夹路径"""if not os.path.exists(folder):os.makedirs(folder)def _encode_image_to_base64(self, image_path):"""将图像文件编码为 Base64 字符串。:param image_path: 图像文件的路径:return: Base64 编码的图像字符串"""with open(image_path, 'rb') as file:image_data = file.read()return base64.b64encode(image_data).decode('utf-8')def _send_request(self, base64_image, api_url):"""发送 POST 请求到 API 进行预测。:param base64_image: Base64 编码的图像字符串:param api_url: 发送请求的 API 地址:return: 响应对象"""data = {"prompt": self.prompt,"image": base64_image}while True:try:response = requests.post(api_url, json=data)response.raise_for_status()return responseexcept requests.RequestException as e:print(f"请求出错: {e},等待 60 秒后重试...")time.sleep(60)def _process_response(self, response, result_path, inference_time, image_name):"""处理 API 响应并保存结果到文件。:param response: 响应对象:param result_path: 结果保存文件的路径:param inference_time: 推理时间:param image_name: 图像文件名:return: 是否成功处理"""try:result = response.json()output_text = result.get("output", "")with open(result_path, "w", encoding="utf-8") as f:f.write(f"推理结果: {output_text}\n")f.write(f"本次推理时间: {inference_time:.4f} 秒")current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())print(f"{current_time} 处理完成,本次推理时间: {inference_time:.4f} 秒")return Trueexcept ValueError:print(f"响应不是有效的 JSON 数据: {response.text}")self._log_error(image_name, f"响应不是有效的 JSON 数据: {response.text}")except Exception as e:print(f"处理时发生错误: {e}")self._log_error(image_name, f"处理时发生错误: {e}")return Falsedef _populate_task_queue(self):"""填充任务队列,将图像文件路径和结果保存路径添加到队列中。"""filenames = sorted(os.listdir(self.images_folder))for filename in filenames:if filename.endswith(('.jpg', '.jpeg', '.png')):image_path = os.path.join(self.images_folder, filename)result_filename = os.path.splitext(filename)[0] + ".txt"result_path = os.path.join(self.result_folder, result_filename)self.task_queue.put((image_path, result_path))def _load_status(self):"""从状态文件中加载已处理图像、总推理时间和推理次数。:return: 已处理图像集合、总推理时间、推理次数"""if os.path.exists(self.status_file):with open(self.status_file, "r", encoding="utf-8") as f:status = json.load(f)return set(status.get("processed_images", [])), status.get("total_inference_time", 0), status.get("inference_count", 0)return set(), 0, 0def _save_status(self):"""将已处理图像、总推理时间和推理次数保存到状态文件。"""status = {"processed_images": list(self.processed_images),"total_inference_time": self.total_inference_time,"inference_count": self.inference_count}with open(self.status_file, "w", encoding="utf-8") as f:json.dump(status, f)def _log_error(self, image_name, error_message):"""记录错误信息到错误日志文件。:param image_name: 图像文件名:param error_message: 错误信息"""with open(self.error_log_file, "a", encoding="utf-8") as f:current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())f.write(f"{current_time} - 图像 {image_name} 发生错误: {error_message}\n")def worker(self, api_url):"""工作线程函数,从任务队列中取出任务并处理。:param api_url: 发送请求的 API 地址"""while not self.task_queue.empty():image_path, result_path = self.task_queue.get()image_name = os.path.splitext(os.path.basename(image_path))[0]if image_name in self.processed_images:self.task_queue.task_done()continuestart_time = time.time()base64_image = self._encode_image_to_base64(image_path)response = self._send_request(base64_image, api_url)end_time = time.time()inference_time = end_time - start_timecurrent_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())print(f"{current_time} 正在处理 {os.path.basename(image_path)}...")success = self._process_response(response, result_path, inference_time, image_name)if success:self.total_inference_time += inference_timeself.inference_count += 1self.processed_images.add(image_name)self._save_status()self.task_queue.task_done()def process_images(self):"""启动多个工作线程处理图像,每个线程对应一个 API 地址。"""threads = []try:for api_url in self.api_urls:thread = threading.Thread(target=self.worker, args=(api_url,))thread.start()threads.append(thread)for thread in threads:thread.join()except KeyboardInterrupt:print("接收到中断信号,正在保存当前状态...")self._save_status()print("当前状态已保存,程序退出。")if self.inference_count > 0:average_inference_time = self.total_inference_time / self.inference_countprint(f"总推理时间: {self.total_inference_time:.4f} 秒")print(f"平均推理时间: {average_inference_time:.4f} 秒")if __name__ == "__main__":prompt_file = "prompt.md"images_folder = "images"result_folder = "result"api_urls = ['http://192.168.1.2:3714/predict', 'http://192.168.1.2:3713/predict']processor = ImageProcessor(prompt_file, images_folder, result_folder, api_urls)processor.process_images()