鱼C论坛

 找回密码
 立即注册
查看: 1874|回复: 0

读取图像文件及其对应标签信息问题

[复制链接]
发表于 2023-6-16 16:44:15 | 显示全部楼层 |阅读模式
10鱼币
代码如下:
  1. import torch
  2. import torch.nn as nn
  3. from torch.utils.data import DataLoader
  4. from torchvision import transforms
  5. from model import VGG16
  6. import matplotlib.pyplot as plt
  7. import os
  8. from PIL import Image

  9. # Set random seed
  10. torch.manual_seed(0)

  11. # Set parameters and data paths
  12. lr = 0.0001
  13. num_epochs = 20
  14. batch_size = 8
  15. data_path = 'D:/Backup/datasets/insect_multi-label'

  16. # Define image loader function
  17. def load_image_information(path):
  18.     # Image root directory
  19.     image_root_dir = data_path
  20.     # Get image path
  21.     image_dir = os.path.join(image_root_dir, path)
  22.     # Open image in RGB format
  23.     # PyTorch DataLoader expects images to be opened using PIL
  24.     # It's recommended to use this method to read images
  25.     # When reading grayscale images, use convert('L')
  26.     return Image.open(image_dir).convert('RGB')

  27. # Define custom dataset class
  28. class CustomDataset(nn.Module):
  29.     def __init__(self, txt_file, transform=None, loader=None):
  30.         super(CustomDataset, self).__init__()
  31.         self.transform = transform
  32.         self.loader = loader

  33.         self.images = []
  34.         self.labels = []
  35.         with open(txt_file, 'r') as file:
  36.             for line in file:
  37.                 #line = line.strip().split()
  38.                 line.strip('\n')
  39.                 line.rstrip()
  40.                 line = line.split()
  41.                 image_path = line[0]
  42.                 label = [float(x) for x in line[1:len(line)]]
  43.                 self.images.append(image_path)
  44.                 self.labels.append(label)

  45.     def __len__(self):
  46.         return len(self.images)

  47.     def __getitem__(self, index):
  48.         image_path = self.images[index]
  49.         label = self.labels[index]

  50.         image = self.loader(image_path)

  51.         if self.transform is not None:
  52.             image = self.transform(image)

  53.         return image, torch.FloatTensor(label)

  54. def main():
  55.     # Data processing
  56.     transform_train = transforms.Compose([transforms.RandomResizedCrop(224),
  57.                                           transforms.RandomHorizontalFlip(),
  58.                                           transforms.ToTensor()])
  59.     transform_val = transforms.Compose([transforms.Resize(256),
  60.                                         transforms.CenterCrop(224),
  61.                                         transforms.ToTensor()])

  62.     train_dataset = CustomDataset(r'D:/Backup/datasets/insect_multi-label/train/train.txt',
  63.                                   transform=transform_train, loader=load_image_information)
  64.     train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

  65.     val_dataset = CustomDataset(r'D:/Backup/datasets/insect_multi-label/val/val.txt',
  66.                                 transform=transform_val, loader=load_image_information)
  67.     val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)

  68.     # Initialize model
  69.     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  70.     model = VGG16(num_classes=len(train_dataset.labels[0])).to(device)

  71.     # Define loss function and optimizer
  72.     criterion = nn.BCELoss()
  73.     optimizer = torch.optim.Adam(model.parameters(), lr=lr)

  74.     # Train model
  75.     train_losses = []
  76.     val_losses = []
  77.     for epoch in range(num_epochs):
  78.         # Training
  79.         train_loss = 0.0
  80.         model.train()
  81.         for images, labels in train_loader:
  82.             images, labels = images.to(device), labels.to(device)
  83.             optimizer.zero_grad()

  84.             # Forward pass
  85.             outputs = model(images)
  86.             loss = criterion(outputs, labels)

  87.             # Backward pass and optimize
  88.             loss.backward()
  89.             optimizer.step()

  90.             train_loss += loss.item() * images.size(0)
  91.         train_loss = train_loss / len(train_loader.dataset)
  92.         train_losses.append(train_loss)

  93.         # Validation
  94.         val_loss = 0.0
  95.         model.eval()
  96.         with torch.no_grad():
  97.             for images, labels in val_loader:
  98.                 images, labels = images.to(device), labels.to(device)
  99.                 outputs = model(images)
  100.                 loss = criterion(outputs, labels)
  101.                 val_loss += loss.item() * images.size(0)
  102.             val_loss = val_loss / len(val_loader.dataset)
  103.             val_losses.append(val_loss)

  104.         print('Epoch [{}/{}], Train Loss: {:.4f}, Val Loss: {:.4f}'.
  105.               format(epoch + 1, num_epochs, train_loss, val_loss))

  106.     # Save the trained model
  107.     torch.save(model.state_dict(), 'vgg16_multi_label_classification.pth')

  108.     # Visualize the training process
  109.     plt.plot(range(num_epochs), train_losses, '-b', label='train')
  110.     plt.plot(range(num_epochs), val_losses, '-r', label='validation')
  111.     plt.legend(loc='lower right')
  112.     plt.xlabel('epoch')
  113.     plt.ylabel('loss')
  114.     plt.show()


  115. if __name__ == '__main__':
  116.     main()
复制代码


运行此代码时报错:
Traceback (most recent call last):
  File "E:/PycharmProjects/Multi-label classification/VGGNet_ML_test/train.py", line 238, in <module>
    main()
  File "E:/PycharmProjects/Multi-label classification/VGGNet_ML_test/train.py", line 172, in main
    transform=transform_train, loader=load_image_information)
  File "E:/PycharmProjects/Multi-label classification/VGGNet_ML_test/train.py", line 144, in __init__
    label = [float(x) for x in line[1:len(line)]]
  File "E:/PycharmProjects/Multi-label classification/VGGNet_ML_test/train.py", line 144, in <listcomp>
    label = [float(x) for x in line[1:len(line)]]
ValueError: could not convert string to float: 'photos.jpg'


当选择断点为label时,debug后结果如图:可以看到label的输出结果为[1.0 0.0 1.0],转换float形式的原因是BCELoss计算时用float形式。
但是,当选择断点为transform=transform_train时,debug结果表明label = [float(x) for x in line[1:len(line)]]这行代码中的x读取为'photo.jpg',如图,看代码x从1.0开始读入,但是代码把图像文件名也读进去了,所以转换不了。
比较奇怪,不知道怎么解决,求大神指点 ,孩子快急死了


附txt文件内容,如图所示,其中以.jpg结尾的为图像文件,后续数字表示:1代表含此标签,0表示不含此标签;一共三类标签。
附图像数据集格式如下:
dataset
----train
     ----ants
          ----1.jpg
          ----2.jpg
          ...
     ----bees
          ----1.jpg
          ----2.jpg
          ...
     ----disease
          ----1.jpg
          ----2.jpg
          ...
----train.txt
----val
     ----ants
          ----1.jpg
          ----2.jpg
          ...
     ----bees
          ----1.jpg
          ----2.jpg
          ...
     ----disease
          ----1.jpg
          ----2.jpg
          ...
----val.txt

label

label

x读取为'photo.jpg'

x读取为'photo.jpg'

标签信息

标签信息
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-4-23 22:24

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表