鱼C论坛

 找回密码
 立即注册
查看: 2257|回复: 0

[学习笔记] 从LayerNorm中看keepdim参数的作用

[复制链接]
发表于 2023-8-8 11:42:03 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
Transformer的LayerNorm层中做均值计算等运算时用到了keepdim参数:
  1. class LayerNorm(nn.Module):
  2.     "Construct a layernorm module (see citation for details)."
  3.     def __init__(self, features, eps=1e-6):
  4.         super(LayerNorm, self).__init__()
  5.         self.a_2 = nn.Parameter(torch.ones(features))
  6.         self.b_2 = nn.parameter(torch.zeros(features))
  7.         self.eps = eps
  8.    
  9.     def forward(self, x):
  10.         mean = x.mean(-1, keepdim=True)
  11.         std = x.std(-1, keepdim=True)
  12.         return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
复制代码


keepdim参数的作用是可以使得在做归并操作时原来的数据维度不变。

示例:
  1. >>>a = torch.arange(12).reshape(3, 4)
  2. >>>print(a)
  3. tensor([[ 0,  1,  2,  3],
  4.         [ 4,  5,  6,  7],
  5.         [ 8,  9, 10, 11]])

  6. >>>print(torch.sum(a, dim=0, keepdim=True))
  7. >>>print(torch.sum(a, dim=0, keepdim=True).shape)
  8. tensor([[12, 15, 18, 21]])
  9. torch.Size([1, 4])

  10. >>>print(torch.sum(a, dim=0, keepdim=False))
  11. >>>print(torch.sum(a, dim=0, keepdim=False).shape)
  12. tensor([12, 15, 18, 21])
  13. torch.Size([4])
复制代码


使用 keepdim=True后,输出张量的维度不变。
keepdim = False后,输出张量丢失了第一个维度。
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-6-10 21:05

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表