Spaces:
Sleeping
Sleeping
File size: 23,428 Bytes
9bf9e42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 |
# This file contains ShowAttendTell and AllImg model
# ShowAttendTell is from Show, Attend and Tell: Neural Image Caption Generation with Visual Attention
# https://arxiv.org/abs/1502.03044
# AllImg is a model where
# img feature is concatenated with word embedding at every time step as the input of lstm
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import *
# from ..utils import misc as utils
from captioning.utils import misc as utils
from . import utils as model_utils
# torch.manual_seed(42)
# if torch.cuda.is_available():
# torch.cuda.manual_seed(42)
class CaptionModel(nn.Module):
def __init__(self):
super(CaptionModel, self).__init__()
# implements beam search
# calls beam_step and returns the final set of beams
# augments log-probabilities with diversity terms when number of groups > 1
def forward(self, *args, **kwargs):
mode = kwargs.get('mode', 'forward')
if 'mode' in kwargs:
del kwargs['mode']
return getattr(self, '_'+mode)(*args, **kwargs)
def beam_search(self, init_state, init_logprobs, *args, **kwargs):
# function computes the similarity score to be augmented
def add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash):
local_time = t - divm
unaug_logprobs = logprobs.clone()
batch_size = beam_seq_table[0].shape[0]
if divm > 0:
change = logprobs.new_zeros(batch_size, logprobs.shape[-1])
for prev_choice in range(divm):
prev_decisions = beam_seq_table[prev_choice][:, :, local_time] # Nxb
for prev_labels in range(bdash):
change.scatter_add_(1, prev_decisions[:, prev_labels].unsqueeze(-1), change.new_ones(batch_size, 1))
if local_time == 0:
logprobs = logprobs - change * diversity_lambda
else:
logprobs = logprobs - self.repeat_tensor(bdash, change) * diversity_lambda
return logprobs, unaug_logprobs
# does one step of classical beam search
def beam_step(logprobs, unaug_logprobs, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
#INPUTS:
#logprobs: probabilities augmented after diversity N*bxV
#beam_size: obvious
#t : time instant
#beam_seq : tensor contanining the beams
#beam_seq_logprobs: tensor contanining the beam logprobs
#beam_logprobs_sum: tensor contanining joint logprobs
#OUPUTS:
#beam_seq : tensor containing the word indices of the decoded captions Nxbxl
#beam_seq_logprobs : log-probability of each decision made, NxbxlxV
#beam_logprobs_sum : joint log-probability of each beam Nxb
batch_size = beam_logprobs_sum.shape[0]
vocab_size = logprobs.shape[-1]
logprobs = logprobs.reshape(batch_size, -1, vocab_size) # NxbxV
if t == 0:
assert logprobs.shape[1] == 1
beam_logprobs_sum = beam_logprobs_sum[:, :1]
candidate_logprobs = beam_logprobs_sum.unsqueeze(-1) + logprobs # beam_logprobs_sum Nxb logprobs is NxbxV
ys, ix = torch.sort(candidate_logprobs.reshape(candidate_logprobs.shape[0], -1), -1, True)
ys, ix = ys[:,:beam_size], ix[:,:beam_size]
beam_ix = ix // vocab_size # Nxb which beam
selected_ix = ix % vocab_size # Nxb # which world
state_ix = (beam_ix + torch.arange(batch_size).type_as(beam_ix).unsqueeze(-1) * logprobs.shape[1]).reshape(-1) # N*b which in Nxb beams
if t > 0:
# gather according to beam_ix
assert (beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) == beam_seq.reshape(-1, beam_seq.shape[-1])[state_ix].view_as(beam_seq)).all()
beam_seq = beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq))
beam_seq_logprobs = beam_seq_logprobs.gather(1, beam_ix.unsqueeze(-1).unsqueeze(-1).expand_as(beam_seq_logprobs))
beam_seq = torch.cat([beam_seq, selected_ix.unsqueeze(-1)], -1) # beam_seq Nxbxl
beam_logprobs_sum = beam_logprobs_sum.gather(1, beam_ix) + \
logprobs.reshape(batch_size, -1).gather(1, ix)
assert (beam_logprobs_sum == ys).all()
_tmp_beam_logprobs = unaug_logprobs[state_ix].reshape(batch_size, -1, vocab_size)
beam_logprobs = unaug_logprobs.reshape(batch_size, -1, vocab_size).gather(1, beam_ix.unsqueeze(-1).expand(-1, -1, vocab_size)) # NxbxV
assert (_tmp_beam_logprobs == beam_logprobs).all()
beam_seq_logprobs = torch.cat([
beam_seq_logprobs,
beam_logprobs.reshape(batch_size, -1, 1, vocab_size)], 2)
new_state = [None for _ in state]
for _ix in range(len(new_state)):
# copy over state in previous beam q to new beam at vix
new_state[_ix] = state[_ix][:, state_ix]
state = new_state
return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state
# Start diverse_beam_search
opt = kwargs['opt']
temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs
beam_size = opt.get('beam_size', 10)
group_size = opt.get('group_size', 1)
diversity_lambda = opt.get('diversity_lambda', 0.5)
decoding_constraint = opt.get('decoding_constraint', 0)
remove_bad_endings = opt.get('remove_bad_endings', 1)
suppress_UNK = opt.get('suppress_UNK', 1)
length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
bdash = beam_size // group_size # beam per group
batch_size = init_logprobs.shape[0]
device = init_logprobs.device
# INITIALIZATIONS
beam_seq_table = [torch.LongTensor(batch_size, bdash, 0).to(device) for _ in range(group_size)]
beam_seq_logprobs_table = [torch.FloatTensor(batch_size, bdash, 0, self.vocab_size + 1).to(device) for _ in range(group_size)]
beam_logprobs_sum_table = [torch.zeros(batch_size, bdash).to(device) for _ in range(group_size)]
# logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
done_beams_table = [[[] for __ in range(group_size)] for _ in range(batch_size)]
# state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)]
# state_table = list(zip(*[_.reshape(-1, batch_size * bdash, group_size, *_.shape[2:]).chunk(group_size, 2) for _ in init_state]))
state_table = [[_.clone() for _ in init_state] for _ in range(group_size)]
# logprobs_table = list(init_logprobs.reshape(batch_size * bdash, group_size, -1).chunk(group_size, 0))
logprobs_table = [init_logprobs.clone() for _ in range(group_size)]
# END INIT
# Chunk elements in the args
args = list(args)
args = model_utils.split_tensors(group_size, args) # For each arg, turn (Bbg)x... to (Bb)x(g)x...
if self.__class__.__name__ == 'AttEnsemble':
args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name
else:
args = [[args[i][j] for i in range(len(args))] for j in range(group_size)]
for t in range(self.seq_length + group_size - 1):
for divm in range(group_size):
if t >= divm and t <= self.seq_length + divm - 1:
# add diversity
logprobs = logprobs_table[divm]
# suppress previous word
if decoding_constraint and t-divm > 0:
logprobs.scatter_(1, beam_seq_table[divm][:, :, t-divm-1].reshape(-1, 1).to(device), float('-inf'))
if remove_bad_endings and t-divm > 0:
logprobs[torch.from_numpy(np.isin(beam_seq_table[divm][:, :, t-divm-1].cpu().numpy(), self.bad_endings_ix)).reshape(-1), 0] = float('-inf')
# suppress UNK tokens in the decoding
if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobs.size(1)-1)] == 'UNK':
logprobs[:,logprobs.size(1)-1] = logprobs[:, logprobs.size(1)-1] - 1000
# diversity is added here
# the function directly modifies the logprobs values and hence, we need to return
# the unaugmented ones for sorting the candidates in the end. # for historical
# reasons :-)
logprobs, unaug_logprobs = add_diversity(beam_seq_table,logprobs,t,divm,diversity_lambda,bdash)
# infer new beams
beam_seq_table[divm],\
beam_seq_logprobs_table[divm],\
beam_logprobs_sum_table[divm],\
state_table[divm] = beam_step(logprobs,
unaug_logprobs,
bdash,
t-divm,
beam_seq_table[divm],
beam_seq_logprobs_table[divm],
beam_logprobs_sum_table[divm],
state_table[divm])
# if time's up... or if end token is reached then copy beams
for b in range(batch_size):
is_end = beam_seq_table[divm][b, :, t-divm] == self.eos_idx
assert beam_seq_table[divm].shape[-1] == t-divm+1
if t == self.seq_length + divm - 1:
is_end.fill_(1)
for vix in range(bdash):
if is_end[vix]:
final_beam = {
'seq': beam_seq_table[divm][b, vix].clone(),
'logps': beam_seq_logprobs_table[divm][b, vix].clone(),
'unaug_p': beam_seq_logprobs_table[divm][b, vix].sum().item(),
'p': beam_logprobs_sum_table[divm][b, vix].item()
}
final_beam['p'] = length_penalty(t-divm+1, final_beam['p'])
done_beams_table[b][divm].append(final_beam)
beam_logprobs_sum_table[divm][b, is_end] -= 1000
# move the current group one step forward in time
it = beam_seq_table[divm][:, :, t-divm].reshape(-1).to(logprobs.device)
logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it, *(args[divm] + [state_table[divm]]))
logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1)
# all beams are sorted by their log-probabilities
done_beams_table = [[sorted(done_beams_table[b][i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)] for b in range(batch_size)]
done_beams = [sum(_, []) for _ in done_beams_table]
return done_beams
def old_beam_search(self, init_state, init_logprobs, *args, **kwargs):
# function computes the similarity score to be augmented
def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash):
local_time = t - divm
unaug_logprobsf = logprobsf.clone()
for prev_choice in range(divm):
prev_decisions = beam_seq_table[prev_choice][local_time]
for sub_beam in range(bdash):
for prev_labels in range(bdash):
logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[prev_labels]] - diversity_lambda
return unaug_logprobsf
# does one step of classical beam search
def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
#INPUTS:
#logprobsf: probabilities augmented after diversity
#beam_size: obvious
#t : time instant
#beam_seq : tensor contanining the beams
#beam_seq_logprobs: tensor contanining the beam logprobs
#beam_logprobs_sum: tensor contanining joint logprobs
#OUPUTS:
#beam_seq : tensor containing the word indices of the decoded captions
#beam_seq_logprobs : log-probability of each decision made, same size as beam_seq
#beam_logprobs_sum : joint log-probability of each beam
ys,ix = torch.sort(logprobsf,1,True)
candidates = []
cols = min(beam_size, ys.size(1))
rows = beam_size
if t == 0:
rows = 1
for c in range(cols): # for each column (word, essentially)
for q in range(rows): # for each beam expansion
#compute logprob of expanding beam q with word in (sorted) position c
local_logprob = ys[q,c].item()
candidate_logprob = beam_logprobs_sum[q] + local_logprob
# local_unaug_logprob = unaug_logprobsf[q,ix[q,c]]
candidates.append({'c':ix[q,c], 'q':q, 'p':candidate_logprob, 'r':unaug_logprobsf[q]})
candidates = sorted(candidates, key=lambda x: -x['p'])
new_state = [_.clone() for _ in state]
#beam_seq_prev, beam_seq_logprobs_prev
if t >= 1:
#we''ll need these as reference when we fork beams around
beam_seq_prev = beam_seq[:t].clone()
beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone()
for vix in range(beam_size):
v = candidates[vix]
#fork beam index q into index vix
if t >= 1:
beam_seq[:t, vix] = beam_seq_prev[:, v['q']]
beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']]
#rearrange recurrent states
for state_ix in range(len(new_state)):
# copy over state in previous beam q to new beam at vix
new_state[state_ix][:, vix] = state[state_ix][:, v['q']] # dimension one is time step
#append new end terminal at the end of this beam
beam_seq[t, vix] = v['c'] # c'th word is the continuation
beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here
beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam
state = new_state
return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state,candidates
# Start diverse_beam_search
opt = kwargs['opt']
temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs
beam_size = opt.get('beam_size', 10)
group_size = opt.get('group_size', 1)
diversity_lambda = opt.get('diversity_lambda', 0.5)
decoding_constraint = opt.get('decoding_constraint', 0)
remove_bad_endings = opt.get('remove_bad_endings', 1)
suppress_UNK = opt.get('suppress_UNK', 1)
length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
bdash = beam_size // group_size # beam per group
# INITIALIZATIONS
beam_seq_table = [torch.LongTensor(self.seq_length, bdash).zero_() for _ in range(group_size)]
beam_seq_logprobs_table = [torch.FloatTensor(self.seq_length, bdash, self.vocab_size + 1).zero_() for _ in range(group_size)]
beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)]
# logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
done_beams_table = [[] for _ in range(group_size)]
# state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)]
state_table = list(zip(*[_.chunk(group_size, 1) for _ in init_state]))
logprobs_table = list(init_logprobs.chunk(group_size, 0))
# END INIT
# Chunk elements in the args
args = list(args)
if self.__class__.__name__ == 'AttEnsemble':
args = [[_.chunk(group_size) if _ is not None else [None]*group_size for _ in args_] for args_ in args] # arg_name, model_name, group_name
args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name
else:
args = [_.chunk(group_size) if _ is not None else [None]*group_size for _ in args]
args = [[args[i][j] for i in range(len(args))] for j in range(group_size)]
for t in range(self.seq_length + group_size - 1):
for divm in range(group_size):
if t >= divm and t <= self.seq_length + divm - 1:
# add diversity
logprobsf = logprobs_table[divm]
# suppress previous word
if decoding_constraint and t-divm > 0:
logprobsf.scatter_(1, beam_seq_table[divm][t-divm-1].unsqueeze(1).to(logprobsf.device), float('-inf'))
if remove_bad_endings and t-divm > 0:
logprobsf[torch.from_numpy(np.isin(beam_seq_table[divm][t-divm-1].cpu().numpy(), self.bad_endings_ix)), 0] = float('-inf')
# suppress UNK tokens in the decoding
if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobsf.size(1)-1)] == 'UNK':
logprobsf[:,logprobsf.size(1)-1] = logprobsf[:, logprobsf.size(1)-1] - 1000
# diversity is added here
# the function directly modifies the logprobsf values and hence, we need to return
# the unaugmented ones for sorting the candidates in the end. # for historical
# reasons :-)
unaug_logprobsf = add_diversity(beam_seq_table,logprobsf,t,divm,diversity_lambda,bdash)
# infer new beams
beam_seq_table[divm],\
beam_seq_logprobs_table[divm],\
beam_logprobs_sum_table[divm],\
state_table[divm],\
candidates_divm = beam_step(logprobsf,
unaug_logprobsf,
bdash,
t-divm,
beam_seq_table[divm],
beam_seq_logprobs_table[divm],
beam_logprobs_sum_table[divm],
state_table[divm])
# if time's up... or if end token is reached then copy beams
for vix in range(bdash):
if beam_seq_table[divm][t-divm,vix] == self.eos_idx or t == self.seq_length + divm - 1:
final_beam = {
'seq': beam_seq_table[divm][:, vix].clone(),
'logps': beam_seq_logprobs_table[divm][:, vix].clone(),
'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum().item(),
'p': beam_logprobs_sum_table[divm][vix].item()
}
final_beam['p'] = length_penalty(t-divm+1, final_beam['p'])
done_beams_table[divm].append(final_beam)
# don't continue beams from finished sequences
beam_logprobs_sum_table[divm][vix] = -1000
# move the current group one step forward in time
it = beam_seq_table[divm][t-divm].to(logprobsf.device)
logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it, *(args[divm] + [state_table[divm]]))
logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1)
# all beams are sorted by their log-probabilities
done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)]
done_beams = sum(done_beams_table, [])
return done_beams
def sample_next_word(self, logprobs, sample_method, temperature):
if sample_method == 'greedy':
sampleLogprobs, it = torch.max(logprobs.data, 1)
it = it.view(-1).long()
elif sample_method == 'gumbel': # gumbel softmax
# ref: https://gist.github.com/yzh119/fd2146d2aeb329d067568a493b20172f
def sample_gumbel(shape, eps=1e-20):
U = torch.rand(shape).to(logprobs.device)
return -torch.log(-torch.log(U + eps) + eps)
def gumbel_softmax_sample(logits, temperature):
y = logits + sample_gumbel(logits.size())
return F.log_softmax(y / temperature, dim=-1)
_logprobs = gumbel_softmax_sample(logprobs, temperature)
_, it = torch.max(_logprobs.data, 1)
sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions
else:
logprobs = logprobs / temperature
if sample_method.startswith('top'): # topk sampling
top_num = float(sample_method[3:])
if 0 < top_num < 1:
# nucleus sampling from # The Curious Case of Neural Text Degeneration
probs = F.softmax(logprobs, dim=1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1)
_cumsum = sorted_probs.cumsum(1)
mask = _cumsum < top_num
mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1)
sorted_probs = sorted_probs * mask.to(sorted_probs)
sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True)
logprobs.scatter_(1, sorted_indices, sorted_probs.log())
else:
the_k = int(top_num)
tmp = torch.empty_like(logprobs).fill_(float('-inf'))
topk, indices = torch.topk(logprobs, the_k, dim=1)
tmp = tmp.scatter(1, indices, topk)
logprobs = tmp
it = torch.distributions.Categorical(logits=logprobs.detach()).sample()
sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions
return it, sampleLogprobs
def decode_sequence(self, seq):
return utils.decode_sequence(self.vocab, seq)
|