Handsome_zhou 发表于 2023-8-8 11:42:03

从LayerNorm中看keepdim参数的作用

Transformer的LayerNorm层中做均值计算等运算时用到了keepdim参数:
class LayerNorm(nn.Module):
    "Construct a layernorm module (see citation for details)."
    def __init__(self, features, eps=1e-6):
      super(LayerNorm, self).__init__()
      self.a_2 = nn.Parameter(torch.ones(features))
      self.b_2 = nn.parameter(torch.zeros(features))
      self.eps = eps
   
    def forward(self, x):
      mean = x.mean(-1, keepdim=True)
      std = x.std(-1, keepdim=True)
      return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

keepdim参数的作用是可以使得在做归并操作时原来的数据维度不变。

示例:
>>>a = torch.arange(12).reshape(3, 4)
>>>print(a)
tensor([[ 0,1,2,3],
      [ 4,5,6,7],
      [ 8,9, 10, 11]])

>>>print(torch.sum(a, dim=0, keepdim=True))
>>>print(torch.sum(a, dim=0, keepdim=True).shape)
tensor([])
torch.Size()

>>>print(torch.sum(a, dim=0, keepdim=False))
>>>print(torch.sum(a, dim=0, keepdim=False).shape)
tensor()
torch.Size()

使用 keepdim=True后,输出张量的维度不变。
keepdim = False后,输出张量丢失了第一个维度。
页: [1]
查看完整版本: 从LayerNorm中看keepdim参数的作用