|
楼主 |
发表于 2022-12-23 10:50:17
|
显示全部楼层
import torch
#import cv2
import torch.nn.functional as F
import numpy as np
import json
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import skimage.transform
import argparse
from PIL import Image
torch.cuda.set_device(-1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
def caption_image_beam_search(encoder, decoder, image_path, word_map, beam_size=3):
"""
Reads an image and captions it with beam search.
:param encoder: encoder model
:param decoder: decoder model
:param image_path: path to image
:param word_map: word map
:param beam_size: number of sequences to consider at each decode-step
:return: caption, weights for visualization
"""
k = beam_size
vocab_size = len(word_map)
# Read image and process
img = Image.open(image_path)
#result = np.zeros(img.shape, dtype=np.float32)
#result = img / 255.0
#img = np.uint8(result * 255.0)
img = np.array(img)
width, high, channel = img.shape
width_new, high_new = (256, 256)
img = img[width - width_new:, (high - 256) // 2:high - ((high - 256) // 2), :]
#img = cv2.imread(image_path)
#img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
# 当为单通道图像时,转化为三通道
if len(img.shape) == 2:
img = img[:, :, np.newaxis] # 增加纬度
img = np.concatenate([img, img, img], axis=2) # 拼接为三通道
#img = img.resize(img, (256, 256), Image.ANTIALIAS)
img = np.array(img)
img = img.transpose(2, 0, 1) # 矩阵转置 通道数放在前面
img = img / 255.
img = torch.FloatTensor(img).to(device)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
transform = transforms.Compose([normalize])
image = transform(img) # (3, 256, 256)
# Encode
image = image.unsqueeze(0) # (1, 3, 256, 256)
encoder_out = encoder(image) # (1, enc_image_size, enc_image_size, encoder_dim) 1,14,14,2048
enc_image_size = encoder_out.size(1)
print('enc_image_size:', enc_image_size)
encoder_dim = encoder_out.size(3)
print('encoder_dim:', encoder_dim)
# Flatten encoding
encoder_out = encoder_out.view(1, -1, encoder_dim) # (1, num_pixels, encoder_dim) 1,196,2048
# 表示了图像的196个区域各自的特征
# print('encoder_out:',encoder_out)
num_pixels = encoder_out.size(1) # 第二位 196
# print('num_pixels:',num_pixels)
# We'll treat the problem as having a batch size of k
# print(encoder_out.size())
encoder_out = encoder_out.expand(k, num_pixels, encoder_dim) # (k, num_pixels, encoder_dim)1->k纬度扩展,五份特征
# print(encoder_out.size())
# Tensor to store top k previous words at each step; now they're just <start>
k_prev_words = torch.LongTensor([[word_map['<start>']]] * k).to(device) # (k, 1)
# print('k_prev_words:',k_prev_words)
# Tensor to store top k sequences; now they're just <start>
seqs = k_prev_words # (k, 1)
# Tensor to store top k sequences' scores; now they're just 0
top_k_scores = torch.zeros(k, 1).to(device) # (k, 1)
# Tensor to store top k sequences' alphas; now they're just 1s 这里其实就是存储每个字对应图像上的关注区域,映射在14*14的张量上面
seqs_alpha = torch.ones(k, 1, enc_image_size, enc_image_size).to(device) # (k, 1, enc_image_size, enc_image_size)
# Lists to store completed sequences, their alphas and scores
complete_seqs = list()
complete_seqs_alpha = list()
complete_seqs_scores = list()
# Start decoding
step = 1
h, c = decoder.init_hidden_state(encoder_out) # h0
print('h, c', h.size(), c.size())
# s is a number less than or equal to k, because sequences are removed from this process once they hit <end>
while True:
embeddings = decoder.embedding(k_prev_words).squeeze(1) # (s, embed_dim) (5,隐层512)
print('embeddings', embeddings.size())
# encode的图片表示 和 隐状态
awe, alpha = decoder.attention(encoder_out,
h) # (s, encoder_dim), (s, num_pixels)(5,2048(),5,196(attention 存储字对应图像各部分的权重))
print(' awe, alpha', awe.size(), alpha.size())
# 0/0
alpha = alpha.view(-1, enc_image_size, enc_image_size) # (s, enc_image_size, enc_image_size)(5,14,14)
gate = decoder.sigmoid(decoder.f_beta(h)) # gating scalar, (s, encoder_dim)
awe = gate * awe # 给特征赋予权重
h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1),
(h, c)) # (s, decoder_dim)输入(512,2048),(512,512)带权重的特征和上一次的lstm输出和细胞状态值
scores = decoder.fc(h) # (s, vocab_size)
scores = F.log_softmax(scores, dim=1)
print('scores', scores.size())
# Add 每一句 含有多少词 更新
scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size)
print('top_k_scores,scores', top_k_scores.size(), scores.size())
# For the first step, all k points will have the same scores (since same k previous words, h, c)
if step == 1:
top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s)
else:
# Unroll and find top scores, and their unrolled indices
top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s) 取词,top
print('top_k_scores,top_k_words', top_k_scores.size(), top_k_words.size())
# Convert unrolled indices to actual indices of scores
prev_word_inds = torch.floor_divide(top_k_words, vocab_size)
# prev_word_inds = top_k_words / vocab_size # (s)
next_word_inds = top_k_words % vocab_size # (s)
print('top_k_scores,top_k_words,prev_word_inds,next_word_inds', top_k_words, top_k_scores, prev_word_inds,
next_word_inds)
# Add new words to sequences, alphas
seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1)#词加一
seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)], # 词对应图像区域加一
dim=1) # (s, step+1, enc_image_size, enc_image_size)
# Which sequences are incomplete (didn't reach <end>)? 挑出这次循环完结的 句子
incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
next_word != word_map['<end>']]
complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))
# Set aside complete sequences 挑出完整序列
if len(complete_inds) > 0:
complete_seqs.extend(seqs[complete_inds].tolist()) # 追加全部序列
complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist())
complete_seqs_scores.extend(top_k_scores[complete_inds])
k -= len(complete_inds) # reduce beam length accordingly
# Proceed with incomplete sequences
if k == 0:
break
# 更新参数 只保留未完全序列参数
seqs = seqs[incomplete_inds]
seqs_alpha = seqs_alpha[incomplete_inds]
h = h[prev_word_inds[incomplete_inds]]
c = c[prev_word_inds[incomplete_inds]]
encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)
# Break if things have been going on too long
if step > 50:
break
step += 1
# 标记 scores分数最高序列作为返回值。
i = complete_seqs_scores.index(max(complete_seqs_scores))
seq = complete_seqs[i]
alphas = complete_seqs_alpha[i]
return seq, alphas
def visualize_att(image_path, seq, alphas, rev_word_map, smooth=True):
"""
Visualizes caption with weights at every word.
Adapted from paper authors' repo: https://github.com/kelvinxu/arct ... visualization.ipynb
:param image_path: path to image that has been captioned
:param seq: caption
:param alphas: weights
:param rev_word_map: reverse word mapping, i.e. ix2word
:param smooth: smooth weights?
"""
image = Image.open(image_path)
image = image.resize([14 * 12, 14 * 12], Image.LANCZOS)
words = [rev_word_map[ind] for ind in seq]
#print(words)
for t in range(1,len(words)-1):
if t > 50:
break
plt.subplot(int(np.ceil(len(words)) / 5.), 6, t)
plt.text(0, 1, '%s' % (words[t]), color='black', backgroundcolor='white', fontsize=12)
plt.imshow(image)
current_alpha = alphas[t, :]
if smooth:
alpha = skimage.transform.pyramid_expand(current_alpha.numpy(), upscale=12, sigma=8)
else:
alpha = skimage.transform.resize(current_alpha.numpy(), [14 * 12, 14 * 12])
if t == 0:
plt.imshow(alpha, alpha=0)
else:
plt.imshow(alpha, alpha=0.8)
plt.set_cmap(cm.Greys_r)
plt.axis('off')
#plt.show()
import scipy
print(scipy.__version__)
#checkpoint = torch.load('./BEST_checkpoint_coco_5_cap_per_img_5_min_word_freq.pth.tar', map_location=str(device))
checkpoint = torch.load(r'C:/Users/Ternence/PycharmProjects/pythonProject/tuxiang/BEST_checkpoint_flickr8k_5_cap_per_img_5_min_word_freq.pth.tar', map_location=str(device))
decoder = checkpoint['decoder']
decoder = decoder.to(device)
decoder.eval()
encoder = checkpoint['encoder']
encoder = encoder.to(device)
encoder.eval()
# Load word map (word2ix)
with open(r'C:/Users/Ternence/kind2/Flickr8k/data/WORDMAP_flickr8k_5_cap_per_img_5_min_word_freq.json', 'r') as j:
word_map = json.load(j,)
rev_word_map = {v: k for k, v in word_map.items()} # ix2word
# Encode, decode with attention and beam search
seq, alphas = caption_image_beam_search(encoder, decoder, r'C:/Users/Ternence/kind2/Flickr8k/tupian/ren2.jpg', word_map, 5)
alphas = torch.FloatTensor(alphas)
# Visualize caption and attention of best sequence
visualize_att(r'C:/Users/Ternence/kind2/Flickr8k/tupian/ren2.jpg', seq, alphas, rev_word_map, True)
words = [rev_word_map[ind] for ind in seq]
print(words)
这个是图像字幕的可视化功能,就是b.py |
|