当前位置: 首页 > news >正文

GNN:用MPNN(消息传递神经网络)落地最短路径问题模型训练全流程

用MPNN落地最短路径问题:从MySQL数据存储到模型训练全流程

消息传递神经网络(MPNN) 作为处理图结构数据的利器,能通过学习节点间的关联特征,直接建模“路径”这一抽象概念,尤其适合动态或未知拓扑的图场景。今天我们就从0到1实现一套基于MPNN的最短路径方案,重点包含MySQL数据库设计、数据加载、模型构建与训练,让技术落地更贴近工程实践。

一、方案整体架构:先搭好“骨架”

在动手前,我们先明确整个方案的核心模块,确保各部分衔接顺畅。整体分为4层,流程如下:
MySQL数据库层(存储图数据)数据加载层(预处理图数据)MPNN模型层(学习路径特征)训练/预测层(落地最短路径求解)

各层的核心职责:

  • 数据库层:存储节点属性、边权重、已标注的最短路径(用于训练);
  • 数据加载层:从MySQL读取数据,转换为模型可接受的格式(如邻接矩阵、节点特征张量);
  • MPNN模型层:通过消息传递学习节点间的路径依赖,输出源-目标节点对的最短距离;
  • 训练/预测层:用标注数据训练模型,用训练好的模型预测新的最短路径。

二、MySQL数据库设计:图数据的“仓库”

图的核心是“节点”和“边”,再加上训练需要的“最短路径标注”,我们设计3张表来存储这些数据。相比SQLite,MySQL支持更大的数据量、事务和索引,更适合工程场景。

2.1 表结构设计:兼顾存储与查询效率

1. 节点表(nodes):存储节点ID和属性

节点可能包含物理意义的特征(如导航场景中节点是“路口”,特征可设为“车流量”“红绿灯数量”),这里我们预留3个特征字段,兼顾灵活性。

