|
马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
class DeeperAutoencoderWithMoreAttention(nn.Module):
def __init__(self):
super(DeeperAutoencoderWithMoreAttention, self).__init__()
# Encoder
self.encoder = nn.Sequential(
ResidualBlockWithAttention(3, 32),
ResidualBlockWithAttention(32, 32),
nn.MaxPool2d(2, stride=2),
ResidualBlockWithAttention(32, 64),
ResidualBlockWithAttention(64, 64),
nn.MaxPool2d(2, stride=2),
ResidualBlockWithAttention(64, 128),
ResidualBlockWithAttention(128, 128),
nn.MaxPool2d(2, stride=2),
ResidualBlockWithAttention(128, 256),
ResidualBlockWithAttention(256, 256),
nn.MaxPool2d(2, stride=2),
ResidualBlockWithAttention(256, 512),
nn.MaxPool2d(2, stride=2),
)
# Decoder
self.decoder = nn.Sequential(
nn.Upsample(scale_factor=2),
ResidualBlockWithAttention(512, 512),
nn.Upsample(scale_factor=2),
ResidualBlockWithAttention(512, 256),
ResidualBlockWithAttention(256, 256),
nn.Upsample(scale_factor=2),
ResidualBlockWithAttention(256, 128),
ResidualBlockWithAttention(128, 128),
nn.Upsample(scale_factor=2),
ResidualBlockWithAttention(128, 64),
ResidualBlockWithAttention(64, 64),
nn.Upsample(scale_factor=2),
ResidualBlockWithAttention(64, 32),
ResidualBlockWithAttention(32, 32),
nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1),
nn.ReLU()
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
内容有亿点多,请做好准备
以上代码定义了一个名为DeeperAutoencoderWithMoreAttention的深度自动编码器模型。下面是对代码的详细解释:
创建类 DeeperAutoencoderWithMoreAttention 并继承自 nn.Module,这是一个 PyTorch 中的基类,用于定义神经网络模块。
在 __init__ 方法中定义模型的结构。该模型由编码器(encoder)和解码器(decoder)两部分组成。
编码器部分按照顺序连接了多个 ResidualBlockWithAttention 残差块和池化层,用于逐步提取输入数据的特征并减小空间维度。具体结构如下:
输入通道数为 3,输出通道数为 32 的 ResidualBlockWithAttention 残差块。
输出通道数为 32 的 ResidualBlockWithAttention 残差块。
使用 2x2 的最大池化层进行下采样。
输出通道数为 32 的 ResidualBlockWithAttention 残差块。
输出通道数为 64 的 ResidualBlockWithAttention 残差块。
使用 2x2 的最大池化层进行下采样。
输出通道数为 64 的 ResidualBlockWithAttention 残差块。
输出通道数为 128 的 ResidualBlockWithAttention 残差块。
使用 2x2 的最大池化层进行下采样。
输出通道数为 128 的 ResidualBlockWithAttention 残差块。
输出通道数为 256 的 ResidualBlockWithAttention 残差块。
使用 2x2 的最大池化层进行下采样。
输出通道数为 256 的 ResidualBlockWithAttention 残差块。
输出通道数为 512 的 ResidualBlockWithAttention 残差块。
使用 2x2 的最大池化层进行下采样。
解码器部分按照顺序连接了多个 ResidualBlockWithAttention 残差块、上采样层和一个卷积层,用于恢复编码器输出的特征到原始输入图像的尺寸。具体结构如下:
使用 2 倍的上采样因子对输入进行上采样。
输出通道数为 512 的 ResidualBlockWithAttention 残差块。
使用 2 倍的上采样因子对输入进行上采样。
输出通道数为 256 的 ResidualBlockWithAttention 残差块。
输出通道数为 256 的 ResidualBlockWithAttention 残差块。
使用 2 倍的上采样因子对输入进行上采样。
输出通道数为 128 的 ResidualBlockWithAttention 残差块。
输出通道数为 128 的 ResidualBlockWithAttention 残差块。
使用 2 倍的上采样因子对输入进行上采样。
输出通道数为 64 的 ResidualBlockWithAttention 残差块。
输出通道数为 64 的 ResidualBlockWithAttention 残差块。
使用 2 倍的上采样因子对输入进行上采样。
输出通道数为 32 的 ResidualBlockWithAttention 残差块。
输出通道数为 32 的 ResidualBlockWithAttention 残差块。
使用大小为 3x3、步长为 1、填充为 1 的卷积层将通道数为 32 的特征图转换为 3 个通道的图像。
使用 ReLU 激活函数。
forward 方法定义了前向传播过程。输入图像经过编码器和解码器后返回解码器输出的结果。
以下是模块图示意图:
- ----------------------------
- | |
- | DeeperAutoencoder |
- --------------------------------------------- |
- | | |
- | V |
- Input Image ----------- -------------- --------- ----------- Output Image
- --------------->| Encoder |--------->| Decoder |-------->| ReLU |------>| ReLU |------> ------------
- ----------- -------------- --------- -----------
- Encoder:
- ------------------------------------------------------------------------------------------------------------------
- | | | | |
- | ResidualBlockWithAttention | ResidualBlockWithAttention | MaxPool2d | ResidualBlockWithAttention |
- | (3, 32) | (32, 32) | (kernel_size=2, stride=2)| (256, 512) |
- | | | | |
- ------------------------------------------------------------------------------------------------------------------
-
- Decoder:
- ------------------------------------------------------------------------------------------------------------------
- | Upsample | | | |
- | (scale_factor=2) | | | |
- | | | | |
- | ResidualBlockWithAttention (512, 512)| ResidualBlockWithAttention| Conv2d | ReLU |
- | | (64, 128) |(32, 3, kernel_size=3, | |
- | | | stride=1, padding=1)| |
- | | | | |
- ------------------------------------------------------------------------------------------------------------------
复制代码
|
|