PyTorch 与 Spring AI 集成实战
目录
- 一、前言
- 二、发布PyTorch 模型
- REST API 部署方式
- ONNX 转换方式
- 三、构建 PyTorch 服务端
- 模型保存
- FastAPI 服务
- 四、Spring AI 调用 PyTorch 模型
- 使用 RestTemplate 访问
- 在 Tool Calling 中集成
- 五、实战演练:智能客服识别情绪
- 用户提问
- PromptTemplate 示例:
- Spring AI 调用 PyTorch 接口进行二次验证:
- 六、总结
- 七、参考
一、前言
大多数深度学习模型仍由 Python 和 PyTorch 驱动,但越来越多的企业希望将这些模型嵌入到 Java 微服务中运行。
Spring AI 提供了灵活的方式,结合 RESTful 接口、容器部署、Tool Calling 和 Agent 架构,使 Java 与 PyTorch 模型之间的协作不再是梦。
本篇将带你完成:
- PyTorch 模型部署为服务(REST API)
- Spring AI 调用 PyTorch 模型进行问答、分类或推理
- 实战示例:中文情感分析模型接入
二、发布PyTorch 模型
Java 无法直接运行 PyTorch 模型,但可以通过以下两种方式调用:
REST API 部署方式
也是本篇推荐使用 FastAPI 或 Flask 将模型包装为 HTTP 接口,第三节将重点介绍。
ONNX 转换方式
ONNX转换适用于通用模型,将模型转换为 ONNX 格式,用 JNI/ONNX Runtime 调用。详见《标准化模型格式ONNX介绍:打通AI模型从训练到部署的环节》
本篇我们将采用第一种REST API 部署方式:用 Python + FastAPI 部署 PyTorch 模型,由 Java 远程调用。
三、构建 PyTorch 服务端
模型保存
# train.py
import torch
model = MyModel()
... # 训练代码
torch.save(model.state_dict(), "sentiment_model.pt")
FastAPI 服务
# app.py
from fastapi import FastAPI, Request
import torch
import torch.nn.functional as Fapp = FastAPI()
model = MyModel()
model.load_state_dict(torch.load("sentiment_model.pt"))
model.eval()@app.post("/predict")
async def predict(request: Request):data = await request.json()text = data["text"]# TODO: text preprocessing & tokenizingwith torch.no_grad():output = model(text)pred = F.softmax(output, dim=1).tolist()return {"result": pred}
四、Spring AI 调用 PyTorch 模型
使用 RestTemplate 访问
@RestController
public class InferenceController {@Autowired RestTemplate restTemplate;@PostMapping("/ai/sentiment")public String classify(@RequestBody String text) {HttpHeaders headers = new HttpHeaders();headers.setContentType(MediaType.APPLICATION_JSON);Map<String, String> body = Map.of("text", text);HttpEntity<Map<String, String>> req = new HttpEntity<>(body, headers);String url = "http://localhost:8000/predict";ResponseEntity<String> resp = restTemplate.postForEntity(url, req, String.class);return resp.getBody();}
}
在 Tool Calling 中集成
@AiFunction(name = "sentiment")
public String analyzeSentiment(@AiParam("text") String text) {return classify(text);
}
注册为 Spring AI 工具:
List<ToolSpecification> tools = FunctionCallingTools.fromBeans(appContext);
chatClient = new FunctionCallingChatClient(chatClient, tools);
现在,模型就可以被 LLM 调用啦!
五、实战演练:智能客服识别情绪
用户提问
“你们的服务真的太烂了,我再也不会买了!”
PromptTemplate 示例:
String prompt = "请判断以下内容的用户情绪类别(积极、消极、中性):{{text}}";
LLM 返回:消极
Spring AI 调用 PyTorch 接口进行二次验证:
String pyResult = analyzeSentiment(userText);
可用于模型投票融合、异常拦截等场景。
六、总结
通过本文,我们完成了从 PyTorch 模型训练、FastAPI 部署,到 Spring AI 调用推理的完整闭环。
Spring AI 可以将自研模型作为 Tool,嵌入智能 Agent 流程中,与大语言模型协同。
七、参考
《Java驱动AI革命:Spring AI八篇进阶指南——从架构基础到企业级智能系统实战》