基于 GME-Qwen2-VL-7B 实现多模态语义检索方案
一、GME-Qwen2-VL-7B 模型
GME-Qwen2VL
系列是统一的多模态Embedding
模型,基于Qwen2-VL
训练,支持动态分辨率。模型支持三种类型的输入:文本、图像、图像-文本对,所有输入类型都可以生成通用的向量表示,并具有优秀的检索性能。使知识向量化不再局限于文本。基于该模型可以实现 文搜文、文搜图,图搜文,图搜图 等丰富的场景。
GME-Qwen2-VL-7B
ModelScope
地址:
https://modelscope.cn/models/iic/gme-Qwen2-VL-7B-Instruct
本文基于 GME-Qwen2-VL-7B
模型,本地化部署,并实现 文搜图 案例,效果如下所示:
二、GME-Qwen2-VL-7B 部署
下载模型:
modelscope download --model="iic/gme-Qwen2-VL-7B-Instruct" --local_dir gme-Qwen2-VL-7B-Instruct
然后将下载后 gme_inference.py
和下面的服务代码放在一起:
读取模型,并启动api
服务:
import time
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
import uvicorn, json
from gme_inference import GmeQwen2VL
import torch
import base64
from io import BytesIO
from PIL import Imageapp = FastAPI()app.add_middleware(CORSMiddleware,allow_origins=["*"],allow_credentials=True,allow_methods=["*"],allow_headers=["*"],
)model_path = "gme-Qwen2-VL-7B-Instruct"
gme = GmeQwen2VL(model_path=model_path, device="cuda:0")def base64_to_image(image_base64: str):image_data = base64.b64decode(image_base64)image_file = BytesIO(image_data)image = Image.open(image_file)return image@app.post("/v1/embeddings")
async def embeddings(request: Request):global model, tokenizerjson_post_raw = await request.json()json_post = json.dumps(json_post_raw)messages = json.loads(json_post)texts = messages.get('texts')images = messages.get('images')if not texts and not images:return {"code": 400,"message": "texts 和 images 至少传一个!",}t = time.time()if images:images = [base64_to_image(b) for b in images if b]if texts and images:embeds = gme.get_fused_embeddings(texts=texts, images=images).tolist()elif texts:embeds = gme.get_text_embeddings(texts=texts).tolist()else:embeds = gme.get_image_embeddings(images=images).tolist()use_time = time.time() - tif torch.backends.mps.is_available():torch.mps.empty_cache()return {"code": 200,"message": "success","data": embeds,"use_time": use_time}if __name__ == '__main__':uvicorn.run(app, host='0.0.0.0', port=8848, workers=1)
启动后大概占用 17.5G
显存。
三、API 调用示例
import base64
import requestsdef encode_image(image_path):with open(image_path, "rb") as image_file:return base64.b64encode(image_file.read()).decode('utf-8')def embeds(texts: [] = None, images: [] = None):if not texts and not images:raise Exception("embed content is empty!")if images:images = [encode_image(p) for p in images if p]response = requests.post("http://127.0.0.1:8848/v1/embeddings",json={"texts": texts,"images": images})return response.json()["data"]def main():texts = ["你好呀"]images = ["img/1.png"]print(embeds(texts=texts, images=images))if __name__ == '__main__':main()
调用结果:
向量维度为 3584
维。
四、实现 文搜图 案例
这里我准备了一些 猫、狗的图片:
通过 GME-Qwen2-VL-7B
模型向量化并持久化到 Milvus
向量库中 。
import json
import os
import base64
import requests
from pymilvus import MilvusClient, DataTypeclient = MilvusClient("http://127.0.0.1:19530")collection_name = "gme_vl_test"def create_collection():client.drop_collection(collection_name=collection_name)schema = MilvusClient.create_schema(auto_id=False,enable_dynamic_field=False,)schema.add_field(field_name="id", datatype=DataType.VARCHAR, is_primary=True, max_length=255)schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=3584)schema.add_field(field_name="content", datatype=DataType.VARCHAR, max_length=5000)schema.verify()index_params = client.prepare_index_params()index_params.add_index(field_name="vector",index_type="IVF_FLAT",metric_type="L2",params={"nlist": 1024})# 创建 collectionclient.create_collection(collection_name=collection_name,schema=schema,index_params=index_params)def encode_image(image_path):with open(image_path, "rb") as image_file:return base64.b64encode(image_file.read()).decode('utf-8')def embeds(texts: [] = None, images: [] = None):if not texts and not images:raise Exception("embed content is empty!")if images:images = [encode_image(p) for p in images if p]response = requests.post("http://127.0.0.1:8848/v1/embeddings",json={"texts": texts,"images": images})return response.json()["data"]def to_milvus():for index, img in enumerate(os.listdir("img")):img_path = os.path.join("img", img)embed = embeds(images=[img_path])content = {"type": "img","content": img}client.upsert(collection_name=collection_name,data={"id": str(index),"vector": embed[0],"content": json.dumps(content, ensure_ascii=False)})print("save ----> ", img_path)def main():## 创建collectioncreate_collection()## 向量持久化to_milvus()if __name__ == '__main__':main()
通过文本进行图像召回检索:
import base64
import requests
from pymilvus import MilvusClient
import matplotlib.pyplot as plt
from PIL import Imageplt.rcParams['font.sans-serif'] = ['SimHei']
client = MilvusClient("http://127.0.0.1:19530")collection_name = "gme_vl_test"def encode_image(image_path):with open(image_path, "rb") as image_file:return base64.b64encode(image_file.read()).decode('utf-8')def embeds(texts: [] = None, images: [] = None):if not texts and not images:raise Exception("embed content is empty!")if images:images = [encode_image(p) for p in images if p]response = requests.post("http://127.0.0.1:8848/v1/embeddings",json={"texts": texts,"images": images})return response.json()["data"]def main():while True:question = input("请输入:")if not question:passif question == "q":breakvec = embeds(texts=[question])res = client.search(collection_name, data=vec, limit=2, output_fields=["content"])plt.figure()plt.axis('off')plt.title(f"输入问题:{question}", fontsize=20, fontweight='bold')for index, item in enumerate(res[0]):img_name = item["entity"]["content"]plt.subplot(1, 2, index + 1)image = Image.open(f"img/{img_name}")plt.imshow(image)plt.show()if __name__ == '__main__':main()
运行后,在控制台输入问题: