鱼C论坛

 找回密码
 立即注册
查看: 3302|回复: 15

[已解决]基于VGGNet网络的多标签分类任务

[复制链接]
发表于 2023-4-22 10:33:51 | 显示全部楼层 |阅读模式
10鱼币
在基于原项目代码的基础上,根据以下要求做出修改使其满足多标签分类任务,并高亮对应文件的相关代码段注明修改原因:

1.修改softmax激活函数为sigmoid激活函数使其在输出层输出多个标签类别。

2.修改loss函数CrossEntropyLoss为BCELoss。

3.增添评价指标Hamming Loss,Accuracyexam,Precisionexam,Recallexam,Fβexam,并输出这些评价结果。

4.可视化训练时的loss损失,横坐标为epochs,epochs为训练轮数,纵坐标分别为train_loss和val_loss,train_loss为训练过程中的损失,val_loss为验证过程中的损失。

5.可视化预测结果,展示的图片中有输出的多标签类别。

6.当涉及到变量及函数定义时,请使用项目文件中的相关变量及函数定义,并与上下代码段保证顺畅。新定义的变量及函数除外。

将修改好的项目代码文件逐个展示自己的完整内容,并说明每个文件的具体用法。


主要项目代码如下:
model.py:定义了VGG模型的架构,并提供了辅助函数和配置信息。
  1. import torch.nn as nn
  2. import torch

  3. # official pretrain weights
  4. model_urls = {
  5.     'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
  6.     'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
  7.     'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
  8.     'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'
  9. }


  10. class VGG(nn.Module):
  11.     def __init__(self, features, num_classes=1000, init_weights=False):
  12.         super(VGG, self).__init__()
  13.         self.features = features
  14.         self.classifier = nn.Sequential(
  15.             nn.Linear(512*7*7, 4096),
  16.             nn.ReLU(True),
  17.             nn.Dropout(p=0.5),
  18.             nn.Linear(4096, 4096),
  19.             nn.ReLU(True),
  20.             nn.Dropout(p=0.5),
  21.             nn.Linear(4096, num_classes)
  22.         )
  23.         if init_weights:
  24.             self._initialize_weights()

  25.     def forward(self, x):
  26.         # N x 3 x 224 x 224
  27.         x = self.features(x)
  28.         # N x 512 x 7 x 7
  29.         x = torch.flatten(x, start_dim=1)
  30.         # N x 512*7*7
  31.         x = self.classifier(x)
  32.         return x

  33.     def _initialize_weights(self):
  34.         for m in self.modules():
  35.             if isinstance(m, nn.Conv2d):
  36.                 # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  37.                 nn.init.xavier_uniform_(m.weight)
  38.                 if m.bias is not None:
  39.                     nn.init.constant_(m.bias, 0)
  40.             elif isinstance(m, nn.Linear):
  41.                 nn.init.xavier_uniform_(m.weight)
  42.                 # nn.init.normal_(m.weight, 0, 0.01)
  43.                 nn.init.constant_(m.bias, 0)


  44. def make_features(cfg: list):
  45.     layers = []
  46.     in_channels = 3
  47.     for v in cfg:
  48.         if v == "M":
  49.             layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
  50.         else:
  51.             conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
  52.             layers += [conv2d, nn.ReLU(True)]
  53.             in_channels = v
  54.     return nn.Sequential(*layers)


  55. cfgs = {
  56.     'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
  57.     'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
  58.     'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
  59.     'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
  60. }


  61. def vgg(model_name="vgg16", **kwargs):
  62.     assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name)
  63.     cfg = cfgs[model_name]

  64.     model = VGG(make_features(cfg), **kwargs)
  65.     return model
复制代码

predict.py:基于训练好的VGG模型,对输入的图像进行分类预测,并输出预测结果。
  1. import os
  2. import json

  3. import torch
  4. from PIL import Image
  5. from torchvision import transforms
  6. import matplotlib.pyplot as plt

  7. from model import vgg


  8. def main():
  9.     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

  10.     data_transform = transforms.Compose(
  11.         [transforms.Resize((224, 224)),
  12.          transforms.ToTensor(),
  13.          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

  14.     # load image
  15.     img_path = "../tulip.jpg"
  16.     assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
  17.     img = Image.open(img_path)
  18.     plt.imshow(img)
  19.     # [N, C, H, W]
  20.     img = data_transform(img)
  21.     # expand batch dimension
  22.     img = torch.unsqueeze(img, dim=0)

  23.     # read class_indict
  24.     json_path = './class_indices.json'
  25.     assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

  26.     with open(json_path, "r") as f:
  27.         class_indict = json.load(f)
  28.    
  29.     # create model
  30.     model = vgg(model_name="vgg16", num_classes=5).to(device)
  31.     # load model weights
  32.     weights_path = "./vgg16Net.pth"
  33.     assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
  34.     model.load_state_dict(torch.load(weights_path, map_location=device))

  35.     model.eval()
  36.     with torch.no_grad():
  37.         # predict class
  38.         output = torch.squeeze(model(img.to(device))).cpu()
  39.         predict = torch.softmax(output, dim=0)
  40.         predict_cla = torch.argmax(predict).numpy()

  41.     print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
  42.                                                  predict[predict_cla].numpy())
  43.     plt.title(print_res)
  44.     for i in range(len(predict)):
  45.         print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
  46.                                                   predict[i].numpy()))
  47.     plt.show()


  48. if __name__ == '__main__':
  49.     main()
复制代码

