鱼C论坛

 找回密码
 立即注册
查看: 2343|回复: 0

[学习笔记] 查看模型各层的输入输出维度

[复制链接]
发表于 2022-10-15 13:36:47 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
本帖最后由 Handsome_zhou 于 2022-10-15 14:42 编辑

模型各层的输入输出维度在搭建好模型之后,进行模型实例化,就可以打印出来了。


  1. device = "cuda" if torch.cuda.is_available() else "cpu"
  2. print(f"Using {device} device")


  3. class NeuralNetwork(nn.Module):
  4.     def __init__(self):
  5.         super(NeuralNetwork, self).__init__()
  6.         self.flatten = nn.Flatten()
  7.         self.linear_relu_stack = nn.Sequential(
  8.             nn.Linear(28*28, 512),
  9.             nn.ReLU(),
  10.             nn.Linear(512, 512),
  11.             nn.ReLU(),
  12.             nn.Linear(512,10)
  13.         )
  14.    
  15.     def forward(self, x):
  16.         x = self.flatten(x)
  17.         logits = self.linear_relu_stack(x)
  18.         return logits

  19. model = NeuralNetwork().to(device)
  20. print(model)
复制代码





输出结果:
15-1.jpg



  1. import json
  2. import multiprocessing #使用python的多线程
  3. import os
  4. import torch
  5. from torch import nn
  6. from d2l import torch as d2l

  7. #load bert.small预训练模型
  8. devices = d2l.try_all_gpus()
  9. bert,vocab = load_pretrained_model('bert.small',num_hiddens=256,ffn_num_hiddens=512,
  10.                                  num_heads=4,num_layers=2,dropout=0.1,max_len=512,devices=devices)

  11. class BERTClassifier(nn.Module):
  12.     def __init__(self, bert):
  13.         super(BERTClassifier, self).__init__()
  14.         self.encoder = bert.encoder
  15.         self.hidden = bert.hidden
  16.         self.output = nn.Linear(256,3)
  17.         
  18.     def forward(self, inputs):
  19.         tokens_X, segments_X, valid_lens_x = inputs
  20.         encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x)
  21.         return self.output(self.hidden(encoded_X[:, 0, :]))
  22.    
  23. net = BERTClassifier(bert).to(device)#模型实例化
  24. print(net)
复制代码




输出结果:
15-2.jpg


textCNN模型的各层输入输出维度:
  1. import torch
  2. import torch.nn as nn
  3. from torch.nn import functional as F
  4. import math

  5. class textCNN(nn.Module):
  6.     def __init__(self, param):
  7.         super(textCNN, self).__init__()
  8.         ci = 1  # input chanel size
  9.         kernel_num = param['kernel_num'] # output chanel size
  10.         kernel_size = param['kernel_size']
  11.         vocab_size = param['vocab_size']
  12.         embed_dim = param['embed_dim']
  13.         dropout = param['dropout']
  14.         class_num = param['class_num']
  15.         self.param = param
  16.         self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=1)
  17.         self.conv11 = nn.Conv2d(ci, kernel_num, (kernel_size[0], embed_dim))
  18.         self.conv12 = nn.Conv2d(ci, kernel_num, (kernel_size[1], embed_dim))
  19.         self.conv13 = nn.Conv2d(ci, kernel_num, (kernel_size[2], embed_dim))
  20.         self.dropout = nn.Dropout(dropout)
  21.         self.fc1 = nn.Linear(len(kernel_size) * kernel_num, class_num)

  22.     def init_embed(self, embed_matrix):
  23.         self.embed.weight = nn.Parameter(torch.Tensor(embed_matrix))

  24.     @staticmethod
  25.     def conv_and_pool(x, conv):
  26.         # x: (batch, 1, sentence_length,  )
  27.         x = conv(x)
  28.         # x: (batch, kernel_num, H_out, 1)
  29.         x = F.relu(x.squeeze(3))
  30.         # x: (batch, kernel_num, H_out)
  31.         x = F.max_pool1d(x, x.size(2)).squeeze(2)
  32.         #  (batch, kernel_num)
  33.         return x

  34.     def forward(self, x):
  35.         # x: (batch, sentence_length)
  36.         x = self.embed(x)
  37.         # x: (batch, sentence_length, embed_dim)
  38.         # TODO init embed matrix with pre-trained
  39.         x = x.unsqueeze(1)
  40.         # x: (batch, 1, sentence_length, embed_dim)
  41.         x1 = self.conv_and_pool(x, self.conv11)  # (batch, kernel_num)
  42.         x2 = self.conv_and_pool(x, self.conv12)  # (batch, kernel_num)
  43.         x3 = self.conv_and_pool(x, self.conv13)  # (batch, kernel_num)
  44.         x = torch.cat((x1, x2, x3), 1)  # (batch, 3 * kernel_num)
  45.         x = self.dropout(x)
  46.         logit = F.log_softmax(self.fc1(x), dim=1)
  47.         return logit

  48.     def init_weight(self):
  49.         for m in self.modules():
  50.             if isinstance(m, nn.Conv2d):
  51.                 n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  52.                 m.weight.data.normal_(0, math.sqrt(2. / n))
  53.                 if m.bias is not None:
  54.                     m.bias.data.zero_()
  55.             elif isinstance(m, nn.BatchNorm2d):
  56.                 m.weight.data.fill_(1)
  57.                 m.bias.data.zero_()
  58.             elif isinstance(m, nn.Linear):
  59.                 m.weight.data.normal_(0, 0.01)
  60.                 m.bias.data.zero_()
  61.                

  62. textCNN_param = {
  63. #     'vocab_size': len(word2ind),
  64.     'vocab_size': 180,
  65.     'embed_dim': 60,
  66. #     'class_num': len(label_w2n),
  67.     'class_num': 190,
  68.     "kernel_num": 16,
  69.     "kernel_size": [3, 4, 5],
  70.     "dropout": 0.5,
  71. }
  72. dataLoader_param = {
  73.     'batch_size': 128,
  74.     'shuffle': True,
  75. }

  76. net = textCNN(textCNN_param)
  77. print(net)
复制代码



输出结果:
15-3.jpg
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-4-23 04:40

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表