鱼C论坛

 找回密码
 立即注册
查看: 1003|回复: 0

读取图像文件及其对应标签信息问题

[复制链接]
发表于 2023-6-16 16:44:15 | 显示全部楼层 |阅读模式
10鱼币
代码如下:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from model import VGG16
import matplotlib.pyplot as plt
import os
from PIL import Image

# Set random seed
torch.manual_seed(0)

# Set parameters and data paths
lr = 0.0001
num_epochs = 20
batch_size = 8
data_path = 'D:/Backup/datasets/insect_multi-label'

# Define image loader function
def load_image_information(path):
    # Image root directory
    image_root_dir = data_path
    # Get image path
    image_dir = os.path.join(image_root_dir, path)
    # Open image in RGB format
    # PyTorch DataLoader expects images to be opened using PIL
    # It's recommended to use this method to read images
    # When reading grayscale images, use convert('L')
    return Image.open(image_dir).convert('RGB')

# Define custom dataset class
class CustomDataset(nn.Module):
    def __init__(self, txt_file, transform=None, loader=None):
        super(CustomDataset, self).__init__()
        self.transform = transform
        self.loader = loader

        self.images = []
        self.labels = []
        with open(txt_file, 'r') as file:
            for line in file:
                #line = line.strip().split()
                line.strip('\n')
                line.rstrip()
                line = line.split()
                image_path = line[0]
                label = [float(x) for x in line[1:len(line)]]
                self.images.append(image_path)
                self.labels.append(label)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image_path = self.images[index]
        label = self.labels[index]

        image = self.loader(image_path)

        if self.transform is not None:
            image = self.transform(image)

        return image, torch.FloatTensor(label)

def main():
    # Data processing
    transform_train = transforms.Compose([transforms.RandomResizedCrop(224),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.ToTensor()])
    transform_val = transforms.Compose([transforms.Resize(256),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor()])

    train_dataset = CustomDataset(r'D:/Backup/datasets/insect_multi-label/train/train.txt',
                                  transform=transform_train, loader=load_image_information)
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

    val_dataset = CustomDataset(r'D:/Backup/datasets/insect_multi-label/val/val.txt',
                                transform=transform_val, loader=load_image_information)
    val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)

    # Initialize model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = VGG16(num_classes=len(train_dataset.labels[0])).to(device)

    # Define loss function and optimizer
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Train model
    train_losses = []
    val_losses = []
    for epoch in range(num_epochs):
        # Training
        train_loss = 0.0
        model.train()
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)
        train_loss = train_loss / len(train_loader.dataset)
        train_losses.append(train_loss)

        # Validation
        val_loss = 0.0
        model.eval()
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * images.size(0)
            val_loss = val_loss / len(val_loader.dataset)
            val_losses.append(val_loss)

        print('Epoch [{}/{}], Train Loss: {:.4f}, Val Loss: {:.4f}'.
              format(epoch + 1, num_epochs, train_loss, val_loss))

    # Save the trained model
    torch.save(model.state_dict(), 'vgg16_multi_label_classification.pth')

    # Visualize the training process
    plt.plot(range(num_epochs), train_losses, '-b', label='train')
    plt.plot(range(num_epochs), val_losses, '-r', label='validation')
    plt.legend(loc='lower right')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.show()


if __name__ == '__main__':
    main()

运行此代码时报错:
Traceback (most recent call last):
  File "E:/PycharmProjects/Multi-label classification/VGGNet_ML_test/train.py", line 238, in <module>
    main()
  File "E:/PycharmProjects/Multi-label classification/VGGNet_ML_test/train.py", line 172, in main
    transform=transform_train, loader=load_image_information)
  File "E:/PycharmProjects/Multi-label classification/VGGNet_ML_test/train.py", line 144, in __init__
    label = [float(x) for x in line[1:len(line)]]
  File "E:/PycharmProjects/Multi-label classification/VGGNet_ML_test/train.py", line 144, in <listcomp>
    label = [float(x) for x in line[1:len(line)]]
ValueError: could not convert string to float: 'photos.jpg'


当选择断点为label时,debug后结果如图:可以看到label的输出结果为[1.0 0.0 1.0],转换float形式的原因是BCELoss计算时用float形式。
但是,当选择断点为transform=transform_train时,debug结果表明label = [float(x) for x in line[1:len(line)]]这行代码中的x读取为'photo.jpg',如图,看代码x从1.0开始读入,但是代码把图像文件名也读进去了,所以转换不了。
比较奇怪,不知道怎么解决,求大神指点 ,孩子快急死了


附txt文件内容,如图所示,其中以.jpg结尾的为图像文件,后续数字表示:1代表含此标签,0表示不含此标签;一共三类标签。
附图像数据集格式如下:
dataset
----train
     ----ants
          ----1.jpg
          ----2.jpg
          ...
     ----bees
          ----1.jpg
          ----2.jpg
          ...
     ----disease
          ----1.jpg
          ----2.jpg
          ...
----train.txt
----val
     ----ants
          ----1.jpg
          ----2.jpg
          ...
     ----bees
          ----1.jpg
          ----2.jpg
          ...
     ----disease
          ----1.jpg
          ----2.jpg
          ...
----val.txt

label

label

x读取为'photo.jpg'

x读取为'photo.jpg'

标签信息

标签信息
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-9-22 19:43

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表