import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from model import VGG16
import matplotlib.pyplot as plt
import os
from PIL import Image
# Set random seed
torch.manual_seed(0)
# Set parameters and data paths
lr = 0.0001
num_epochs = 20
batch_size = 8
data_path = 'D:/Backup/datasets/insect_multi-label'
# Define image loader function
def load_image_information(path):
# Image root directory
image_root_dir = data_path
# Get image path
image_dir = os.path.join(image_root_dir, path)
# Open image in RGB format
# PyTorch DataLoader expects images to be opened using PIL
# It's recommended to use this method to read images
# When reading grayscale images, use convert('L')
return Image.open(image_dir).convert('RGB')
# Define custom dataset class
class CustomDataset(nn.Module):
def __init__(self, txt_file, transform=None, loader=None):
super(CustomDataset, self).__init__()
self.transform = transform
self.loader = loader
self.images = []
self.labels = []
with open(txt_file, 'r') as file:
for line in file:
#line = line.strip().split()
line.strip('\n')
line.rstrip()
line = line.split()
image_path = line[0]
label = [float(x) for x in line[1:len(line)]]
self.images.append(image_path)
self.labels.append(label)
def __len__(self):
return len(self.images)
def __getitem__(self, index):
image_path = self.images[index]
label = self.labels[index]
image = self.loader(image_path)
if self.transform is not None:
image = self.transform(image)
return image, torch.FloatTensor(label)
def main():
# Data processing
transform_train = transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()])
transform_val = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()])
train_dataset = CustomDataset(r'D:/Backup/datasets/insect_multi-label/train/train.txt',
transform=transform_train, loader=load_image_information)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_dataset = CustomDataset(r'D:/Backup/datasets/insect_multi-label/val/val.txt',
transform=transform_val, loader=load_image_information)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)
# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VGG16(num_classes=len(train_dataset.labels[0])).to(device)
# Define loss function and optimizer
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# Train model
train_losses = []
val_losses = []
for epoch in range(num_epochs):
# Training
train_loss = 0.0
model.train()
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward pass and optimize
loss.backward()
optimizer.step()
train_loss += loss.item() * images.size(0)
train_loss = train_loss / len(train_loader.dataset)
train_losses.append(train_loss)
# Validation
val_loss = 0.0
model.eval()
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item() * images.size(0)
val_loss = val_loss / len(val_loader.dataset)
val_losses.append(val_loss)
print('Epoch [{}/{}], Train Loss: {:.4f}, Val Loss: {:.4f}'.
format(epoch + 1, num_epochs, train_loss, val_loss))
# Save the trained model
torch.save(model.state_dict(), 'vgg16_multi_label_classification.pth')
# Visualize the training process
plt.plot(range(num_epochs), train_losses, '-b', label='train')
plt.plot(range(num_epochs), val_losses, '-r', label='validation')
plt.legend(loc='lower right')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()
if __name__ == '__main__':
main()
运行此代码时报错:
Traceback (most recent call last):
File "E:/PycharmProjects/Multi-label classification/VGGNet_ML_test/train.py", line 238, in <module>
main()
File "E:/PycharmProjects/Multi-label classification/VGGNet_ML_test/train.py", line 172, in main
transform=transform_train, loader=load_image_information)
File "E:/PycharmProjects/Multi-label classification/VGGNet_ML_test/train.py", line 144, in __init__
label = [float(x) for x in line[1:len(line)]]
File "E:/PycharmProjects/Multi-label classification/VGGNet_ML_test/train.py", line 144, in <listcomp>
label = [float(x) for x in line[1:len(line)]]
ValueError: could not convert string to float: 'photos.jpg'
当选择断点为label时,debug后结果如图:可以看到label的输出结果为[1.0 0.0 1.0],转换float形式的原因是BCELoss计算时用float形式。
但是,当选择断点为transform=transform_train时,debug结果表明label = [float(x) for x in line[1:len(line)]]这行代码中的x读取为'photo.jpg',如图,看代码x从1.0开始读入,但是代码把图像文件名也读进去了,所以转换不了。
比较奇怪,不知道怎么解决,求大神指点 ,孩子快急死了