消息传递的图神经网络_图神经网络 信息传递 公式-程序员宅基地

一、消息传递范式介绍

消息传递范式是一种聚合邻接节点信息来更新中心节点信息的范式,它将卷积算子推广到不规则数据领域,实现了图与神经网络的连接。此范式包含三个步骤:(1)邻接节点信息变换;(2)邻接节点信息聚合到中心节点;(3)聚合信息变换。

消息传递图神经网络可以描述为:
x i ( k ) = γ ( k ) ( x i ( k − 1 ) , □ j ∈ N ( i )   ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) ) , \mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right), xi(k)=γ(k)(xi(k1),jN(i)ϕ(k)(xi(k1),xj(k1),ej,i)),
x i ( k − 1 ) ∈ R F \mathbf{x}^{(k-1)}_i\in\mathbb{R}^F xi(k1)RF表示(k-1)层中节点i的节点特征, e j , i ∈ R D \mathbf{e}_{j,i} \in \mathbb{R}^D ej,iRD表示从节点j到节点i的边的特征, □ \square 表示可微分的、具有排列不变形的函数,具有排列不变形的函数有和函数、均值函数和最大值函数。 γ \gamma γ ϕ \phi ϕ表示可微分的函数。

二、Pytorch Geometric中的MessagePassing基类

Pytorch Geometric提供了MessagePassing类,实现了消息传播的自动处理,继承该基类可以方便地构造消息传递图神经网络,我们只需要定义函数 ϕ \phi ϕ(即message函数)和函数 γ \gamma γ(即update函数),以及消息聚合方案(aggr=“add”、aggr="mean"或aggr=“max”)。

  • MessagePassing(aggr=“add”, flow=“source_to_target”, node_dim=-2):
    aggr: 定义要使用的聚合方案(“add”、“mean"或"max”)
    flow: 定义消息传递的流向(“source_to_target"或"target_to_source”)
    node_dim: 定义沿着哪个轴线传播

  • MessagePassing.propagate(edge_index, size=None, **kwargs):
    开始传播消息的起始调用。它以edge_index(边的端点的索引)和flow(消息的流向)以及一些额外的数据为参数。
    size=(N,M)设置对称邻接矩阵的形状。

  • MessagePassing.message(…)接受最初传递给propagate函数的所有参数。

  • MessagePassing.aggregate(…)将从源节点传递过来的消息聚合在目标节点上,一般可选的聚合方式有sum,mean和max。

  • MessagePassing.message_and_aggregate(…)融合了邻接节点信息变换和邻接节点信息聚合。

  • MessagePassing.update(aggr_out, …)为每个节点更新节点表征,即实现 γ \gamma γ函数。该函数以聚合函数的输出为第一参数,并接收所有传递给propagate函数的参数。

三、继承MessagePassing类的GCNConv

GCNConv的数学定义为:
x i ( k ) = ∑ j ∈ N ( i ) ∪ { i } 1 deg ⁡ ( i ) ⋅ deg ⁡ ( j ) ⋅ ( Θ ⋅ x j ( k − 1 ) ) , \mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{\Theta} \cdot \mathbf{x}_j^{(k-1)} \right), xi(k)=jN(i){ i}deg(i) deg(j) 1(Θxj(k1)),
其中相邻节点的特征通过权重矩阵 Θ \mathbf{\Theta} Θ进行转换,然后按端点的度进行归一化处理,最后进行加总。这个公式可以分为以下几个步骤:

  1. 向邻接矩阵添加自环边。
  2. 线性转换节点特征矩阵。
  3. 计算归一化系数。
  4. 归一化j中的节点特征。
  5. 将相邻节点特征相加。
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
        # "Add" aggregation (Step 5).
        # flow='source_to_target' 表示消息从源节点传播到目标节点
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]
        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j

# 初始化和调用
conv = GCNConv(16, 32)
x = conv(x, edge_index)

四、复写message函数

class GCNConv(MessagePassing):
    def forward(self, x, edge_index):
        # ....
        return self.propagate(edge_index, x=x, norm=norm, d=d)
    def message(self, x_j, norm, d_i):
        # x_j has shape [E, out_channels]
        return norm.view(-1, 1) * x_j * d_i

五、覆写aggregate函数

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
        
    def forward(self, x, edge_index):
        # ....
        return self.propagate(edge_index, x=x, norm=norm, d=d)

    def aggregate(self, inputs, index, ptr, dim_size):
        print(self.aggr)
        print("`aggregate` is called")
        return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)

六、覆写aggregate函数

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
        
    def forward(self, x, edge_index):
        # ....
        return self.propagate(edge_index, x=x, norm=norm, d=d)

    def aggregate(self, inputs, index, ptr, dim_size):
        print(self.aggr)
        print("`aggregate` is called")
        return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)

七、覆写message_and_aggregate函数

