Vageesh1 commited on
Commit
c44b109
1 Parent(s): 0225049

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +221 -0
model.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import clip
2
+ import os
3
+ from torch import nn
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as nnf
7
+ import sys
8
+ from typing import Tuple, List, Union, Optional
9
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
10
+ from tqdm import tqdm, trange
11
+ import skimage.io as io
12
+ import PIL.Image
13
+ from IPython.display import Image
14
+
15
+
16
+ N = type(None)
17
+ V = np.array
18
+ ARRAY = np.ndarray
19
+ ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
20
+ VS = Union[Tuple[V, ...], List[V]]
21
+ VN = Union[V, N]
22
+ VNS = Union[VS, N]
23
+ T = torch.Tensor
24
+ TS = Union[Tuple[T, ...], List[T]]
25
+ TN = Optional[T]
26
+ TNS = Union[Tuple[TN, ...], List[TN]]
27
+ TSN = Optional[TS]
28
+ TA = Union[T, ARRAY]
29
+
30
+
31
+ D = torch.device
32
+
33
+ def get_device(device_id: int) -> D:
34
+ if not torch.cuda.is_available():
35
+ return CPU
36
+ device_id = min(torch.cuda.device_count() - 1, device_id)
37
+ return torch.device(f'cuda:{device_id}')
38
+
39
+
40
+ CUDA = get_device
41
+
42
+ current_directory = os.getcwd()
43
+ save_path = os.path.join(os.path.dirname(current_directory), "pretrained_models")
44
+ os.makedirs(save_path, exist_ok=True)
45
+ model_path = os.path.join(save_path, 'model_wieghts.pt')
46
+
47
+
48
+ class MLP(nn.Module):
49
+
50
+ def forward(self, x: T) -> T:
51
+ return self.model(x)
52
+
53
+ def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
54
+ super(MLP, self).__init__()
55
+ layers = []
56
+ for i in range(len(sizes) -1):
57
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
58
+ if i < len(sizes) - 2:
59
+ layers.append(act())
60
+ self.model = nn.Sequential(*layers)
61
+
62
+ class ClipCaptionModel(nn.Module):
63
+
64
+ #@functools.lru_cache #FIXME
65
+ def get_dummy_token(self, batch_size: int, device: D) -> T:
66
+ return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
67
+
68
+ def forward(self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None):
69
+ embedding_text = self.gpt.transformer.wte(tokens)
70
+ prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
71
+ #print(embedding_text.size()) #torch.Size([5, 67, 768])
72
+ #print(prefix_projections.size()) #torch.Size([5, 1, 768])
73
+ embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
74
+ if labels is not None:
75
+ dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
76
+ labels = torch.cat((dummy_token, tokens), dim=1)
77
+ out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
78
+ return out
79
+
80
+ def __init__(self, prefix_length: int, prefix_size: int = 512):
81
+ super(ClipCaptionModel, self).__init__()
82
+ self.prefix_length = prefix_length
83
+ self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
84
+ self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
85
+ if prefix_length > 10: # not enough memory
86
+ self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
87
+ else:
88
+ self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length))
89
+
90
+
91
+ class ClipCaptionPrefix(ClipCaptionModel):
92
+
93
+ def parameters(self, recurse: bool = True):
94
+ return self.clip_project.parameters()
95
+
96
+ def train(self, mode: bool = True):
97
+ super(ClipCaptionPrefix, self).train(mode)
98
+ self.gpt.eval()
99
+ return self
100
+
101
+ def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None,
102
+ entry_length=67, temperature=1., stop_token: str = '.'):
103
+
104
+ model.eval()
105
+ stop_token_index = tokenizer.encode(stop_token)[0]
106
+ tokens = None
107
+ scores = None
108
+ device = next(model.parameters()).device
109
+ seq_lengths = torch.ones(beam_size, device=device)
110
+ is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
111
+ with torch.no_grad():
112
+ if embed is not None:
113
+ generated = embed
114
+ else:
115
+ if tokens is None:
116
+ tokens = torch.tensor(tokenizer.encode(prompt))
117
+ tokens = tokens.unsqueeze(0).to(device)
118
+ generated = model.gpt.transformer.wte(tokens)
119
+ for i in range(entry_length):
120
+ outputs = model.gpt(inputs_embeds=generated)
121
+ logits = outputs.logits
122
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
123
+ logits = logits.softmax(-1).log()
124
+ if scores is None:
125
+ scores, next_tokens = logits.topk(beam_size, -1)
126
+ generated = generated.expand(beam_size, *generated.shape[1:])
127
+ next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
128
+ if tokens is None:
129
+ tokens = next_tokens
130
+ else:
131
+ tokens = tokens.expand(beam_size, *tokens.shape[1:])
132
+ tokens = torch.cat((tokens, next_tokens), dim=1)
133
+ else:
134
+ logits[is_stopped] = -float(np.inf)
135
+ logits[is_stopped, 0] = 0
136
+ scores_sum = scores[:, None] + logits
137
+ seq_lengths[~is_stopped] += 1
138
+ scores_sum_average = scores_sum / seq_lengths[:, None]
139
+ scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
140
+ next_tokens_source = next_tokens // scores_sum.shape[1]
141
+ seq_lengths = seq_lengths[next_tokens_source]
142
+ next_tokens = next_tokens % scores_sum.shape[1]
143
+ next_tokens = next_tokens.unsqueeze(1)
144
+ tokens = tokens[next_tokens_source]
145
+ tokens = torch.cat((tokens, next_tokens), dim=1)
146
+ generated = generated[next_tokens_source]
147
+ scores = scores_sum_average * seq_lengths
148
+ is_stopped = is_stopped[next_tokens_source]
149
+ next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
150
+ generated = torch.cat((generated, next_token_embed), dim=1)
151
+ is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
152
+ if is_stopped.all():
153
+ break
154
+ scores = scores / seq_lengths
155
+ output_list = tokens.cpu().numpy()
156
+ output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
157
+ order = scores.argsort(descending=True)
158
+ output_texts = [output_texts[i] for i in order]
159
+ return output_texts
160
+
161
+ def generate2(
162
+ model,
163
+ tokenizer,
164
+ tokens=None,
165
+ prompt=None,
166
+ embed=None,
167
+ entry_count=1,
168
+ entry_length=67, # maximum number of words
169
+ top_p=0.8,
170
+ temperature=1.,
171
+ stop_token: str = '.',
172
+ ):
173
+ model.eval()
174
+ generated_num = 0
175
+ generated_list = []
176
+ stop_token_index = tokenizer.encode(stop_token)[0]
177
+ filter_value = -float("Inf")
178
+ device = next(model.parameters()).device
179
+
180
+ with torch.no_grad():
181
+
182
+ for entry_idx in trange(entry_count):
183
+ if embed is not None:
184
+ generated = embed
185
+ else:
186
+ if tokens is None:
187
+ tokens = torch.tensor(tokenizer.encode(prompt))
188
+ tokens = tokens.unsqueeze(0).to(device)
189
+
190
+ generated = model.gpt.transformer.wte(tokens)
191
+
192
+ for i in range(entry_length):
193
+
194
+ outputs = model.gpt(inputs_embeds=generated)
195
+ logits = outputs.logits
196
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
197
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
198
+ cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)
199
+ sorted_indices_to_remove = cumulative_probs > top_p
200
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
201
+ ..., :-1
202
+ ].clone()
203
+ sorted_indices_to_remove[..., 0] = 0
204
+
205
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
206
+ logits[:, indices_to_remove] = filter_value
207
+ next_token = torch.argmax(logits, -1).unsqueeze(0)
208
+ next_token_embed = model.gpt.transformer.wte(next_token)
209
+ if tokens is None:
210
+ tokens = next_token
211
+ else:
212
+ tokens = torch.cat((tokens, next_token), dim=1)
213
+ generated = torch.cat((generated, next_token_embed), dim=1)
214
+ if stop_token_index == next_token.item():
215
+ break
216
+
217
+ output_list = list(tokens.squeeze().cpu().numpy())
218
+ output_text = tokenizer.decode(output_list)
219
+ generated_list.append(output_text)
220
+
221
+ return generated_list[0]