黎明丿晓小 发表于 2023-6-16 16:44:15

读取图像文件及其对应标签信息问题

代码如下:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from model import VGG16
import matplotlib.pyplot as plt
import os
from PIL import Image

# Set random seed
torch.manual_seed(0)

# Set parameters and data paths
lr = 0.0001
num_epochs = 20
batch_size = 8
data_path = 'D:/Backup/datasets/insect_multi-label'

# Define image loader function
def load_image_information(path):
    # Image root directory
    image_root_dir = data_path
    # Get image path
    image_dir = os.path.join(image_root_dir, path)
    # Open image in RGB format
    # PyTorch DataLoader expects images to be opened using PIL
    # It's recommended to use this method to read images
    # When reading grayscale images, use convert('L')
    return Image.open(image_dir).convert('RGB')

# Define custom dataset class
class CustomDataset(nn.Module):
    def __init__(self, txt_file, transform=None, loader=None):
      super(CustomDataset, self).__init__()
      self.transform = transform
      self.loader = loader

      self.images = []
      self.labels = []
      with open(txt_file, 'r') as file:
            for line in file:
                #line = line.strip().split()
                line.strip('\n')
                line.rstrip()
                line = line.split()
                image_path = line
                label = ]
                self.images.append(image_path)
                self.labels.append(label)

    def __len__(self):
      return len(self.images)

    def __getitem__(self, index):
      image_path = self.images
      label = self.labels

      image = self.loader(image_path)

      if self.transform is not None:
            image = self.transform(image)

      return image, torch.FloatTensor(label)

def main():
    # Data processing
    transform_train = transforms.Compose([transforms.RandomResizedCrop(224),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.ToTensor()])
    transform_val = transforms.Compose([transforms.Resize(256),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor()])

    train_dataset = CustomDataset(r'D:/Backup/datasets/insect_multi-label/train/train.txt',
                                  transform=transform_train, loader=load_image_information)
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

    val_dataset = CustomDataset(r'D:/Backup/datasets/insect_multi-label/val/val.txt',
                              transform=transform_val, loader=load_image_information)
    val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)

    # Initialize model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = VGG16(num_classes=len(train_dataset.labels)).to(device)

    # Define loss function and optimizer
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Train model
    train_losses = []
    val_losses = []
    for epoch in range(num_epochs):
      # Training
      train_loss = 0.0
      model.train()
      for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)
      train_loss = train_loss / len(train_loader.dataset)
      train_losses.append(train_loss)

      # Validation
      val_loss = 0.0
      model.eval()
      with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * images.size(0)
            val_loss = val_loss / len(val_loader.dataset)
            val_losses.append(val_loss)

      print('Epoch [{}/{}], Train Loss: {:.4f}, Val Loss: {:.4f}'.
            format(epoch + 1, num_epochs, train_loss, val_loss))

    # Save the trained model
    torch.save(model.state_dict(), 'vgg16_multi_label_classification.pth')

    # Visualize the training process
    plt.plot(range(num_epochs), train_losses, '-b', label='train')
    plt.plot(range(num_epochs), val_losses, '-r', label='validation')
    plt.legend(loc='lower right')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.show()


if __name__ == '__main__':
    main()

运行此代码时报错:
Traceback (most recent call last):
File "E:/PycharmProjects/Multi-label classification/VGGNet_ML_test/train.py", line 238, in <module>
    main()
File "E:/PycharmProjects/Multi-label classification/VGGNet_ML_test/train.py", line 172, in main
    transform=transform_train, loader=load_image_information)
File "E:/PycharmProjects/Multi-label classification/VGGNet_ML_test/train.py", line 144, in __init__
    label = ]
File "E:/PycharmProjects/Multi-label classification/VGGNet_ML_test/train.py", line 144, in <listcomp>
    label = ]
ValueError: could not convert string to float: 'photos.jpg'

当选择断点为label时,debug后结果如图:可以看到label的输出结果为,转换float形式的原因是BCELoss计算时用float形式。
但是,当选择断点为transform=transform_train时,debug结果表明label = ]这行代码中的x读取为'photo.jpg',如图,看代码x从1.0开始读入,但是代码把图像文件名也读进去了,所以转换不了。
比较奇怪,不知道怎么解决,求大神指点{:10_266:} ,孩子快急死了

附txt文件内容,如图所示,其中以.jpg结尾的为图像文件,后续数字表示:1代表含此标签,0表示不含此标签;一共三类标签。
附图像数据集格式如下:
dataset
----train
   ----ants
          ----1.jpg
          ----2.jpg
          ...
   ----bees
          ----1.jpg
          ----2.jpg
          ...
   ----disease
          ----1.jpg
          ----2.jpg
          ...
----train.txt
----val
   ----ants
          ----1.jpg
          ----2.jpg
          ...
   ----bees
          ----1.jpg
          ----2.jpg
          ...
   ----disease
          ----1.jpg
          ----2.jpg
          ...
----val.txt

页: [1]
查看完整版本: 读取图像文件及其对应标签信息问题