from torch_sparse import SparseTensor

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
        
    def forward(self, x, edge_index):
        # ....
        adjmat = SparseTensor(row=edge_index[0], col=edge_index[1], value=torch.ones(edge_index.shape[1]))
        # 此处传的不再是edge_idex,而是SparseTensor类型的Adjancency Matrix
        return self.propagate(adjmat, x=x, norm=norm, d=d)
    
    def message(self, x_j, norm, d_i):
        # x_j has shape [E, out_channels]
        return norm.view(-1, 1) * x_j * d_i # 这里不管正确性
    
    def aggregate(self, inputs, index, ptr, dim_size):
        print(self.aggr)
        print("`aggregate` is called")
        return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)
    
    def message_and_aggregate(self, adj_t, x, norm):
        print('`message_and_aggregate` is called')

八、覆写update函数

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')

    def update(self, inputs: Tensor) -> Tensor:
        return inputs
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/Lazybones_3/article/details/118061456

智能推荐

视频教程-微信小程序支付java-Java-程序员宅基地

文章浏览阅读354次。微信小程序支付java 我是一名java高级开发工程师,已有七年的工作经验,..._微信小程序支付教学视频

HarmonyOS Next从入门到精通实战精品课

Canvas完成矩形的绘制;华为闹钟的订阅和取消;华为闹钟的基本绘制;华为闹钟的时针-分针-秒针的绘制;华为闹钟的任务列表的样式;关于State的状态更新的须知;关于样式的简单介绍;关于vp和fp的介绍;知乎数据的真实的渲染;实现底部组件的封装;华为闹钟的添加闹钟;完成知乎小案例的UI布局;ForEach的商品列表案例;ForEach的商品列表的Grid布局;ForEach的key的一个简单介绍;Watch的刷题案例-做题的思路;State嵌套更新的实现方式;你画我猜的签字板实现;

vue2实现复制粘贴功能

【代码】vue2实现复制粘贴功能。

SpringBoot引入Layui样式总是出现404

本文主要介绍的是在项目开发中SrpingBoot引入Layui样式出现404的解决方案

STM32与Proteus的串口仿真详细教程与源程序

解决了STM32的Proteus串口收发问题。

深入理解C语言中的 extern`和 static

extern关键字用于声明一个变量或函数,表示其定义在另一个文件或本文件的其他位置。使用extern可以在多个文件之间共享全局变量或函数。static关键字用于声明变量或函数的作用域为仅限于定义它们的文件,同时保持它们的值在函数调用之间持久存在。理解并正确使用extern和static关键字对于管理大型C语言项目中的变量和函数作用域、链接属性和生命周期至关重要。希望这篇文章能帮助你更好地掌握这些概念。

随便推点

VUE3与Uniapp 四 (Class变量和内联样式)

【代码】VUE3与Uniapp 三 (Class变量和内联样式)

前端打包过大如何解决?

前端开发完毕部署到线上是,执行npm run build。当打包过大时,部署到服务端后加载缓慢,如何优化?前端打包成gzip,用new CompressionWebpackPlugin来压缩。服务端nginx设置。我们可以通过执行npm run analyze。可以看到各个包文件大小的区别。当打包过大时,通过压缩gzip的方式,可以看个开始和打包和压缩的大小。

c语言上机入门实验十,C语言入门学习-C上机实验三要求-程序员宅基地

文章浏览阅读1.2k次。2+ x3- x4 +…… -x2n + x2n+1……的值。当某项的绝对值小于10-6时终止。(当x为0.5时,和值为0.333334)【系统函数fabs(x)的功能是计算x的绝对值,前面需加math . h头文件】3.一个球从100m高度自由落下,每次落地后反跳回原高度的一半,再落下,再反弹。编程计算:“它在第10次落地后,反弹多高”;“从第一次落下到第十次反弹,总共经过了多少米”。 (结果:..._入x,x>1,求以下数列之和,当某项绝对值小于10-6为止。

R语言实现PVAR(面板向量自回归模型)_pvar模型r语言-程序员宅基地

文章浏览阅读1.9w次,点赞48次,收藏173次。这次研究了一个问题,要用PVAR(面板向量自回归模型),在网上找的教程基本上都是Stata或者Eviews的教程,而鲜有R实现PVAR的教程,这里总结分享一下我摸索的PVAR用R实现的整个过程。..._pvar模型r语言

you2php镜像仓库,10-S2I镜像定制-程序员宅基地

文章浏览阅读209次。1.概述Source to Image流程为应用的容器化提供了一个标准,实现了自动化。OpenShift默认提供Java WildFly、PHP、Python、Ruby及Perl的S2I Builder镜像。但是现实中的需求是多样化的,特殊的应用构建环境需要用户定制S2I的Builder Image来满足。S2I Builder镜像从本质上来说也是一个普通的Docker镜像,只是在镜像中会加入S2..._php 生成s2i防伪码

matlab版大学物理学,MATLAB可视化大学物理学(第2版)-程序员宅基地

文章浏览阅读2.5k次。前言这是一本将大学基础物理和MATLAB相结合的教材。一、物理部分的构思物理部分分为14章,完全按照大学基础物理的内容顺序编排。与一般的大学物理教材相比,物理部前言这是一本将大学基础物理和MATLAB相结合的教材。一、 物理部分的构思物理部分分为14章,完全按照大学基础物理的内容顺序编排。与一般的大学物理教材相比,物理部分的构思有以下创意。1. 各章先列出大学基础物理的基本内容,将其作为主要线索..._matlab 物理学

推荐文章

热门文章

相关标签