从LayerNorm中看keepdim参数的作用
Transformer的LayerNorm层中做均值计算等运算时用到了keepdim参数:class LayerNorm(nn.Module):
"Construct a layernorm module (see citation for details)."
def __init__(self, features, eps=1e-6):
super(LayerNorm, self).__init__()
self.a_2 = nn.Parameter(torch.ones(features))
self.b_2 = nn.parameter(torch.zeros(features))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
keepdim参数的作用是可以使得在做归并操作时原来的数据维度不变。
示例:
>>>a = torch.arange(12).reshape(3, 4)
>>>print(a)
tensor([[ 0,1,2,3],
[ 4,5,6,7],
[ 8,9, 10, 11]])
>>>print(torch.sum(a, dim=0, keepdim=True))
>>>print(torch.sum(a, dim=0, keepdim=True).shape)
tensor([])
torch.Size()
>>>print(torch.sum(a, dim=0, keepdim=False))
>>>print(torch.sum(a, dim=0, keepdim=False).shape)
tensor()
torch.Size()
使用 keepdim=True后,输出张量的维度不变。
keepdim = False后,输出张量丢失了第一个维度。
页:
[1]