当前位置: 首页 > news >正文

【DGL学习1】GCN example

DGL学习

用DGL实现一个简单的GCN-cora的例子。
参考:https://docs.dgl.ai/tutorials/blitz/1_introduction.html#sphx-glr-tutorials-blitz-1-introduction-py

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
Using backend: pytorch
import dgl.data
dataset = dgl.data.CoraGraphDataset()
print(dataset)
  NumNodes: 2708NumEdges: 10556NumFeats: 1433NumClasses: 7NumTrainingSamples: 140NumValidationSamples: 500NumTestSamples: 1000
Done loading data from cached files.
<dgl.data.citation_graph.CoraGraphDataset object at 0x7fcc7c0922d0>
g = dataset[0]
print(g)
Graph(num_nodes=2708, num_edges=10556,ndata_schemes={'feat': Scheme(shape=(1433,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool), 'train_mask': Scheme(shape=(), dtype=torch.bool)}edata_schemes={})
print(g.ndata)
print(g.edata)
{'feat': tensor([[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],...,[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.]]), 'label': tensor([3, 4, 4,  ..., 3, 3, 3]), 'val_mask': tensor([False, False, False,  ..., False, False, False]), 'test_mask': tensor([False, False, False,  ...,  True,  True,  True]), 'train_mask': tensor([ True,  True,  True,  ..., False, False, False])}
{}
print(g.ndata['feat'].shape)
print(g.ndata['label'].shape)
torch.Size([2708, 1433])
torch.Size([2708])
from dgl.nn import GraphConv
class GCN(nn.Module):def __init__(self, in_feats, h_feats, num_classes):super(GCN, self).__init__()self.conv1 = GraphConv(in_feats, h_feats)self.conv2 = GraphConv(h_feats, num_classes)def forward(self, g, in_feat):x = F.relu(self.conv1(g, in_feat))x = F.softmax(self.conv2(g, x))return xmodel = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes)
print(model)
GCN((conv1): GraphConv(in=1433, out=16, normalization=both, activation=None)(conv2): GraphConv(in=16, out=7, normalization=both, activation=None)
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = model.to(device)
g = g.to(device)
cpu
train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']
feat = g.ndata['feat']
label = g.ndata['label']
def train():optimizer = torch.optim.Adam(model.parameters(), lr=0.01)criterion = nn.CrossEntropyLoss()out = model(g, feat)loss = criterion(out[train_mask], label[train_mask])optimizer.zero_grad()loss.backward()optimizer.step()pred = out.argmax(dim=1)train_acc = (pred[train_mask] == label[train_mask]).float().mean()val_acc = (pred[val_mask] == label[val_mask]).float().mean()test_acc = (pred[test_mask] == label[test_mask]).float().mean()return loss.item(), train_acc, val_acc, test_acc
def main():best_val_acc = 0best_test_acc = 0for epoch in range(100):loss, train_acc, val_acc, test_acc = train()if best_val_acc < val_acc:best_val_acc = val_accbest_test_acc = test_accprint('epoch:{:03d}, train_acc:{:.4f}, val_acc:{:.4f}, test_acc:{:.4f}'.format(epoch, train_acc, val_acc, test_acc))print('best_val_acc:', best_val_acc)print('best_test_acc:', best_test_acc)
if __name__ == '__main__':main()
epoch:000, train_acc:0.1357, val_acc:0.1220, test_acc:0.1130
epoch:001, train_acc:0.2143, val_acc:0.1960, test_acc:0.1920
epoch:002, train_acc:0.6857, val_acc:0.4620, test_acc:0.4530
epoch:003, train_acc:0.5143, val_acc:0.2640, test_acc:0.2750
epoch:004, train_acc:0.5000, val_acc:0.2760, test_acc:0.2900
epoch:005, train_acc:0.5500, val_acc:0.3300, test_acc:0.3360
epoch:006, train_acc:0.6143, val_acc:0.3200, test_acc:0.3310
epoch:007, train_acc:0.6357, val_acc:0.3260, test_acc:0.3390
epoch:008, train_acc:0.6143, val_acc:0.3420, test_acc:0.3530
epoch:009, train_acc:0.6500, val_acc:0.3420, test_acc:0.3610
epoch:010, train_acc:0.6357, val_acc:0.3260, test_acc:0.3470
epoch:011, train_acc:0.6429, val_acc:0.3160, test_acc:0.3390
epoch:012, train_acc:0.6571, val_acc:0.3320, test_acc:0.3430
epoch:013, train_acc:0.6500, val_acc:0.3220, test_acc:0.3460
epoch:014, train_acc:0.6357, val_acc:0.3220, test_acc:0.3410
epoch:015, train_acc:0.6571, val_acc:0.3340, test_acc:0.3480
epoch:016, train_acc:0.6500, val_acc:0.3320, test_acc:0.3480
epoch:017, train_acc:0.6571, val_acc:0.3400, test_acc:0.3540
epoch:018, train_acc:0.6714, val_acc:0.3460, test_acc:0.3530
epoch:019, train_acc:0.6929, val_acc:0.3520, test_acc:0.3640
epoch:020, train_acc:0.6929, val_acc:0.3560, test_acc:0.3720
epoch:021, train_acc:0.6929, val_acc:0.3580, test_acc:0.3720
epoch:022, train_acc:0.7071, val_acc:0.3740, test_acc:0.3820
epoch:023, train_acc:0.7071, val_acc:0.3700, test_acc:0.3870
epoch:024, train_acc:0.7071, val_acc:0.3800, test_acc:0.3910
epoch:025, train_acc:0.7071, val_acc:0.3880, test_acc:0.4000
epoch:026, train_acc:0.7286, val_acc:0.3820, test_acc:0.4070
epoch:027, train_acc:0.7643, val_acc:0.4200, test_acc:0.4240
epoch:028, train_acc:0.7857, val_acc:0.4240, test_acc:0.4280
epoch:029, train_acc:0.8000, val_acc:0.4500, test_acc:0.4540
epoch:030, train_acc:0.8214, val_acc:0.4680, test_acc:0.4690
epoch:031, train_acc:0.8286, val_acc:0.4820, test_acc:0.4810
epoch:032, train_acc:0.8286, val_acc:0.4920, test_acc:0.4830
epoch:033, train_acc:0.8286, val_acc:0.5060, test_acc:0.4980
epoch:034, train_acc:0.8286, val_acc:0.5080, test_acc:0.5090
epoch:035, train_acc:0.8357, val_acc:0.5120, test_acc:0.5170
epoch:036, train_acc:0.8429, val_acc:0.5120, test_acc:0.5210
epoch:037, train_acc:0.8571, val_acc:0.5140, test_acc:0.5310
epoch:038, train_acc:0.8714, val_acc:0.5240, test_acc:0.5420
epoch:039, train_acc:0.8714, val_acc:0.5340, test_acc:0.5450
epoch:040, train_acc:0.8929, val_acc:0.5480, test_acc:0.5480
epoch:041, train_acc:0.9071, val_acc:0.5640, test_acc:0.5630
epoch:042, train_acc:0.9071, val_acc:0.5640, test_acc:0.5680
epoch:043, train_acc:0.9214, val_acc:0.5780, test_acc:0.5830
epoch:044, train_acc:0.9286, val_acc:0.5840, test_acc:0.5890
epoch:045, train_acc:0.9429, val_acc:0.6080, test_acc:0.6110
epoch:046, train_acc:0.9429, val_acc:0.6120, test_acc:0.6150
epoch:047, train_acc:0.9429, val_acc:0.6360, test_acc:0.6450
epoch:048, train_acc:0.9429, val_acc:0.6500, test_acc:0.6490
epoch:049, train_acc:0.9429, val_acc:0.6600, test_acc:0.6610
epoch:050, train_acc:0.9571, val_acc:0.6760, test_acc:0.6670
epoch:051, train_acc:0.9571, val_acc:0.6820, test_acc:0.6820
epoch:052, train_acc:0.9643, val_acc:0.7000, test_acc:0.7000
epoch:053, train_acc:0.9643, val_acc:0.7020, test_acc:0.7000
epoch:054, train_acc:0.9643, val_acc:0.7140, test_acc:0.7100
epoch:055, train_acc:0.9643, val_acc:0.7240, test_acc:0.7230
epoch:056, train_acc:0.9643, val_acc:0.7260, test_acc:0.7220
epoch:057, train_acc:0.9643, val_acc:0.7460, test_acc:0.7400
epoch:058, train_acc:0.9643, val_acc:0.7360, test_acc:0.7370
epoch:059, train_acc:0.9643, val_acc:0.7540, test_acc:0.7400
epoch:060, train_acc:0.9643, val_acc:0.7560, test_acc:0.7500
epoch:061, train_acc:0.9643, val_acc:0.7600, test_acc:0.7510
epoch:062, train_acc:0.9714, val_acc:0.7580, test_acc:0.7520
epoch:063, train_acc:0.9714, val_acc:0.7620, test_acc:0.7560
epoch:064, train_acc:0.9714, val_acc:0.7620, test_acc:0.7550
epoch:065, train_acc:0.9714, val_acc:0.7660, test_acc:0.7550
epoch:066, train_acc:0.9714, val_acc:0.7720, test_acc:0.7650
epoch:067, train_acc:0.9786, val_acc:0.7700, test_acc:0.7590
epoch:068, train_acc:0.9714, val_acc:0.7800, test_acc:0.7750
epoch:069, train_acc:0.9786, val_acc:0.7740, test_acc:0.7660
epoch:070, train_acc:0.9714, val_acc:0.7740, test_acc:0.7630
epoch:071, train_acc:0.9786, val_acc:0.7780, test_acc:0.7730
epoch:072, train_acc:0.9786, val_acc:0.7720, test_acc:0.7660
epoch:073, train_acc:0.9857, val_acc:0.7820, test_acc:0.7820
epoch:074, train_acc:0.9786, val_acc:0.7800, test_acc:0.7700
epoch:075, train_acc:0.9857, val_acc:0.7900, test_acc:0.7830
epoch:076, train_acc:0.9786, val_acc:0.7780, test_acc:0.7730
epoch:077, train_acc:0.9929, val_acc:0.7840, test_acc:0.7770
epoch:078, train_acc:0.9786, val_acc:0.7780, test_acc:0.7760
epoch:079, train_acc:0.9929, val_acc:0.7920, test_acc:0.7800
epoch:080, train_acc:0.9857, val_acc:0.7820, test_acc:0.7750
epoch:081, train_acc:0.9929, val_acc:0.7920, test_acc:0.7840
epoch:082, train_acc:0.9929, val_acc:0.7820, test_acc:0.7770
epoch:083, train_acc:0.9929, val_acc:0.7960, test_acc:0.7870
epoch:084, train_acc:0.9929, val_acc:0.7860, test_acc:0.7760
epoch:085, train_acc:0.9929, val_acc:0.7940, test_acc:0.7860
epoch:086, train_acc:0.9929, val_acc:0.7860, test_acc:0.7780
epoch:087, train_acc:0.9929, val_acc:0.7940, test_acc:0.7910
epoch:088, train_acc:0.9929, val_acc:0.7820, test_acc:0.7800
epoch:089, train_acc:0.9929, val_acc:0.7920, test_acc:0.7890
epoch:090, train_acc:0.9929, val_acc:0.7880, test_acc:0.7780
epoch:091, train_acc:0.9929, val_acc:0.7900, test_acc:0.7870
epoch:092, train_acc:0.9929, val_acc:0.7820, test_acc:0.7780
epoch:093, train_acc:0.9929, val_acc:0.7860, test_acc:0.7870
epoch:094, train_acc:0.9929, val_acc:0.7800, test_acc:0.7770
epoch:095, train_acc:0.9929, val_acc:0.7880, test_acc:0.7880
epoch:096, train_acc:0.9929, val_acc:0.7860, test_acc:0.7830
epoch:097, train_acc:0.9929, val_acc:0.7880, test_acc:0.7870
epoch:098, train_acc:0.9929, val_acc:0.7820, test_acc:0.7830
epoch:099, train_acc:0.9929, val_acc:0.7920, test_acc:0.7920
best_val_acc: tensor(0.7960)
best_test_acc: tensor(0.7870)
http://www.xdnf.cn/news/804385.html

相关文章:

  • 优米网视频-在路上第四期:林正刚-外企职场心态
  • 《饭局也疯狂》范伟 黄渤 刘桦 最新喜剧大片下载,DVD 816MB 480P普清下载!
  • Ubuntu10.04版本下載地址
  • Web课程设计:旅游景点网站设计——北京故宫(9页) HTML+CSS+JavaScript 简单DIV布局个人介绍网页模板代码
  • 我的fedora9安装后配置
  • GetLastError返回代码含义
  • Ubuntu 9.04 全部官方衍生版本下载
  • 5个编写技巧,有效提高单元测试实践
  • Delphi 2009 安装序列号
  • win2003 序列号 windows2003 sp2可用序列号大全(准版与企业版)
  • 25款实用的桌面版博客编辑器
  • 重磅!9个中文免费电子书网站合集来了
  • 番茄花园版Windows XP作者被拘留!
  • 视频编码与封装方式详解
  • 基于Java+Vue中国风音乐推介网站设计和实现(源码+LW+部署讲解)
  • 通用vue组件化登录页面
  • Google AdSense申请完全手册
  • WinRAR4.0注册码
  • CSS3中强大的filter(滤镜)属性使用详细解说
  • MySQL procedure详解
  • c++基础(三)
  • MyEclipse 9.0 正式版发布—for linux ,for windows
  • 布拉德皮特不完全档案及星路历程
  • SQL基础——存储过程 语法
  • python实战演练—— 制作超级玛丽游戏!
  • 一些软件所有版本下载地址 (第一期)
  • 水晶报表相关官方软件下载
  • 【知乎问答】有哪些特殊的搜索引擎?
  • 网站管理后台帐号密码暴力破解方法
  • 统一身份认证实现,推广的可能性及优缺点?