|
马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
class ResidualBlockWithAttention(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
super(ResidualBlockWithAttention, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding)
self.bn2 = nn.BatchNorm2d(out_channels)
self.attention = AttentionModule(out_channels)
self.shortcut = nn.Sequential()
if in_channels != out_channels or stride != 1:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out = self.attention(out)
out += self.shortcut(x)
out = F.relu(out)
return out |
|