鱼C论坛

 找回密码
 立即注册
查看: 538|回复: 0

关于一个“基于EEG的情绪识别 正则化图神经网络”代码求助

[复制链接]
发表于 2022-10-27 17:11:07 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.autograd import Function
  5. from torch_geometric.nn import SGConv, global_add_pool
  6. from torch_scatter import scatter_add


  7. def maybe_num_nodes(index, num_nodes=None):
  8.     return index.max().item() + 1 if num_nodes is None else num_nodes


  9. def add_remaining_self_loops(edge_index,
  10.                              edge_weight=None,
  11.                              fill_value=1,
  12.                              num_nodes=None):
  13.     num_nodes = maybe_num_nodes(edge_index, num_nodes)
  14.     row, col = edge_index

  15.     mask = row != col
  16.     inv_mask = 1 - mask
  17.     loop_weight = torch.full(
  18.         (num_nodes, ),
  19.         fill_value,
  20.         dtype=None if edge_weight is None else edge_weight.dtype,
  21.         device=edge_index.device)

  22.     if edge_weight is not None:
  23.         assert edge_weight.numel() == edge_index.size(1)
  24.         remaining_edge_weight = edge_weight[inv_mask]
  25.         if remaining_edge_weight.numel() > 0:
  26.             loop_weight[row[inv_mask]] = remaining_edge_weight
  27.         edge_weight = torch.cat([edge_weight[mask], loop_weight], dim=0)

  28.     loop_index = torch.arange(0, num_nodes, dtype=row.dtype, device=row.device)
  29.     loop_index = loop_index.unsqueeze(0).repeat(2, 1)
  30.     edge_index = torch.cat([edge_index[:, mask], loop_index], dim=1)

  31.     return edge_index, edge_weight


  32. class NewSGConv(SGConv):
  33.     def __init__(self, num_features, num_classes, K=1, cached=False,
  34.                  bias=True):
  35.         super(NewSGConv, self).__init__(num_features, num_classes, K=K, cached=cached, bias=bias)

  36.     # allow negative edge weights
  37.     @staticmethod
  38.     def norm(edge_index, num_nodes, edge_weight, improved=False, dtype=None):
  39.         if edge_weight is None:
  40.             edge_weight = torch.ones((edge_index.size(1), ),
  41.                                      dtype=dtype,
  42.                                      device=edge_index.device)

  43.         fill_value = 1 if not improved else 2
  44.         edge_index, edge_weight = add_remaining_self_loops(
  45.             edge_index, edge_weight, fill_value, num_nodes)
  46.         row, col = edge_index
  47.         deg = scatter_add(torch.abs(edge_weight), row, dim=0, dim_size=num_nodes)
  48.         deg_inv_sqrt = deg.pow(-0.5)
  49.         deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

  50.         return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

  51.     def forward(self, x, edge_index, edge_weight=None):
  52.         """"""
  53.         if not self.cached or self.cached_result is None:
  54.             edge_index, norm = NewSGConv.norm(
  55.                 edge_index, x.size(0), edge_weight, dtype=x.dtype)

  56.             for k in range(self.K):
  57.                 x = self.propagate(edge_index, x=x, norm=norm)
  58.             self.cached_result = x

  59.         return self.lin(self.cached_result)

  60.     def message(self, x_j, norm):
  61.         # x_j: (batch_size*num_nodes*num_nodes, num_features)
  62.         # norm: (batch_size*num_nodes*num_nodes, )
  63.         return norm.view(-1, 1) * x_j


  64. class ReverseLayerF(Function):
  65.     @staticmethod
  66.     def forward(ctx, x, alpha):
  67.         ctx.alpha = alpha
  68.         return x.view_as(x)

  69.     @staticmethod
  70.     def backward(ctx, grad_output):
  71.         output = grad_output.neg() * ctx.alpha
  72.         return output, None


  73. class SymSimGCNNet(torch.nn.Module):
  74.     def __init__(self, num_nodes, learn_edge_weight, edge_weight, num_features, num_hiddens, num_classes, K, dropout=0.5, domain_adaptation=""):
  75.         """
  76.             num_nodes: number of nodes in the graph
  77.             learn_edge_weight: if True, the edge_weight is learnable
  78.             edge_weight: initial edge matrix
  79.             num_features: feature dim for each node/channel
  80.             num_hiddens: a tuple of hidden dimensions
  81.             num_classes: number of emotion classes
  82.             K: number of layers
  83.             dropout: dropout rate in final linear layer
  84.             domain_adaptation: RevGrad
  85.         """
  86.         super(SymSimGCNNet, self).__init__()
  87.         self.domain_adaptation = domain_adaptation
  88.         self.num_nodes = num_nodes
  89.         self.xs, self.ys = torch.tril_indices(self.num_nodes, self.num_nodes, offset=0)
  90.         edge_weight = edge_weight.reshape(self.num_nodes, self.num_nodes)[self.xs, self.ys] # strict lower triangular values
  91.         self.edge_weight = nn.Parameter(edge_weight, requires_grad=learn_edge_weight)
  92.         self.dropout = dropout
  93.         self.conv1 = NewSGConv(num_features=num_features, num_classes=num_hiddens[0], K=K)
  94.         self.fc = nn.Linear(num_hiddens[0], num_classes)
  95.         if self.domain_adaptation in ["RevGrad"]:
  96.             self.domain_classifier = nn.Linear(num_hiddens[0], 2)

  97.     def forward(self, data, alpha=0):
  98.         batch_size = len(data.y)
  99.         x, edge_index = data.x, data.edge_index
  100.         edge_weight = torch.zeros((self.num_nodes, self.num_nodes), device=edge_index.device)
  101.         edge_weight[self.xs.to(edge_weight.device), self.ys.to(edge_weight.device)] = self.edge_weight
  102.         edge_weight = edge_weight + edge_weight.transpose(1,0) - torch.diag(edge_weight.diagonal()) # copy values from lower tri to upper tri
  103.         edge_weight = edge_weight.reshape(-1).repeat(batch_size)
  104.         x = F.relu(self.conv1(x, edge_index, edge_weight))
  105.         
  106.         # domain classification
  107.         domain_output = None
  108.         if self.domain_adaptation in ["RevGrad"]:
  109.             reverse_x = ReverseLayerF.apply(x, alpha)
  110.             domain_output = self.domain_classifier(reverse_x)
  111.         x = global_add_pool(x, data.batch, size=batch_size)
  112.         x = F.dropout(x, p=self.dropout, training=self.training)
  113.         x = self.fc(x)
  114.         return x, domain_output
复制代码


这是我下载下来的源码,但是这里面我找不到他提取数据集的代码,有没有大佬可以帮帮忙看看,这里哪个是提取数据集的代码?
因为是新手,想跑一跑别人的代码学习一下。
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-3-29 07:22

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表