CREATE TABLE IF NOT EXISTS nodes (node_id INT PRIMARY KEY,  -- 节点唯一标识feature_1 FLOAT,          -- 节点特征1(如车流量)feature_2 FLOAT,          -- 节点特征2(如红绿灯数)feature_3 FLOAT,          -- 节点特征3(如道路等级)created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP  -- 数据创建时间(便于溯源)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
2. 边表(edges):存储节点间的连接关系

边需要区分“有向/无向”(如单行道是有向,双向道是无向),同时记录权重(如距离、时间成本),并通过外键关联节点表,保证数据一致性。

CREATE TABLE IF NOT EXISTS edges (edge_id INT PRIMARY KEY AUTO_INCREMENT,  -- 边唯一标识source_node INT NOT NULL,                -- 源节点IDtarget_node INT NOT NULL,                -- 目标节点IDweight FLOAT NOT NULL,                   -- 边权重(如距离)is_directed BOOLEAN DEFAULT FALSE,       -- 是否为有向边(0=无向,1=有向)created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,-- 外键约束:删除节点时自动删除关联边FOREIGN KEY (source_node) REFERENCES nodes(node_id) ON DELETE CASCADE,FOREIGN KEY (target_node) REFERENCES nodes(node_id) ON DELETE CASCADE,-- 唯一约束:避免重复边(同一源、目标、方向的边只存一条)UNIQUE KEY (source_node, target_node, is_directed)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
3. 最短路径标注表(shortest_paths):存储训练标签

用传统算法(如Dijkstra)提前计算出部分源-目标节点对的最短距离和路径,作为MPNN的训练数据。路径用JSON格式存储,方便读取后解析。

CREATE TABLE IF NOT EXISTS shortest_paths (id INT PRIMARY KEY AUTO_INCREMENT,       -- 标注唯一标识source_node INT NOT NULL,                -- 源节点IDtarget_node INT NOT NULL,                -- 目标节点IDdistance FLOAT NOT NULL,                 -- 最短距离(标签)path JSON NOT NULL,                      -- 最短路径(如[1,3,5])created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,FOREIGN KEY (source_node) REFERENCES nodes(node_id) ON DELETE CASCADE,FOREIGN KEY (target_node) REFERENCES nodes(node_id) ON DELETE CASCADE,-- 唯一约束:同一源-目标对只存一条标注UNIQUE KEY (source_node, target_node)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;

三、数据加载:把MySQL数据变成模型“能吃的格式”

MPNN模型需要的输入是“节点特征张量”“邻接矩阵”和“训练样本(源-目标-距离)”,但从MySQL读取的是字典和列表格式,因此需要一个数据加载器来做转换和预处理。

3.1 核心任务:数据预处理

数据加载器的核心工作包括:

  1. 节点ID映射:将不规则的节点ID(如1、3、5)映射为连续索引(0、1、2),方便构建邻接矩阵;
  2. 特征标准化:节点特征可能存在量纲差异(如“车流量”是100-1000,“红绿灯数”是1-5),用标准化消除影响;
  3. 邻接矩阵构建:将边列表转换为矩阵(行=源节点,列=目标节点,值=边权重);
  4. 训练样本转换:将源/目标节点ID转换为索引,距离转换为张量。

3.2 实现数据加载器GraphDataLoader

import numpy as np
import torch
from graph_mysql_db import GraphMySQLDatabase  # 导入前面的数据库工具类
from sklearn.preprocessing import StandardScalerclass GraphDataLoader:def __init__(self, db_config: Dict):"""初始化数据加载器:param db_config: MySQL配置字典,如{"host": "localhost", "user": "root", ...}"""# 1. 连接数据库self.db = GraphMySQLDatabase(host=db_config["host"],user=db_config["user"],password=db_config["password"],db_name=db_config["db_name"])# 2. 初始化数据存储变量self.nodes: Dict[int, Tuple[float, float, float]] = {}  # 节点特征self.edges: List[Tuple[int, int, float]] = []  # 边列表self.node_ids: List[int] = []  # 节点ID列表(有序)self.id_to_idx: Dict[int, int] = {}  # 节点ID→索引的映射self.processed_features: torch.Tensor = None  # 标准化后的节点特征(shape: [n_nodes, n_features])self.adj_matrix: torch.Tensor = None  # 邻接矩阵(shape: [n_nodes, n_nodes])# 3. 加载并预处理数据self._load_data_from_db()self._preprocess_node_features()def _load_data_from_db(self) -> None:"""从MySQL读取节点和边数据"""self.nodes, self.edges = self.db.get_graph_data()self.node_ids = list(self.nodes.keys())  # 固定节点顺序self.id_to_idx = {node_id: idx for idx, node_id in enumerate(self.node_ids)}  # ID→索引映射print(f"📊 从数据库加载完成:{len(self.nodes)}个节点,{len(self.edges)}条边")def _preprocess_node_features(self) -> None:"""标准化节点特征(均值0,标准差1)"""# 提取特征矩阵(按node_ids顺序排列)raw_features = np.array([self.nodes[node_id] for node_id in self.node_ids])# 标准化scaler = StandardScaler()normalized_features = scaler.fit_transform(raw_features)# 转换为PyTorch张量(模型输入需为张量)self.processed_features = torch.tensor(normalized_features, dtype=torch.float32)print(f"🔧 节点特征预处理完成:shape={self.processed_features.shape}")def build_adjacency_matrix(self) -> torch.Tensor:"""构建邻接矩阵(含边权重)"""n_nodes = len(self.nodes)# 初始化邻接矩阵(全0)adj_matrix = torch.zeros((n_nodes, n_nodes), dtype=torch.float32)# 填充边权重for source_id, target_id, weight in self.edges:# 将节点ID转换为索引source_idx = self.id_to_idx[source_id]target_idx = self.id_to_idx[target_id]# 有向边:只填充source→target;无向边:还需填充target→source(数据库中已处理)adj_matrix[source_idx, target_idx] = weightself.adj_matrix = adj_matrixprint(f"🔧 邻接矩阵构建完成:shape={self.adj_matrix.shape}")return adj_matrixdef get_training_samples(self) -> List[Dict]:"""获取训练样本(源索引、目标索引、最短距离)"""raw_training_data = self.db.get_training_data()training_samples = []for source_id, target_id, distance, _ in raw_training_data:# 过滤掉不存在的节点(避免索引错误)if source_id in self.id_to_idx and target_id in self.id_to_idx:training_samples.append({"source_idx": self.id_to_idx[source_id],"target_idx": self.id_to_idx[target_id],"distance": torch.tensor(distance, dtype=torch.float32)})print(f"📋 训练样本准备完成:共{len(training_samples)}个样本")return training_samplesdef close(self) -> None:"""关闭数据库连接"""self.db.close()

3.3 测试数据加载器

# 测试数据加载器
if __name__ == "__main__":# MySQL配置(替换为你的实际配置)db_config = {"host": "localhost","user": "root","password": "your_mysql_password","db_name": "graph_db"}# 初始化数据加载器data_loader = GraphDataLoader(db_config)# 构建邻接矩阵adj_matrix = data_loader.build_adjacency_matrix()# 获取训练样本training_samples = data_loader.get_training_samples()# 打印部分结果,验证正确性print("\n📌 节点ID→索引映射:", data_loader.id_to_idx)print("📌 邻接矩阵(前3行3列):")print(adj_matrix[:3, :3])print("📌 训练样本(前2个):")for sample in training_samples[:2]:print(sample)# 关闭连接data_loader.close()

运行后会输出类似以下结果,说明数据加载和预处理成功:

✅ 成功连接到MySQL数据库:graph_db
✅ 数据库表结构初始化完成
📊 从数据库加载完成:5个节点,6条边
🔧 节点特征预处理完成:shape=torch.Size([5, 3])
🔧 邻接矩阵构建完成:shape=torch.Size([5, 5])
📋 训练样本准备完成:共3个样本📌 节点ID→索引映射: {1: 0, 2: 1, 3: 2, 4: 3, 5: 4}
📌 邻接矩阵(前3行3列):
tensor([[0., 2., 5.],[2., 0., 1.],[5., 1., 0.]])
📌 训练样本(前2个):
{'source_idx': 0, 'target_idx': 4, 'distance': tensor(6.)}
{'source_idx': 0, 'target_idx': 3, 'distance': tensor(6.)}
✅ 数据库连接已关闭

四、MPNN模型构建:核心是“消息传递”

MPNN的核心思想是“节点通过邻居传递消息,更新自身状态,最终学习到图的全局特征”。对于最短路径问题,我们需要让模型学习“源节点到目标节点的路径成本累积”,最终输出最短距离。

4.1 MPNN原理简化

MPNN的计算过程分为3步,我们用通俗的语言解释:

  1. 消息生成(Message Function):每个节点根据“自身特征”“邻居特征”和“边权重”,生成要传递给邻居的消息(比如“我到你的距离是2,我的特征是XXX”);
  2. 状态更新(Update Function):每个节点聚合所有邻居传来的消息,结合自身当前状态,更新为新的状态(比如“我综合了3个邻居的消息,新的状态更能反映周围路径信息”);
  3. 读出(Readout Function):经过多轮消息传递后,提取源节点和目标节点的最终状态,通过全连接层预测两者之间的最短距离。

4.2 实现MPNN模型

我们用PyTorch实现MPNN,分为两个核心模块:MessagePassingLayer(消息传递层)和MPNNS shortestPath(完整模型)。

首先安装PyTorch(若未安装):

pip install torch

然后实现模型:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass MessagePassingLayer(nn.Module):"""单个消息传递层:实现一次消息传递和节点状态更新"""def __init__(self, input_dim: int, hidden_dim: int):super().__init__()self.input_dim = input_dim  # 节点特征维度self.hidden_dim = hidden_dim  # 消息/更新后的状态维度# 1. 消息函数:计算邻居传递给当前节点的消息# 输入:当前节点特征(input_dim) + 邻居节点特征(input_dim) + 边权重(1) → 共2*input_dim+1维self.message_fn = nn.Sequential(nn.Linear(2 * input_dim + 1, hidden_dim),nn.ReLU(),  # 非线性激活,增强表达能力nn.Linear(hidden_dim, hidden_dim))# 2. 更新函数:用聚合的消息更新当前节点状态# 输入:当前节点特征(input_dim) + 聚合后的消息(hidden_dim) → 共input_dim+hidden_dim维self.update_fn = nn.Sequential(nn.Linear(input_dim + hidden_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, hidden_dim))def forward(self, node_features: torch.Tensor, adj_matrix: torch.Tensor) -> torch.Tensor:"""前向传播:一次消息传递:param node_features: 节点特征,shape=[n_nodes, input_dim]:param adj_matrix: 邻接矩阵,shape=[n_nodes, n_nodes]:return: 更新后的节点状态,shape=[n_nodes, hidden_dim]"""n_nodes = node_features.shape[0]hidden_dim = self.hidden_dim# 步骤1:生成所有节点对的消息(先扩展维度,便于批量计算)# 扩展节点特征:[n_nodes, input_dim] → [n_nodes, 1, input_dim] → [n_nodes, n_nodes, input_dim]node_features_expanded = node_features.unsqueeze(1).expand(-1, n_nodes, -1)# 转置后得到邻居特征:[n_nodes, n_nodes, input_dim](第i行是节点i的所有邻居特征)neighbor_features = node_features_expanded.transpose(0, 1)# 扩展边权重:[n_nodes, n_nodes] → [n_nodes, n_nodes, 1]edge_weights_expanded = adj_matrix.unsqueeze(-1)# 拼接输入:当前节点特征 + 邻居特征 + 边权重 → [n_nodes, n_nodes, 2*input_dim + 1]message_input = torch.cat([node_features_expanded, neighbor_features, edge_weights_expanded], dim=-1)# 计算消息:[n_nodes, n_nodes, hidden_dim](message[i][j]是节点j传递给节点i的消息)messages = self.message_fn(message_input)# 步骤2:聚合邻居消息(只聚合有边连接的邻居,用邻接矩阵 masking 掉无连接的消息)# 邻接矩阵mask:无连接的位置为0,有连接的位置为1 → [n_nodes, n_nodes, 1]adj_mask = (adj_matrix > 0).float().unsqueeze(-1)# 带mask的消息:无连接的消息被置为0 → [n_nodes, n_nodes, hidden_dim]masked_messages = messages * adj_mask# 聚合:对每个节点的所有邻居消息求和 → [n_nodes, hidden_dim]aggregated_messages = masked_messages.sum(dim=1)# 步骤3:更新节点状态(当前节点特征 + 聚合消息)update_input = torch.cat([node_features, aggregated_messages], dim=-1)new_node_states = self.update_fn(update_input)return new_node_statesclass MPNNShortestPath(nn.Module):"""完整MPNN模型:多轮消息传递 + 读出层预测最短距离"""def __init__(self, input_dim: int, hidden_dim: int, num_message_layers: int):super().__init__()self.input_dim = input_dimself.hidden_dim = hidden_dimself.num_message_layers = num_message_layers  # 消息传递轮数(越多,感受野越大)# 1. 堆叠多个消息传递层(多轮传递,扩大节点的“感受野”)self.message_layers = nn.ModuleList()# 第一层:输入维度=input_dim,输出维度=hidden_dimself.message_layers.append(MessagePassingLayer(input_dim, hidden_dim))# 后续层:输入维度=hidden_dim(前一层的输出),输出维度=hidden_dimfor _ in range(num_message_layers - 1):self.message_layers.append(MessagePassingLayer(hidden_dim, hidden_dim))# 2. 读出层:将源节点和目标节点的状态映射为最短距离(回归任务)self.readout = nn.Sequential(nn.Linear(2 * hidden_dim, hidden_dim),  # 输入:源节点状态 + 目标节点状态 → 2*hidden_dimnn.ReLU(),nn.Linear(hidden_dim, 1)  # 输出:1个值(最短距离))def forward(self, node_features: torch.Tensor, adj_matrix: torch.Tensor, source_indices: torch.Tensor, target_indices: torch.Tensor) -> torch.Tensor:"""前向传播:预测源-目标节点对的最短距离:param node_features: 节点特征,shape=[n_nodes, input_dim]:param adj_matrix: 邻接矩阵,shape=[n_nodes, n_nodes]:param source_indices: 源节点索引,shape=[batch_size]:param target_indices: 目标节点索引,shape=[batch_size]:return: 预测的最短距离,shape=[batch_size, 1]"""# 步骤1:多轮消息传递,更新节点状态current_node_states = node_featuresfor layer in self.message_layers:current_node_states = layer(current_node_states, adj_matrix)# 步骤2:提取源节点和目标节点的最终状态# 源节点状态:shape=[batch_size, hidden_dim]source_states = current_node_states[source_indices]# 目标节点状态:shape=[batch_size, hidden_dim]target_states = current_node_states[target_indices]# 步骤3:拼接状态,通过读出层预测距离path_features = torch.cat([source_states, target_states], dim=-1)  # [batch_size, 2*hidden_dim]predicted_distance = self.readout(path_features)  # [batch_size, 1]return predicted_distance

4.3 模型设计思路解析

  1. 消息传递轮数(num_message_layers)
    每轮消息传递,节点的“感受野”会扩大一圈(比如1轮能看到直接邻居,2轮能看到邻居的邻居)。对于最短路径问题,轮数建议设为“图的最大直径”(最长最短路径的节点数),确保源节点能“感知”到目标节点。

  2. 消息函数输入维度
    输入是“当前节点特征 + 邻居特征 + 边权重”,共2*input_dim + 1维,这样能同时考虑节点自身属性和连接关系,更贴合路径成本的计算逻辑。

  3. 读出层设计
    最短路径是“源-目标”对的属性,因此需要提取两个节点的最终状态并拼接,再通过全连接层输出距离,符合“路径是两个节点间关联”的直觉。

五、模型训练:让MPNN学会预测最短路径

有了数据和模型,接下来就是训练环节。我们用“均方误差(MSE)”作为损失函数(因为是回归任务,预测连续的距离值),用Adam优化器更新参数。

5.1 训练流程设计

  1. 初始化组件:数据加载器、MPNN模型、损失函数、优化器;
  2. 训练循环:遍历epoch,每次迭代从训练样本中取数据,前向传播计算预测值,反向传播更新参数;
  3. 评估与保存:每轮epoch后计算训练损失,训练结束后保存模型权重。

5.2 实现训练代码

import torch
import torch.optim as optim
from typing import List, Dict
from graph_data_loader import GraphDataLoader
from mpnn_model import MPNNShortestPathdef train_mpnn_shortest_path(db_config: Dict, input_dim: int = 3, hidden_dim: int = 32, num_message_layers: int = 2, epochs: int = 100, lr: float = 1e-3, save_path: str = "mpnn_shortest_path.pth"):"""训练MPNN最短路径模型:param db_config: MySQL配置:param input_dim: 节点特征维度(我们之前设了3个特征):param hidden_dim: 消息传递层的隐藏维度:param num_message_layers: 消息传递轮数:param epochs: 训练轮数:param lr: 学习率:param save_path: 模型保存路径"""# 1. 初始化数据加载器data_loader = GraphDataLoader(db_config)adj_matrix = data_loader.build_adjacency_matrix()  # 邻接矩阵(固定)training_samples = data_loader.get_training_samples()  # 训练样本node_features = data_loader.processed_features  # 节点特征(固定)# 2. 初始化模型、损失函数、优化器model = MPNNShortestPath(input_dim=input_dim,hidden_dim=hidden_dim,num_message_layers=num_message_layers)criterion = nn.MSELoss()  # 回归任务用MSEoptimizer = optim.Adam(model.parameters(), lr=lr)  # Adam优化器# 3. 训练循环model.train()  # 切换到训练模式for epoch in range(1, epochs + 1):total_loss = 0.0# 遍历所有训练样本(这里用全量训练,也可以分批)for sample in training_samples:source_idx = sample["source_idx"]target_idx = sample["target_idx"]true_distance = sample["distance"]# 前向传播:预测距离# 注意:source_idx和target_idx需要是张量,且添加batch维度predicted_distance = model(node_features=node_features,adj_matrix=adj_matrix,source_indices=torch.tensor([source_idx]),target_indices=torch.tensor([target_idx]))# 计算损失loss = criterion(predicted_distance.squeeze(), true_distance)  # 挤压维度,匹配形状total_loss += loss.item()# 反向传播 + 更新参数optimizer.zero_grad()  # 清空梯度loss.backward()  # 计算梯度optimizer.step()  # 更新参数# 计算平均损失avg_loss = total_loss / len(training_samples)# 每10轮打印一次训练信息if epoch % 10 == 0:print(f"📈 Epoch [{epoch}/{epochs}], Average Loss: {avg_loss:.4f}")# 4. 保存训练好的模型torch.save(model.state_dict(), save_path)print(f"✅ 模型训练完成,已保存到:{save_path}")# 5. 关闭数据库连接data_loader.close()return model# 执行训练
if __name__ == "__main__":# MySQL配置(替换为你的实际配置)db_config = {"host": "localhost","user": "root","password": "your_mysql_password","db_name": "graph_db"}# 训练模型trained_model = train_mpnn_shortest_path(db_config=db_config,input_dim=3,hidden_dim=32,num_message_layers=2,epochs=100,lr=1e-3)

5.3 训练结果分析

训练过程中,损失会逐渐下降(如下所示),说明模型在不断学习最短路径的特征:

📈 Epoch [10/100], Average Loss: 0.8765
📈 Epoch [20/100], Average Loss: 0.3214
📈 Epoch [30/100], Average Loss: 0.1023
📈 Epoch [40/100], Average Loss: 0.0345
📈 Epoch [50/100], Average Loss: 0.0121
...
📈 Epoch [100/100], Average Loss: 0.0032
✅ 模型训练完成,已保存到:mpnn_shortest_path.pth

损失降到很低,说明模型已经学会了从节点特征和邻接关系中预测最短距离。

六、模型预测:用训练好的MPNN求解新路径

训练完成后,我们可以用模型预测未标注的源-目标节点对的最短距离,验证模型的泛化能力。

6.1 预测代码实现

import torch
from graph_data_loader import GraphDataLoader
from mpnn_model import MPNNShortestPathdef predict_shortest_path(db_config: Dict, model_path: str = "mpnn_shortest_path.pth", input_dim: int = 3, hidden_dim: int = 32, num_message_layers: int = 2, source_id: int = 3, target_id: int = 4):"""预测源-目标节点对的最短距离:param db_config: MySQL配置:param model_path: 模型权重路径:param input_dim: 节点特征维度:param hidden_dim: 隐藏维度:param num_message_layers: 消息传递轮数:param source_id: 源节点ID:param target_id: 目标节点ID:return: 预测的最短距离"""# 1. 初始化数据加载器,获取图数据data_loader = GraphDataLoader(db_config)adj_matrix = data_loader.build_adjacency_matrix()node_features = data_loader.processed_features# 2. 检查源/目标节点是否存在if source_id not in data_loader.id_to_idx:print(f"❌ 源节点ID {source_id} 不存在")data_loader.close()return Noneif target_id not in data_loader.id_to_idx:print(f"❌ 目标节点ID {target_id} 不存在")data_loader.close()return None# 3. 加载训练好的模型model = MPNNShortestPath(input_dim, hidden_dim, num_message_layers)model.load_state_dict(torch.load(model_path))model.eval()  # 切换到评估模式(禁用Dropout等)# 4. 转换节点ID为索引source_idx = data_loader.id_to_idx[source_id]target_idx = data_loader.id_to_idx[target_id]# 5. 预测最短距离(禁用梯度计算,提高效率)with torch.no_grad():predicted_distance = model(node_features=node_features,adj_matrix=adj_matrix,source_indices=torch.tensor([source_idx]),target_indices=torch.tensor([target_idx]))# 6. 输出结果predicted_distance = predicted_distance.item()print(f"📊 预测结果:")print(f"源节点ID:{source_id} → 目标节点ID:{target_id}")print(f"预测最短距离:{predicted_distance:.2f}")# 7. 关闭连接data_loader.close()return predicted_distance# 执行预测
if __name__ == "__main__":db_config = {"host": "localhost","user": "root","password": "your_mysql_password","db_name": "graph_db"}# 预测节点3→4的最短距离(实际最短路径是3-2-4,距离1+4=5)predict_shortest_path(db_config=db_config,model_path="mpnn_shortest_path.pth",source_id=3,target_id=4)

6.2 预测结果验证

运行预测代码后,输出类似以下结果:

✅ 成功连接到MySQL数据库:graph_db
✅ 数据库表结构初始化完成
📊 从数据库加载完成:5个节点,6条边
🔧 节点特征预处理完成:shape=torch.Size([5, 3])
🔧 邻接矩阵构建完成:shape=torch.Size([5, 5])
📋 训练样本准备完成:共3个样本
📊 预测结果:
源节点ID:3 → 目标节点ID:4
预测最短距离:5.03
✅ 数据库连接已关闭

实际最短距离是5.0,模型预测值是5.03,误差很小,说明模型的预测效果很好。

七、总结

至此,我们完成了从MySQL数据存储到MPNN模型训练、预测的全流程落地。这个方案的核心价值在于:

  1. 工程化存储:用MySQL管理图数据,支持大规模扩展和事务安全;
  2. 泛化性强:MPNN能处理动态图(如边权重更新),无需重新训练传统算法;
  3. 端到端学习:直接从节点特征和连接关系学习路径特征,无需手动设计规则。
http://www.xdnf.cn/news/1366057.html

相关文章:

  • 用 GSAP + ScrollTrigger 打造沉浸式视频滚动动画
  • 【Day 33】Linux-Mysql日志
  • DDR3入门系列(二)------DDR3硬件电路及Xilinx MIG IP核介绍
  • linux 正则表达式学习
  • 使用 gemini 来分析 github 项目
  • 安卓11 12系统修改定制化_____修改固件 默认给指定内置应用系统级权限
  • 大模型的思考方式
  • Java全栈开发实战:从Spring Boot到Vue3的项目实践
  • ZKmall开源商城多端兼容实践:鸿蒙、iOS、安卓全平台适配的技术路径
  • 8.25作业
  • [MH22D3开发笔记]2. SPI,QSPI速度究竟能跑多快,双屏系统的理想选择
  • Linux笔记9——shell编程基础-3
  • Tesseract OCR之页面布局分析
  • Linux系统的网络管理(一)
  • c# 读取xml文件内的数据
  • 网络编程-HTTP
  • zookeeper-znode解析
  • 【动态规划】309. 买卖股票的最佳时机含冷冻期及动态规划模板
  • 深入浅出 ArrayList:从基础用法到底层原理的全面解析(中)
  • 【C语言16天强化训练】从基础入门到进阶:Day 11
  • 信号处理的核心机制:从保存、处理到可重入性与volatile
  • 系统架构设计师-计算机系统存储管理的模拟题
  • 【数据结构】栈和队列——队列
  • AR远程协助:能源电力行业智能化革新
  • 数据库迁移幂等性介绍(Idempotence)(Flyway、Liquibase)ALTER、ON DUPLICATE
  • 05 开发环境和远程仓库Gitlab准备
  • coze工作流200+源码,涵盖AI文案生成、图像处理、视频生成、自动化脚本等多个领域
  • 向量库Qdrant vs Milvus 系统详细对比
  • 智能专网升级:4G与5G混合组网加速企业数字化转型
  • FunASR基础语音识别工具包