关于一个“基于EEG的情绪识别 正则化图神经网络”代码求助
import torchimport torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch_geometric.nn import SGConv, global_add_pool
from torch_scatter import scatter_add
def maybe_num_nodes(index, num_nodes=None):
return index.max().item() + 1 if num_nodes is None else num_nodes
def add_remaining_self_loops(edge_index,
edge_weight=None,
fill_value=1,
num_nodes=None):
num_nodes = maybe_num_nodes(edge_index, num_nodes)
row, col = edge_index
mask = row != col
inv_mask = 1 - mask
loop_weight = torch.full(
(num_nodes, ),
fill_value,
dtype=None if edge_weight is None else edge_weight.dtype,
device=edge_index.device)
if edge_weight is not None:
assert edge_weight.numel() == edge_index.size(1)
remaining_edge_weight = edge_weight
if remaining_edge_weight.numel() > 0:
loop_weight] = remaining_edge_weight
edge_weight = torch.cat(, loop_weight], dim=0)
loop_index = torch.arange(0, num_nodes, dtype=row.dtype, device=row.device)
loop_index = loop_index.unsqueeze(0).repeat(2, 1)
edge_index = torch.cat(, loop_index], dim=1)
return edge_index, edge_weight
class NewSGConv(SGConv):
def __init__(self, num_features, num_classes, K=1, cached=False,
bias=True):
super(NewSGConv, self).__init__(num_features, num_classes, K=K, cached=cached, bias=bias)
# allow negative edge weights
@staticmethod
def norm(edge_index, num_nodes, edge_weight, improved=False, dtype=None):
if edge_weight is None:
edge_weight = torch.ones((edge_index.size(1), ),
dtype=dtype,
device=edge_index.device)
fill_value = 1 if not improved else 2
edge_index, edge_weight = add_remaining_self_loops(
edge_index, edge_weight, fill_value, num_nodes)
row, col = edge_index
deg = scatter_add(torch.abs(edge_weight), row, dim=0, dim_size=num_nodes)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt = 0
return edge_index, deg_inv_sqrt * edge_weight * deg_inv_sqrt
def forward(self, x, edge_index, edge_weight=None):
""""""
if not self.cached or self.cached_result is None:
edge_index, norm = NewSGConv.norm(
edge_index, x.size(0), edge_weight, dtype=x.dtype)
for k in range(self.K):
x = self.propagate(edge_index, x=x, norm=norm)
self.cached_result = x
return self.lin(self.cached_result)
def message(self, x_j, norm):
# x_j: (batch_size*num_nodes*num_nodes, num_features)
# norm: (batch_size*num_nodes*num_nodes, )
return norm.view(-1, 1) * x_j
class ReverseLayerF(Function):
@staticmethod
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
output = grad_output.neg() * ctx.alpha
return output, None
class SymSimGCNNet(torch.nn.Module):
def __init__(self, num_nodes, learn_edge_weight, edge_weight, num_features, num_hiddens, num_classes, K, dropout=0.5, domain_adaptation=""):
"""
num_nodes: number of nodes in the graph
learn_edge_weight: if True, the edge_weight is learnable
edge_weight: initial edge matrix
num_features: feature dim for each node/channel
num_hiddens: a tuple of hidden dimensions
num_classes: number of emotion classes
K: number of layers
dropout: dropout rate in final linear layer
domain_adaptation: RevGrad
"""
super(SymSimGCNNet, self).__init__()
self.domain_adaptation = domain_adaptation
self.num_nodes = num_nodes
self.xs, self.ys = torch.tril_indices(self.num_nodes, self.num_nodes, offset=0)
edge_weight = edge_weight.reshape(self.num_nodes, self.num_nodes) # strict lower triangular values
self.edge_weight = nn.Parameter(edge_weight, requires_grad=learn_edge_weight)
self.dropout = dropout
self.conv1 = NewSGConv(num_features=num_features, num_classes=num_hiddens, K=K)
self.fc = nn.Linear(num_hiddens, num_classes)
if self.domain_adaptation in ["RevGrad"]:
self.domain_classifier = nn.Linear(num_hiddens, 2)
def forward(self, data, alpha=0):
batch_size = len(data.y)
x, edge_index = data.x, data.edge_index
edge_weight = torch.zeros((self.num_nodes, self.num_nodes), device=edge_index.device)
edge_weight = self.edge_weight
edge_weight = edge_weight + edge_weight.transpose(1,0) - torch.diag(edge_weight.diagonal()) # copy values from lower tri to upper tri
edge_weight = edge_weight.reshape(-1).repeat(batch_size)
x = F.relu(self.conv1(x, edge_index, edge_weight))
# domain classification
domain_output = None
if self.domain_adaptation in ["RevGrad"]:
reverse_x = ReverseLayerF.apply(x, alpha)
domain_output = self.domain_classifier(reverse_x)
x = global_add_pool(x, data.batch, size=batch_size)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.fc(x)
return x, domain_output
这是我下载下来的源码,但是这里面我找不到他提取数据集的代码,有没有大佬可以帮帮忙看看,这里哪个是提取数据集的代码?
因为是新手,想跑一跑别人的代码学习一下。
页:
[1]