train.py:基于给定的数据集,训练VGG模型,最终得到可用于图像分类的模型文件。
  1. import os
  2. import sys
  3. import json

  4. import torch
  5. import torch.nn as nn
  6. from torchvision import transforms, datasets
  7. import torch.optim as optim
  8. from tqdm import tqdm

  9. from model import vgg


  10. def main():
  11.     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  12.     print("using {} device.".format(device))

  13.     data_transform = {
  14.         "train": transforms.Compose([transforms.RandomResizedCrop(224),
  15.                                      transforms.RandomHorizontalFlip(),
  16.                                      transforms.ToTensor(),
  17.                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
  18.         "val": transforms.Compose([transforms.Resize((224, 224)),
  19.                                    transforms.ToTensor(),
  20.                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

  21.     data_root = os.path.abspath(os.path.join(os.getcwd(), "../"))  # get data root path
  22.     image_path = os.path.join(data_root, "data_set", "PlantVillage")  # flower data set path
  23.     assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
  24.     train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
  25.                                          transform=data_transform["train"])
  26.     train_num = len(train_dataset)

  27.     # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
  28.     flower_list = train_dataset.class_to_idx
  29.     cla_dict = dict((val, key) for key, val in flower_list.items())
  30.     # write dict into json file
  31.     json_str = json.dumps(cla_dict, indent=4)
  32.     with open('class_indices.json', 'w') as json_file:
  33.         json_file.write(json_str)

  34.     batch_size = 32
  35.     nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
  36.     print('Using {} dataloader workers every process'.format(nw))

  37.     train_loader = torch.utils.data.DataLoader(train_dataset,
  38.                                                batch_size=batch_size, shuffle=True,
  39.                                                num_workers=nw)

  40.     validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
  41.                                             transform=data_transform["val"])
  42.     val_num = len(validate_dataset)
  43.     validate_loader = torch.utils.data.DataLoader(validate_dataset,
  44.                                                   batch_size=batch_size, shuffle=False,
  45.                                                   num_workers=nw)
  46.     print("using {} images for training, {} images for validation.".format(train_num,
  47.                                                                            val_num))

  48.     # test_data_iter = iter(validate_loader)
  49.     # test_image, test_label = test_data_iter.next()

  50.     model_name = "vgg16"
  51.     net = vgg(model_name=model_name, num_classes=4, init_weights=True)
  52.     net.to(device)
  53.     loss_function = nn.CrossEntropyLoss()
  54.     optimizer = optim.Adam(net.parameters(), lr=0.0001)

  55.     epochs = 30
  56.     best_acc = 0.0
  57.     save_path = './{}Net.pth'.format(model_name)
  58.     train_steps = len(train_loader)
  59.     for epoch in range(epochs):
  60.         # train
  61.         net.train()
  62.         running_loss = 0.0
  63.         train_bar = tqdm(train_loader, file=sys.stdout)
  64.         for step, data in enumerate(train_bar):
  65.             images, labels = data
  66.             optimizer.zero_grad()
  67.             outputs = net(images.to(device))
  68.             loss = loss_function(outputs, labels.to(device))
  69.             loss.backward()
  70.             optimizer.step()

  71.             # print statistics
  72.             running_loss += loss.item()

  73.             train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
  74.                                                                      epochs,
  75.                                                                      loss)

  76.         # validate
  77.         net.eval()
  78.         acc = 0.0  # accumulate accurate number / epoch
  79.         with torch.no_grad():
  80.             val_bar = tqdm(validate_loader, file=sys.stdout)
  81.             for val_data in val_bar:
  82.                 val_images, val_labels = val_data
  83.                 outputs = net(val_images.to(device))
  84.                 predict_y = torch.max(outputs, dim=1)[1]
  85.                 acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

  86.         val_accurate = acc / val_num
  87.         print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
  88.               (epoch + 1, running_loss / train_steps, val_accurate))

  89.         if val_accurate > best_acc:
  90.             best_acc = val_accurate
  91.             torch.save(net.state_dict(), save_path)

  92.     print('Finished Training')


  93. if __name__ == '__main__':
  94.     main()
复制代码
最佳答案
2023-4-22 10:33:52
黎明丿晓小 发表于 2023-4-22 16:42
在运行以下代码时,报错“RuntimeError: Error(s) in loading state_dict for VGG:
        size mismatch for  ...

这个错误是因为您试图将预训练模型的权重(输出层有4个神经元)加载到具有不同输出层结构的模型中(输出层有5个神经元)。这会导致权重和模型的形状不匹配,因此无法加载权重。

要解决这个问题,您需要确保模型的类别数量与预训练模型权重的类别数量一致。在这个例子中,您应该将模型的类别数量改为4:

  1. model = vgg(model_name="vgg16", num_classes=4).to(device)
复制代码


然后,您需要更新`class_indices.json`文件,确保它包含正确的类别数量和类别名称。您可以手动编辑文件或在训练模型时自动生成该文件。

请注意,如果您需要处理5个类别,您需要重新训练模型,以适应5个类别的任务。在这种情况下,您需要使用适当数量的类别重新训练模型,并确保`class_indices.json`文件包含正确的类别信息。
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

 楼主| 发表于 2023-4-22 11:35:25 | 显示全部楼层
isdkz 发表于 2023-4-22 10:39
在基于原项目代码的基础上,根据要求做出修改以满足多标签分类任务。以下是修改后的代码文件及修改原因:
...

我有以下问题需要你来回答:
1.np.vstack中的np是numpy吗?
2.val_loss是如何定义的?
3.请在基于predict.py代码的基础上,修改部分代码使其满足以下图片要求:
   
   其中,red和dress为多标签预测,后面对应的数字为置信度,也就是概率,可以设置一个阈值threshold=0.5,当预测一个图片的输出结果时,某一个标签的置信度大于阈值即可显示在输出结果图片上面。
   
20190326113925310.png
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

 楼主| 发表于 2023-4-22 11:46:15 | 显示全部楼层
isdkz 发表于 2023-4-22 11:38
1. 是的,`np`通常是`numpy`库的缩写。在Python程序中,您经常会看到这样的导入语句:`import numpy as n ...

我的意思是val_loss在以下代码是如何定义的:
  1. import os
  2. import sys
  3. import json

  4. import torch
  5. import torch.nn as nn
  6. from torchvision import transforms, datasets
  7. import torch.optim as optim
  8. from tqdm import tqdm
  9. from sklearn.metrics import hamming_loss,accuracy_score,precision_score,recall_score,fbeta_score
  10. import numpy as np
  11. import matplotlib.pyplot as plt

  12. from model import vgg

  13. # Define the evaluation function
  14. def evaluate(y_true, y_pred):
  15.     hamming = hamming_loss(y_true, y_pred)
  16.     accuracy = accuracy_score(y_true, y_pred)
  17.     precision = precision_score(y_true, y_pred, average='micro')
  18.     recall = recall_score(y_true, y_pred, average='micro')
  19.     fbeta = fbeta_score(y_true, y_pred, beta=1, average='micro')

  20.     return hamming, accuracy, precision, recall, fbeta

  21. def main():
  22.     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  23.     print("using {} device.".format(device))

  24.     data_transform = {
  25.         "train": transforms.Compose([transforms.RandomResizedCrop(224),
  26.                                      transforms.RandomHorizontalFlip(),
  27.                                      transforms.ToTensor(),
  28.                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
  29.         "val": transforms.Compose([transforms.Resize((224, 224)),
  30.                                    transforms.ToTensor(),
  31.                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

  32.     data_root = os.path.abspath(os.path.join(os.getcwd(), "../"))  # get data root path
  33.     image_path = os.path.join(data_root, "data_set", "PlantVillage")  # flower data set path
  34.     assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
  35.     train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
  36.                                          transform=data_transform["train"])
  37.     train_num = len(train_dataset)

  38.     # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
  39.     flower_list = train_dataset.class_to_idx
  40.     cla_dict = dict((val, key) for key, val in flower_list.items())
  41.     # write dict into json file
  42.     json_str = json.dumps(cla_dict, indent=4)
  43.     with open('class_indices.json', 'w') as json_file:
  44.         json_file.write(json_str)

  45.     batch_size = 32
  46.     nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
  47.     print('Using {} dataloader workers every process'.format(nw))

  48.     train_loader = torch.utils.data.DataLoader(train_dataset,
  49.                                                batch_size=batch_size, shuffle=True,
  50.                                                num_workers=nw)

  51.     validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
  52.                                             transform=data_transform["val"])
  53.     val_num = len(validate_dataset)
  54.     validate_loader = torch.utils.data.DataLoader(validate_dataset,
  55.                                                   batch_size=batch_size, shuffle=False,
  56.                                                   num_workers=nw)
  57.     print("using {} images for training, {} images for validation.".format(train_num,
  58.                                                                            val_num))

  59.     # test_data_iter = iter(validate_loader)
  60.     # test_image, test_label = test_data_iter.next()

  61.     model_name = "vgg16"
  62.     net = vgg(model_name=model_name, num_classes=4, init_weights=True)
  63.     net.to(device)
  64.     # Change the loss function to BCELoss
  65.     loss_function = nn.BCEWithLogitsLoss()
  66.     optimizer = optim.Adam(net.parameters(), lr=0.0001)

  67.     train_losses = []
  68.     val_losses = []

  69.     # Modify the training and validation loop to calculate evaluation metrics
  70.     epochs = 30
  71.     best_acc = 0.0
  72.     save_path = './{}Net.pth'.format(model_name)
  73.     train_steps = len(train_loader)
  74.     for epoch in range(epochs):
  75.         # train
  76.         net.train()
  77.         running_loss = 0.0
  78.         train_bar = tqdm(train_loader, file=sys.stdout)
  79.         for step, data in enumerate(train_bar):
  80.             images, labels = data
  81.             optimizer.zero_grad()
  82.             outputs = net(images.to(device))
  83.             loss = loss_function(outputs, labels.to(device))
  84.             loss.backward()
  85.             optimizer.step()

  86.             # print statistics
  87.             running_loss += loss.item()

  88.             train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
  89.                                                                      epochs,
  90.                                                                      loss)
  91.             train_losses.append(running_loss / train_steps)

  92.         # validate, Modify the validation part
  93.         net.eval()
  94.         acc = 0.0  # accumulate accurate number / epoch
  95.         val_steps = len(validate_loader)
  96.         val_true_labels = []
  97.         val_pred_labels = []
  98.         with torch.no_grad():
  99.             val_bar = tqdm(validate_loader, file=sys.stdout)
  100.             for val_data in val_bar:
  101.                 val_images, val_labels = val_data
  102.                 outputs = net(val_images.to(device))
  103.                 val_true_labels.append(val_labels.cpu().numpy())
  104.                 val_pred_labels.append((outputs > 0.5).float().cpu().numpy())
  105.                 val_losses.append(val_loss / val_steps)
  106.                 # predict_y = torch.max(outputs, dim=1)[1]
  107.                 # acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
  108.         val_true_labels = np.vstack(val_true_labels)
  109.         val_pred_labels = np.vstack(val_pred_labels)
  110.         hamming, accuracy, precision, recall, fbeta = evaluate(val_true_labels, val_pred_labels)
  111.         print('[epoch %d] train_loss: %.3f val_hamming: %.3f val_accuracy: %.3f val_precision: %.3f val_recall: %.3f')

  112.         # val_accurate = acc / val_num
  113.         # print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
  114.         #       (epoch + 1, running_loss / train_steps, val_accurate))
  115.         #
  116.         # if val_accurate > best_acc:
  117.         #     best_acc = val_accurate
  118.         #     torch.save(net.state_dict(), save_path)
  119.     plt.figure(figsize=(10, 5))
  120.     plt.plot(train_losses, label='Training Loss')
  121.     plt.plot(val_losses, label='Validation Loss')
  122.     plt.xlabel('Epochs')
  123.     plt.ylabel('Loss')
  124.     plt.legend()
  125.     plt.show()

  126.     print('Finished Training')


  127. if __name__ == '__main__':
  128.     main()
复制代码

另外,请在基于以下代码的基础上,修改部分代码使其满足以下图片要求:
其中,red和dress为多标签预测,后面对应的数字为置信度,也就是概率,可以设置一个阈值threshold=0.5,当预测一个图片的输出结果时,某一个标签的置信度大于阈值即可显示在输出结果图片上面。
  
  1. import os
  2. import json

  3. import torch
  4. from PIL import Image
  5. from torchvision import transforms
  6. import matplotlib.pyplot as plt

  7. from model import vgg


  8. def main():
  9.     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

  10.     data_transform = transforms.Compose(
  11.         [transforms.Resize((224, 224)),
  12.          transforms.ToTensor(),
  13.          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

  14.     # load image
  15.     img_path = "./test.jpg"
  16.     assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
  17.     img = Image.open(img_path)
  18.     plt.imshow(img)
  19.     # [N, C, H, W]
  20.     img = data_transform(img)
  21.     # expand batch dimension
  22.     img = torch.unsqueeze(img, dim=0)

  23.     # read class_indict
  24.     json_path = './class_indices.json'
  25.     assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

  26.     with open(json_path, "r") as f:
  27.         class_indict = json.load(f)
  28.    
  29.     # create model
  30.     model = vgg(model_name="vgg16", num_classes=5).to(device)
  31.     # load model weights
  32.     weights_path = "./vgg16Net.pth"
  33.     assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
  34.     model.load_state_dict(torch.load(weights_path, map_location=device))

  35.     model.eval()
  36.     with torch.no_grad():
  37.         # predict class
  38.         output = torch.squeeze(model(img.to(device))).cpu()
  39.         predict = torch.softmax(output, dim=0)
  40.         predict_cla = torch.argmax(predict).numpy()

  41.     print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
  42.                                                  predict[predict_cla].numpy())
  43.     plt.title(print_res)
  44.     for i in range(len(predict)):
  45.         print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
  46.                                                   predict[i].numpy()))
  47.     plt.show()


  48. if __name__ == '__main__':
  49.     main()
复制代码
20190326113925310.png
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

 楼主| 发表于 2023-4-22 12:24:20 | 显示全部楼层
isdkz 发表于 2023-4-22 11:50
首先,针对您提供的第一个代码片段,关于`val_loss`的定义问题。实际上,在这段代码中,并没有显式地定义 ...

在训练以下代码时报错ValueError:Target size (torch.Size([32])) must be the same as input size (torch.Size([32,4]))
  1. import os
  2. import sys
  3. import json

  4. import torch
  5. import torch.nn as nn
  6. from torchvision import transforms, datasets
  7. import torch.optim as optim
  8. from tqdm import tqdm
  9. from sklearn.metrics import hamming_loss,accuracy_score,precision_score,recall_score,fbeta_score
  10. import numpy as np
  11. import matplotlib.pyplot as plt

  12. from model import vgg

  13. # Define the evaluation function
  14. def evaluate(y_true, y_pred):
  15.     hamming = hamming_loss(y_true, y_pred)
  16.     accuracy = accuracy_score(y_true, y_pred)
  17.     precision = precision_score(y_true, y_pred, average='micro')
  18.     recall = recall_score(y_true, y_pred, average='micro')
  19.     fbeta = fbeta_score(y_true, y_pred, beta=1, average='micro')

  20.     return hamming, accuracy, precision, recall, fbeta

  21. def main():
  22.     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  23.     print("using {} device.".format(device))

  24.     data_transform = {
  25.         "train": transforms.Compose([transforms.RandomResizedCrop(224),
  26.                                      transforms.RandomHorizontalFlip(),
  27.                                      transforms.ToTensor(),
  28.                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
  29.         "val": transforms.Compose([transforms.Resize((224, 224)),
  30.                                    transforms.ToTensor(),
  31.                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

  32.     data_root = os.path.abspath(os.path.join(os.getcwd(), "../"))  # get data root path
  33.     image_path = os.path.join(data_root, "data_set", "PlantVillage")  # flower data set path
  34.     assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
  35.     train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
  36.                                          transform=data_transform["train"])
  37.     train_num = len(train_dataset)

  38.     # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
  39.     flower_list = train_dataset.class_to_idx
  40.     cla_dict = dict((val, key) for key, val in flower_list.items())
  41.     # write dict into json file
  42.     json_str = json.dumps(cla_dict, indent=4)
  43.     with open('class_indices.json', 'w') as json_file:
  44.         json_file.write(json_str)

  45.     batch_size = 32
  46.     nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
  47.     print('Using {} dataloader workers every process'.format(nw))

  48.     train_loader = torch.utils.data.DataLoader(train_dataset,
  49.                                                batch_size=batch_size, shuffle=True,
  50.                                                num_workers=nw)

  51.     validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
  52.                                             transform=data_transform["val"])
  53.     val_num = len(validate_dataset)
  54.     validate_loader = torch.utils.data.DataLoader(validate_dataset,
  55.                                                   batch_size=batch_size, shuffle=False,
  56.                                                   num_workers=nw)
  57.     print("using {} images for training, {} images for validation.".format(train_num,
  58.                                                                            val_num))

  59.     # test_data_iter = iter(validate_loader)
  60.     # test_image, test_label = test_data_iter.next()

  61.     model_name = "vgg16"
  62.     net = vgg(model_name=model_name, num_classes=4, init_weights=True)
  63.     net.to(device)
  64.     # Change the loss function to BCELoss
  65.     loss_function = nn.BCEWithLogitsLoss()
  66.     optimizer = optim.Adam(net.parameters(), lr=0.0001)

  67.     train_losses = []
  68.     val_losses = []

  69.     # Modify the training and validation loop to calculate evaluation metrics
  70.     epochs = 30
  71.     best_acc = 0.0
  72.     save_path = './{}Net.pth'.format(model_name)
  73.     train_steps = len(train_loader)
  74.     for epoch in range(epochs):
  75.         # train
  76.         net.train()
  77.         running_loss = 0.0
  78.         train_bar = tqdm(train_loader, file=sys.stdout)
  79.         for step, data in enumerate(train_bar):
  80.             images, labels = data
  81.             optimizer.zero_grad()
  82.             outputs = net(images.to(device))
  83.             loss = loss_function(outputs, labels.to(device))
  84.             loss.backward()
  85.             optimizer.step()

  86.             # print statistics
  87.             running_loss += loss.item()

  88.             train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
  89.                                                                      epochs,
  90.                                                                      loss)
  91.             train_losses.append(running_loss / train_steps)

  92.         # validate, Modify the validation part
  93.         net.eval()
  94.         validate_loss = 0.0
  95.         acc = 0.0  # accumulate accurate number / epoch
  96.         val_steps = len(validate_loader)
  97.         val_true_labels = []
  98.         val_pred_labels = []
  99.         with torch.no_grad():
  100.             val_bar = tqdm(validate_loader, file=sys.stdout)
  101.             for val_data in val_bar:
  102.                 val_images, val_labels = val_data
  103.                 outputs = net(val_images.to(device))
  104.                 val_loss = loss_function(outputs,val_labels.to(device))
  105.                 val_loss.backward()
  106.                 optimizer.step()
  107.                 validate_loss += val_loss.item()
  108.                 val_bar.desc = "val epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
  109.                                                                      epochs,
  110.                                                                      val_loss)
  111.                 val_losses.append(validate_loss / val_steps)
  112.                 val_true_labels.append(val_labels.cpu().numpy())
  113.                 val_pred_labels.append((outputs > 0.5).float().cpu().numpy())

  114.                 # predict_y = torch.max(outputs, dim=1)[1]
  115.                 # acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
  116.         val_true_labels = np.vstack(val_true_labels)
  117.         val_pred_labels = np.vstack(val_pred_labels)
  118.         hamming, accuracy, precision, recall, fbeta = evaluate(val_true_labels, val_pred_labels)
  119.         print('[epoch %d] train_loss: %.3f val_hamming: %.3f val_accuracy: %.3f val_precision: %.3f val_recall: %.3f')

  120.         # val_accurate = acc / val_num
  121.         # print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
  122.         #       (epoch + 1, running_loss / train_steps, val_accurate))
  123.         #
  124.         # if val_accurate > best_acc:
  125.         #     best_acc = val_accurate
  126.         #     torch.save(net.state_dict(), save_path)
  127.     plt.figure(figsize=(10, 5))
  128.     plt.plot(train_losses, label='Training Loss')
  129.     plt.plot(val_losses, label='Validation Loss')
  130.     plt.xlabel('Epochs')
  131.     plt.ylabel('Loss')
  132.     plt.legend()
  133.     plt.show()

  134.     print('Finished Training')


  135. if __name__ == '__main__':
  136.     main()
复制代码

请指出错误原因并更改相应代码段
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

 楼主| 发表于 2023-4-22 12:31:09 | 显示全部楼层
isdkz 发表于 2023-4-22 12:25
根据您提供的代码,错误出现在处理多标签分类问题时的损失函数计算部分。`nn.BCEWithLogitsLoss`需要将标 ...

应该是“将标签的尺寸更改为与模型输入相同的尺寸”吧,另外,这样改会失去多标签分类任务的功能吗?
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

 楼主| 发表于 2023-4-22 16:16:04 | 显示全部楼层
本帖最后由 黎明丿晓小 于 2023-4-22 16:17 编辑
isdkz 发表于 2023-4-22 12:43
抱歉,我在描述中的确有误,应该是“将标签的尺寸更改为与模型输出相同的尺寸”。

另外,不用担心,这 ...


运行以下代码时,在训练过程中报出警告“UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use zero_division parameter to control this behavior._warn_prf(average, modifier, msg_start, len(result))”,请解释报出警告的原因,并给出解决办法:
  1. import os
  2. import sys
  3. import json

  4. import torch
  5. import torch.nn as nn
  6. from torchvision import transforms, datasets
  7. import torch.optim as optim
  8. from tqdm import tqdm
  9. from sklearn.metrics import hamming_loss,accuracy_score,precision_score,recall_score,fbeta_score
  10. import numpy as np
  11. import matplotlib.pyplot as plt
  12. import warnings

  13. from model import vgg

  14. # Define the evaluation function
  15. def evaluate(y_true, y_pred):
  16.     hamming = hamming_loss(y_true, y_pred)
  17.     accuracy = accuracy_score(y_true, y_pred)
  18.     precision = precision_score(y_true, y_pred, average='micro')
  19.     recall = recall_score(y_true, y_pred, average='micro')
  20.     fbeta = fbeta_score(y_true, y_pred, beta=1, average='micro')

  21.     return hamming, accuracy, precision, recall, fbeta

  22. def main():
  23.     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  24.     print("using {} device.".format(device))

  25.     data_transform = {
  26.         "train": transforms.Compose([transforms.RandomResizedCrop(224),
  27.                                      transforms.RandomHorizontalFlip(),
  28.                                      transforms.ToTensor(),
  29.                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
  30.         "val": transforms.Compose([transforms.Resize((224, 224)),
  31.                                    transforms.ToTensor(),
  32.                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

  33.     data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
  34.     image_path = os.path.join(data_root, "data_set", "plantvillage_demo1")  # flower data set path
  35.     assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
  36.     train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
  37.                                          transform=data_transform["train"])
  38.     train_num = len(train_dataset)

  39.     # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
  40.     flower_list = train_dataset.class_to_idx
  41.     cla_dict = dict((val, key) for key, val in flower_list.items())
  42.     # write dict into json file
  43.     json_str = json.dumps(cla_dict, indent=4)
  44.     with open('class_indices.json', 'w') as json_file:
  45.         json_file.write(json_str)

  46.     batch_size = 32
  47.     nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
  48.     print('Using {} dataloader workers every process'.format(nw))

  49.     def label_to_onehot(labels, num_classes):
  50.         batch_size = len(labels)
  51.         one_hot = torch.zeros(batch_size, num_classes)
  52.         one_hot[torch.arange(batch_size), labels] = 1
  53.         return one_hot

  54.     train_loader = torch.utils.data.DataLoader(train_dataset,
  55.                                                batch_size=batch_size, shuffle=True,
  56.                                                num_workers=nw)

  57.     validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
  58.                                             transform=data_transform["val"])
  59.     val_num = len(validate_dataset)
  60.     validate_loader = torch.utils.data.DataLoader(validate_dataset,
  61.                                                   batch_size=batch_size, shuffle=False,
  62.                                                   num_workers=nw)
  63.     print("using {} images for training, {} images for validation.".format(train_num,
  64.                                                                            val_num))

  65.     # test_data_iter = iter(validate_loader)
  66.     # test_image, test_label = test_data_iter.next()

  67.     model_name = "vgg16"
  68.     net = vgg(model_name=model_name, num_classes=4, init_weights=True)
  69.     net.to(device)
  70.     # Change the loss function to BCELoss
  71.     loss_function = nn.BCEWithLogitsLoss()
  72.     optimizer = optim.Adam(net.parameters(), lr=0.0001)

  73.     train_losses = []
  74.     #val_losses = []

  75.     # Modify the training and validation loop to calculate evaluation metrics
  76.     #warnings.filterwarnings("ignore")
  77.     epochs = 30
  78.     best_acc = 0.0
  79.     save_path = './{}Net.pth'.format(model_name)
  80.     train_steps = len(train_loader)
  81.     for epoch in range(epochs):
  82.         # train
  83.         net.train()
  84.         running_loss = 0.0
  85.         train_bar = tqdm(train_loader, file=sys.stdout)
  86.         for step, data in enumerate(train_bar):
  87.             images, labels = data
  88.             labels = label_to_onehot(labels, num_classes=4).to(device)
  89.             optimizer.zero_grad()
  90.             outputs = net(images.to(device))
  91.             loss = loss_function(outputs, labels.to(device))
  92.             loss.backward()
  93.             optimizer.step()

  94.             # print statistics
  95.             running_loss += loss.item()

  96.             train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
  97.                                                                      epochs,
  98.                                                                      loss)
  99.             train_losses.append(running_loss / train_steps)

  100.         # validate, Modify the validation part
  101.         net.eval()
  102.         acc = 0.0  # accumulate accurate number / epoch
  103.         val_true_labels = []
  104.         val_pred_labels = []
  105.         with torch.no_grad():
  106.             val_bar = tqdm(validate_loader, file=sys.stdout)
  107.             for val_data in val_bar:
  108.                 val_images, val_labels = val_data
  109.                 #val_labels = label_to_onehot(val_labels, num_classes=4).to(device)
  110.                 outputs = net(val_images.to(device))
  111.                 val_true_labels.append(label_to_onehot(val_labels,num_classes=4).cpu().numpy())
  112.                 val_pred_labels.append((outputs > 0.5).float().cpu().numpy())

  113.                 predict_y = torch.max(outputs, dim=1)[1]
  114.                 acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
  115.         val_true_labels = np.vstack(val_true_labels)
  116.         val_pred_labels = np.vstack(val_pred_labels)
  117.         hamming, accuracy, precision, recall, fbeta = evaluate(val_true_labels, val_pred_labels)
  118.         val_accurate = acc / val_num
  119.         print('[epoch %d] train_loss: %.3f val_hamming: %.3f val_accuracy: %.3f val_precision: %.3f val_recall: %.3f fbeta: %.3f' %
  120.               (epoch + 1, running_loss / train_steps, hamming, val_accurate, precision, recall, fbeta))


  121.         # print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
  122.         #       (epoch + 1, running_loss / train_steps, val_accurate))

  123.         if val_accurate > best_acc:
  124.             best_acc = val_accurate
  125.             torch.save(net.state_dict(), save_path)
  126.     plt.figure(figsize=(10, 5))
  127.     plt.plot(train_losses, label='Training Loss')
  128.     plt.plot(val_accurate, label='Validation Accuracy')
  129.     plt.xlabel('Epochs')
  130.     plt.ylabel('Loss/Accuracy')
  131.     plt.legend()
  132.     plt.show()

  133.     print('Finished Training')


  134. if __name__ == '__main__':
  135.     main()
复制代码
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

 楼主| 发表于 2023-4-22 16:42:56 | 显示全部楼层
isdkz 发表于 2023-4-22 16:21
这个警告出现的原因是,在计算precision(精确度)时,有些标签在预测结果中没有出现,导致在计算这些标 ...

在运行以下代码时,报错“RuntimeError: Error(s) in loading state_dict for VGG:
        size mismatch for classifier.6.weight: copying a param with shape torch.Size([4, 4096]) from checkpoint, the shape in current model is torch.Size([5, 4096]).
        size mismatch for classifier.6.bias: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([5]).”
请给出错误原因以及修改意见
  1. import os
  2. import json

  3. import torch
  4. from PIL import Image
  5. from torchvision import transforms
  6. import matplotlib.pyplot as plt

  7. from model import vgg


  8. def main():
  9.     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

  10.     data_transform = transforms.Compose(
  11.         [transforms.Resize((224, 224)),
  12.          transforms.ToTensor(),
  13.          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

  14.     # load image
  15.     img_path = "./test.jpg"
  16.     assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
  17.     img = Image.open(img_path)
  18.     plt.imshow(img)
  19.     # [N, C, H, W]
  20.     img = data_transform(img)
  21.     # expand batch dimension
  22.     img = torch.unsqueeze(img, dim=0)

  23.     # read class_indict
  24.     json_path = './class_indices.json'
  25.     assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

  26.     with open(json_path, "r") as f:
  27.         class_indict = json.load(f)
  28.    
  29.     # create model
  30.     model = vgg(model_name="vgg16", num_classes=5).to(device)
  31.     # load model weights
  32.     weights_path = "./vgg16Net.pth"
  33.     assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
  34.     model.load_state_dict(torch.load(weights_path, map_location=device))

  35.     model.eval()
  36.     threshold = 0.5
  37.     with torch.no_grad():
  38.         # predict class
  39.         output = torch.squeeze(model(img.to(device))).cpu()
  40.         probabilities = torch.sigmoid(output)
  41.         predicted_labels = (probabilities > threshold).float().numpy()
  42.         label_confidences = {class_indict[str(i)]: prob.item() for i, prob in enumerate(probabilities) if
  43.                              predicted_labels[i] == 1}

  44.     for label, confidence in label_confidences.items():
  45.         print("Label: {}    Confidence: {:.3f}".format(label, confidence))

  46.     plt.title("\n".join("{}: {:.3f}".format(label, confidence) for label, confidence in label_confidences.items()))
  47.     plt.show()
  48.     #     output = torch.squeeze(model(img.to(device))).cpu()
  49.     #     predict = torch.softmax(output, dim=0)
  50.     #     predict_cla = torch.argmax(predict).numpy()
  51.     #
  52.     # print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
  53.     #                                              predict[predict_cla].numpy())
  54.     # plt.title(print_res)
  55.     # for i in range(len(predict)):
  56.     #     print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
  57.     #                                               predict[i].numpy()))
  58.     # plt.show()


  59. if __name__ == '__main__':
  60.     main()
复制代码
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

 楼主| 发表于 2023-4-22 17:15:56 | 显示全部楼层
黎明丿晓小 发表于 2023-4-22 16:16
运行以下代码时,在训练过程中报出警告“UndefinedMetricWarning: Precision is ill-defined and being ...

在运行此代码的过程中,hamming,  precision, recall, fbeta的输出分别一直为0.250,0.000,0.000,0.000。请说明原因并给出解决方法
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

 楼主| 发表于 2023-4-22 18:37:20 | 显示全部楼层
isdkz 发表于 2023-4-22 16:45
这个错误是因为您试图将预训练模型的权重(输出层有4个神经元)加载到具有不同输出层结构的模型中(输出 ...

请修改以下代码的数据读取方式以满足多标签分类任务的数据读取方式:
  1. import os
  2. import sys
  3. import json

  4. import torch
  5. import torch.nn as nn
  6. from torchvision import transforms, datasets
  7. import torch.optim as optim
  8. from tqdm import tqdm
  9. from sklearn.metrics import hamming_loss,accuracy_score,precision_score,recall_score,fbeta_score
  10. import numpy as np
  11. import matplotlib.pyplot as plt
  12. import warnings

  13. from model import vgg

  14. # Define the evaluation function
  15. def evaluate(y_true, y_pred):
  16.     hamming = hamming_loss(y_true, y_pred)
  17.     accuracy = accuracy_score(y_true, y_pred)
  18.     precision = precision_score(y_true, y_pred, average='micro')
  19.     recall = recall_score(y_true, y_pred, average='micro')
  20.     fbeta = fbeta_score(y_true, y_pred, beta=1, average='micro')

  21.     return hamming, accuracy, precision, recall, fbeta

  22. def main():
  23.     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  24.     print("using {} device.".format(device))

  25.     data_transform = {
  26.         "train": transforms.Compose([transforms.RandomResizedCrop(224),
  27.                                      transforms.RandomHorizontalFlip(),
  28.                                      transforms.ToTensor(),
  29.                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
  30.         "val": transforms.Compose([transforms.Resize((224, 224)),
  31.                                    transforms.ToTensor(),
  32.                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

  33.     data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
  34.     image_path = os.path.join(data_root, "data_set", "plantvillage_demo1")  # flower data set path
  35.     assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
  36.     train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
  37.                                          transform=data_transform["train"])
  38.     train_num = len(train_dataset)

  39.     # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
  40.     flower_list = train_dataset.class_to_idx
  41.     cla_dict = dict((val, key) for key, val in flower_list.items())
  42.     # write dict into json file
  43.     json_str = json.dumps(cla_dict, indent=4)
  44.     with open('class_indices.json', 'w') as json_file:
  45.         json_file.write(json_str)

  46.     batch_size = 32
  47.     nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
  48.     print('Using {} dataloader workers every process'.format(nw))

  49.     def label_to_onehot(labels, num_classes):
  50.         batch_size = len(labels)
  51.         one_hot = torch.zeros(batch_size, num_classes)
  52.         one_hot[torch.arange(batch_size), labels] = 1
  53.         return one_hot

  54.     train_loader = torch.utils.data.DataLoader(train_dataset,
  55.                                                batch_size=batch_size, shuffle=True,
  56.                                                num_workers=nw)

  57.     validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
  58.                                             transform=data_transform["val"])
  59.     val_num = len(validate_dataset)
  60.     validate_loader = torch.utils.data.DataLoader(validate_dataset,
  61.                                                   batch_size=batch_size, shuffle=False,
  62.                                                   num_workers=nw)
  63.     print("using {} images for training, {} images for validation.".format(train_num,
  64.                                                                            val_num))

  65.     # test_data_iter = iter(validate_loader)
  66.     # test_image, test_label = test_data_iter.next()

  67.     model_name = "vgg16"
  68.     net = vgg(model_name=model_name, num_classes=4, init_weights=True)
  69.     net.to(device)
  70.     # Change the loss function to BCELoss
  71.     loss_function = nn.BCEWithLogitsLoss()
  72.     optimizer = optim.Adam(net.parameters(), lr=0.0001)

  73.     train_losses = []
  74.     #val_losses = []

  75.     # Modify the training and validation loop to calculate evaluation metrics
  76.     warnings.filterwarnings("ignore")
  77.     epochs = 30
  78.     best_acc = 0.0
  79.     save_path = './{}Net.pth'.format(model_name)
  80.     train_steps = len(train_loader)
  81.     for epoch in range(epochs):
  82.         # train
  83.         net.train()
  84.         running_loss = 0.0
  85.         train_bar = tqdm(train_loader, file=sys.stdout)
  86.         for step, data in enumerate(train_bar):
  87.             images, labels = data
  88.             labels = label_to_onehot(labels, num_classes=4).to(device)
  89.             optimizer.zero_grad()
  90.             outputs = net(images.to(device))
  91.             loss = loss_function(outputs, labels.to(device))
  92.             loss.backward()
  93.             optimizer.step()

  94.             # print statistics
  95.             running_loss += loss.item()

  96.             train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
  97.                                                                      epochs,
  98.                                                                      loss)
  99.             train_losses.append(running_loss / train_steps)

  100.         # validate, Modify the validation part
  101.         net.eval()
  102.         acc = 0.0  # accumulate accurate number / epoch
  103.         val_true_labels = []
  104.         val_pred_labels = []
  105.         with torch.no_grad():
  106.             val_bar = tqdm(validate_loader, file=sys.stdout)
  107.             for val_data in val_bar:
  108.                 val_images, val_labels = val_data
  109.                 #val_labels = label_to_onehot(val_labels, num_classes=4).to(device)
  110.                 outputs = net(val_images.to(device))
  111.                 val_true_labels.append(label_to_onehot(val_labels,num_classes=4).cpu().numpy())
  112.                 val_pred_labels.append((outputs > 0.5).float().cpu().numpy())

  113.                 predict_y = torch.max(outputs, dim=1)[1]
  114.                 acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
  115.         val_true_labels = np.vstack(val_true_labels)
  116.         val_pred_labels = np.vstack(val_pred_labels)
  117.         hamming, accuracy, precision, recall, fbeta = evaluate(val_true_labels, val_pred_labels)
  118.         val_accurate = acc / val_num
  119.         print('[epoch %d] train_loss: %.3f val_hamming: %.3f val_accuracy: %.3f val_precision: %.3f val_recall: %.3f fbeta: %.3f' %
  120.               (epoch + 1, running_loss / train_steps, hamming, val_accurate, precision, recall, fbeta))


  121.         # print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
  122.         #       (epoch + 1, running_loss / train_steps, val_accurate))

  123.         if val_accurate > best_acc:
  124.             best_acc = val_accurate
  125.             torch.save(net.state_dict(), save_path)
  126.     plt.figure(figsize=(10, 5))
  127.     plt.plot(train_losses, label='Training Loss')
  128.     plt.plot(val_accurate, label='Validation Accuracy')
  129.     plt.xlabel('Epochs')
  130.     plt.ylabel('Loss/Accuracy')
  131.     plt.legend()
  132.     plt.show()

  133.     print('Finished Training')


  134. if __name__ == '__main__':
  135.     main()
复制代码
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-10-21 03:17

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表