|
马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
Transformer的LayerNorm层中做均值计算等运算时用到了keepdim参数:
- class LayerNorm(nn.Module):
- "Construct a layernorm module (see citation for details)."
- def __init__(self, features, eps=1e-6):
- super(LayerNorm, self).__init__()
- self.a_2 = nn.Parameter(torch.ones(features))
- self.b_2 = nn.parameter(torch.zeros(features))
- self.eps = eps
-
- def forward(self, x):
- mean = x.mean(-1, keepdim=True)
- std = x.std(-1, keepdim=True)
- return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
复制代码
keepdim参数的作用是可以使得在做归并操作时原来的数据维度不变。
示例:
- >>>a = torch.arange(12).reshape(3, 4)
- >>>print(a)
- tensor([[ 0, 1, 2, 3],
- [ 4, 5, 6, 7],
- [ 8, 9, 10, 11]])
- >>>print(torch.sum(a, dim=0, keepdim=True))
- >>>print(torch.sum(a, dim=0, keepdim=True).shape)
- tensor([[12, 15, 18, 21]])
- torch.Size([1, 4])
- >>>print(torch.sum(a, dim=0, keepdim=False))
- >>>print(torch.sum(a, dim=0, keepdim=False).shape)
- tensor([12, 15, 18, 21])
- torch.Size([4])
复制代码
使用 keepdim=True后,输出张量的维度不变。
keepdim = False后,输出张量丢失了第一个维度。 |
|