Duplicate from Vageesh1/clip_gpt2
Browse filesCo-authored-by: vageesh <[email protected]>
- .gitattributes +34 -0
- COCO_model.h5 +3 -0
- README.md +13 -0
- app.py +84 -0
- engine.py +42 -0
- model.h5 +3 -0
- model.py +220 -0
- model_2.py +0 -0
- model_trained.pth +3 -0
- neuralnet/dataset.py +139 -0
- neuralnet/model.py +71 -0
- neuralnet/train.py +130 -0
- neuralnet/utils.py +42 -0
- requirements.txt +19 -0
- vocab.json +0 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
COCO_model.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:35200360d19ea02ce5c8f007c8bf6d8297e3c16ae3b3fb4b6eeb24ec1c07f8e6
|
3 |
+
size 636283447
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Clip Gpt2
|
3 |
+
emoji: 🐨
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: indigo
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.19.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
duplicated_from: Vageesh1/clip_gpt2
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import clip
|
3 |
+
import PIL.Image
|
4 |
+
from PIL import Image
|
5 |
+
import skimage.io as io
|
6 |
+
import streamlit as st
|
7 |
+
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
|
8 |
+
from transformers import GPT2TokenizerFast, ViTImageProcessor, VisionEncoderDecoderModel
|
9 |
+
from model import generate2,ClipCaptionModel
|
10 |
+
from engine import inference
|
11 |
+
|
12 |
+
|
13 |
+
model_trained = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
14 |
+
model_trained.load_state_dict(torch.load('model_trained.pth',map_location=torch.device('cpu')),strict=False)
|
15 |
+
image_processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
16 |
+
tokenizer = GPT2TokenizerFast.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
17 |
+
|
18 |
+
def show_n_generate(img, model, greedy = True):
|
19 |
+
image = Image.open(img)
|
20 |
+
pixel_values = image_processor(image, return_tensors ="pt").pixel_values
|
21 |
+
|
22 |
+
if greedy:
|
23 |
+
generated_ids = model.generate(pixel_values, max_new_tokens = 30)
|
24 |
+
else:
|
25 |
+
generated_ids = model.generate(
|
26 |
+
pixel_values,
|
27 |
+
do_sample=True,
|
28 |
+
max_new_tokens = 30,
|
29 |
+
top_k=5)
|
30 |
+
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
31 |
+
return generated_text
|
32 |
+
|
33 |
+
device = "cpu"
|
34 |
+
clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
|
35 |
+
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
36 |
+
|
37 |
+
prefix_length = 10
|
38 |
+
|
39 |
+
model = ClipCaptionModel(prefix_length)
|
40 |
+
|
41 |
+
model.load_state_dict(torch.load('model.h5',map_location=torch.device('cpu')),strict=False)
|
42 |
+
|
43 |
+
model = model.eval()
|
44 |
+
|
45 |
+
coco_model = ClipCaptionModel(prefix_length)
|
46 |
+
coco_model.load_state_dict(torch.load('COCO_model.h5',map_location=torch.device('cpu')),strict=False)
|
47 |
+
model = model.eval()
|
48 |
+
|
49 |
+
|
50 |
+
def ui():
|
51 |
+
st.markdown("# Image Captioning")
|
52 |
+
# st.markdown("## Done By- Vageesh and Rushil")
|
53 |
+
uploaded_file = st.file_uploader("Upload an Image", type=['png', 'jpeg', 'jpg'])
|
54 |
+
|
55 |
+
if uploaded_file is not None:
|
56 |
+
image = io.imread(uploaded_file)
|
57 |
+
pil_image = PIL.Image.fromarray(image)
|
58 |
+
image = preprocess(pil_image).unsqueeze(0).to(device)
|
59 |
+
|
60 |
+
option = st.selectbox('Please select the Model',('Clip Captioning','Attention Decoder','VIT+GPT2'))
|
61 |
+
|
62 |
+
if option=='Clip Captioning':
|
63 |
+
with torch.no_grad():
|
64 |
+
prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
|
65 |
+
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
|
66 |
+
generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)
|
67 |
+
|
68 |
+
st.image(uploaded_file, width = 500, channels = 'RGB')
|
69 |
+
st.markdown("**PREDICTION:** " + generated_text_prefix)
|
70 |
+
elif option=='Attention Decoder':
|
71 |
+
out = inference(uploaded_file)
|
72 |
+
st.image(uploaded_file, width = 500, channels = 'RGB')
|
73 |
+
st.markdown("**PREDICTION:** " + out)
|
74 |
+
|
75 |
+
# elif option=='VIT+GPT2':
|
76 |
+
# out=show_n_generate(uploaded_file, greedy = False, model = model_trained)
|
77 |
+
# st.image(uploaded_file, width = 500, channels = 'RGB')
|
78 |
+
# st.markdown("**PREDICTION:** " + out)
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
if __name__ == '__main__':
|
83 |
+
ui()
|
84 |
+
|
engine.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
from PIL import Image
|
5 |
+
import json
|
6 |
+
from neuralnet.model import SeqToSeq
|
7 |
+
import wget
|
8 |
+
|
9 |
+
url = "https://github.com/Koushik0901/Image-Captioning/releases/download/v1.0/flickr30k.pt"
|
10 |
+
# os.system("curl -L https://github.com/Koushik0901/Image-Captioning/releases/download/v1.0/flickr30k.pt")
|
11 |
+
filename = wget.download(url)
|
12 |
+
|
13 |
+
def inference(img_path):
|
14 |
+
transform = transforms.Compose(
|
15 |
+
[
|
16 |
+
transforms.Resize((299, 299)),
|
17 |
+
transforms.ToTensor(),
|
18 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
19 |
+
]
|
20 |
+
)
|
21 |
+
|
22 |
+
vocabulary = json.load(open('./vocab.json'))
|
23 |
+
|
24 |
+
model_params = {"embed_size":256, "hidden_size":512, "vocab_size": 7666, "num_layers": 3, "device":"cpu"}
|
25 |
+
model = SeqToSeq(**model_params)
|
26 |
+
checkpoint = torch.load('./flickr30k.pt', map_location = 'cpu')
|
27 |
+
model.load_state_dict(checkpoint['state_dict'])
|
28 |
+
|
29 |
+
img = transform(Image.open(img_path).convert("RGB")).unsqueeze(0)
|
30 |
+
|
31 |
+
result_caption = []
|
32 |
+
model.eval()
|
33 |
+
|
34 |
+
x = model.encoder(img).unsqueeze(0)
|
35 |
+
states = None
|
36 |
+
|
37 |
+
out_captions = model.caption_image(img, vocabulary['itos'], 50)
|
38 |
+
return " ".join(out_captions[1:-1])
|
39 |
+
|
40 |
+
|
41 |
+
if __name__ == '__main__':
|
42 |
+
print(inference('./test_examples/dog.png'))
|
model.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2a36a09076b9779de2807d3aa533d455a398d70c1250aeb24a5cc9110e3d59a4
|
3 |
+
size 636272061
|
model.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import clip
|
2 |
+
import os
|
3 |
+
from torch import nn
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as nnf
|
7 |
+
import sys
|
8 |
+
from typing import Tuple, List, Union, Optional
|
9 |
+
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
|
10 |
+
from tqdm import tqdm, trange
|
11 |
+
import skimage.io as io
|
12 |
+
import PIL.Image
|
13 |
+
|
14 |
+
|
15 |
+
N = type(None)
|
16 |
+
V = np.array
|
17 |
+
ARRAY = np.ndarray
|
18 |
+
ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
|
19 |
+
VS = Union[Tuple[V, ...], List[V]]
|
20 |
+
VN = Union[V, N]
|
21 |
+
VNS = Union[VS, N]
|
22 |
+
T = torch.Tensor
|
23 |
+
TS = Union[Tuple[T, ...], List[T]]
|
24 |
+
TN = Optional[T]
|
25 |
+
TNS = Union[Tuple[TN, ...], List[TN]]
|
26 |
+
TSN = Optional[TS]
|
27 |
+
TA = Union[T, ARRAY]
|
28 |
+
|
29 |
+
|
30 |
+
D = torch.device
|
31 |
+
|
32 |
+
def get_device(device_id: int) -> D:
|
33 |
+
if not torch.cuda.is_available():
|
34 |
+
return CPU
|
35 |
+
device_id = min(torch.cuda.device_count() - 1, device_id)
|
36 |
+
return torch.device(f'cuda:{device_id}')
|
37 |
+
|
38 |
+
|
39 |
+
CUDA = get_device
|
40 |
+
|
41 |
+
current_directory = os.getcwd()
|
42 |
+
save_path = os.path.join(os.path.dirname(current_directory), "pretrained_models")
|
43 |
+
os.makedirs(save_path, exist_ok=True)
|
44 |
+
model_path = os.path.join(save_path, 'model_wieghts.pt')
|
45 |
+
|
46 |
+
|
47 |
+
class MLP(nn.Module):
|
48 |
+
|
49 |
+
def forward(self, x: T) -> T:
|
50 |
+
return self.model(x)
|
51 |
+
|
52 |
+
def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
|
53 |
+
super(MLP, self).__init__()
|
54 |
+
layers = []
|
55 |
+
for i in range(len(sizes) -1):
|
56 |
+
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
|
57 |
+
if i < len(sizes) - 2:
|
58 |
+
layers.append(act())
|
59 |
+
self.model = nn.Sequential(*layers)
|
60 |
+
|
61 |
+
class ClipCaptionModel(nn.Module):
|
62 |
+
|
63 |
+
#@functools.lru_cache #FIXME
|
64 |
+
def get_dummy_token(self, batch_size: int, device: D) -> T:
|
65 |
+
return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
|
66 |
+
|
67 |
+
def forward(self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None):
|
68 |
+
embedding_text = self.gpt.transformer.wte(tokens)
|
69 |
+
prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
|
70 |
+
#print(embedding_text.size()) #torch.Size([5, 67, 768])
|
71 |
+
#print(prefix_projections.size()) #torch.Size([5, 1, 768])
|
72 |
+
embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
|
73 |
+
if labels is not None:
|
74 |
+
dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
|
75 |
+
labels = torch.cat((dummy_token, tokens), dim=1)
|
76 |
+
out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
|
77 |
+
return out
|
78 |
+
|
79 |
+
def __init__(self, prefix_length: int, prefix_size: int = 512):
|
80 |
+
super(ClipCaptionModel, self).__init__()
|
81 |
+
self.prefix_length = prefix_length
|
82 |
+
self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
|
83 |
+
self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
|
84 |
+
if prefix_length > 10: # not enough memory
|
85 |
+
self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
|
86 |
+
else:
|
87 |
+
self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length))
|
88 |
+
|
89 |
+
|
90 |
+
class ClipCaptionPrefix(ClipCaptionModel):
|
91 |
+
|
92 |
+
def parameters(self, recurse: bool = True):
|
93 |
+
return self.clip_project.parameters()
|
94 |
+
|
95 |
+
def train(self, mode: bool = True):
|
96 |
+
super(ClipCaptionPrefix, self).train(mode)
|
97 |
+
self.gpt.eval()
|
98 |
+
return self
|
99 |
+
|
100 |
+
def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None,
|
101 |
+
entry_length=67, temperature=1., stop_token: str = '.'):
|
102 |
+
|
103 |
+
model.eval()
|
104 |
+
stop_token_index = tokenizer.encode(stop_token)[0]
|
105 |
+
tokens = None
|
106 |
+
scores = None
|
107 |
+
device = next(model.parameters()).device
|
108 |
+
seq_lengths = torch.ones(beam_size, device=device)
|
109 |
+
is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
|
110 |
+
with torch.no_grad():
|
111 |
+
if embed is not None:
|
112 |
+
generated = embed
|
113 |
+
else:
|
114 |
+
if tokens is None:
|
115 |
+
tokens = torch.tensor(tokenizer.encode(prompt))
|
116 |
+
tokens = tokens.unsqueeze(0).to(device)
|
117 |
+
generated = model.gpt.transformer.wte(tokens)
|
118 |
+
for i in range(entry_length):
|
119 |
+
outputs = model.gpt(inputs_embeds=generated)
|
120 |
+
logits = outputs.logits
|
121 |
+
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
|
122 |
+
logits = logits.softmax(-1).log()
|
123 |
+
if scores is None:
|
124 |
+
scores, next_tokens = logits.topk(beam_size, -1)
|
125 |
+
generated = generated.expand(beam_size, *generated.shape[1:])
|
126 |
+
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
|
127 |
+
if tokens is None:
|
128 |
+
tokens = next_tokens
|
129 |
+
else:
|
130 |
+
tokens = tokens.expand(beam_size, *tokens.shape[1:])
|
131 |
+
tokens = torch.cat((tokens, next_tokens), dim=1)
|
132 |
+
else:
|
133 |
+
logits[is_stopped] = -float(np.inf)
|
134 |
+
logits[is_stopped, 0] = 0
|
135 |
+
scores_sum = scores[:, None] + logits
|
136 |
+
seq_lengths[~is_stopped] += 1
|
137 |
+
scores_sum_average = scores_sum / seq_lengths[:, None]
|
138 |
+
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
|
139 |
+
next_tokens_source = next_tokens // scores_sum.shape[1]
|
140 |
+
seq_lengths = seq_lengths[next_tokens_source]
|
141 |
+
next_tokens = next_tokens % scores_sum.shape[1]
|
142 |
+
next_tokens = next_tokens.unsqueeze(1)
|
143 |
+
tokens = tokens[next_tokens_source]
|
144 |
+
tokens = torch.cat((tokens, next_tokens), dim=1)
|
145 |
+
generated = generated[next_tokens_source]
|
146 |
+
scores = scores_sum_average * seq_lengths
|
147 |
+
is_stopped = is_stopped[next_tokens_source]
|
148 |
+
next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
|
149 |
+
generated = torch.cat((generated, next_token_embed), dim=1)
|
150 |
+
is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
|
151 |
+
if is_stopped.all():
|
152 |
+
break
|
153 |
+
scores = scores / seq_lengths
|
154 |
+
output_list = tokens.cpu().numpy()
|
155 |
+
output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
|
156 |
+
order = scores.argsort(descending=True)
|
157 |
+
output_texts = [output_texts[i] for i in order]
|
158 |
+
return output_texts
|
159 |
+
|
160 |
+
def generate2(
|
161 |
+
model,
|
162 |
+
tokenizer,
|
163 |
+
tokens=None,
|
164 |
+
prompt=None,
|
165 |
+
embed=None,
|
166 |
+
entry_count=1,
|
167 |
+
entry_length=67, # maximum number of words
|
168 |
+
top_p=0.8,
|
169 |
+
temperature=1.,
|
170 |
+
stop_token: str = '.',
|
171 |
+
):
|
172 |
+
model.eval()
|
173 |
+
generated_num = 0
|
174 |
+
generated_list = []
|
175 |
+
stop_token_index = tokenizer.encode(stop_token)[0]
|
176 |
+
filter_value = -float("Inf")
|
177 |
+
device = next(model.parameters()).device
|
178 |
+
|
179 |
+
with torch.no_grad():
|
180 |
+
|
181 |
+
for entry_idx in trange(entry_count):
|
182 |
+
if embed is not None:
|
183 |
+
generated = embed
|
184 |
+
else:
|
185 |
+
if tokens is None:
|
186 |
+
tokens = torch.tensor(tokenizer.encode(prompt))
|
187 |
+
tokens = tokens.unsqueeze(0).to(device)
|
188 |
+
|
189 |
+
generated = model.gpt.transformer.wte(tokens)
|
190 |
+
|
191 |
+
for i in range(entry_length):
|
192 |
+
|
193 |
+
outputs = model.gpt(inputs_embeds=generated)
|
194 |
+
logits = outputs.logits
|
195 |
+
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
|
196 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
197 |
+
cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)
|
198 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
199 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
200 |
+
..., :-1
|
201 |
+
].clone()
|
202 |
+
sorted_indices_to_remove[..., 0] = 0
|
203 |
+
|
204 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
205 |
+
logits[:, indices_to_remove] = filter_value
|
206 |
+
next_token = torch.argmax(logits, -1).unsqueeze(0)
|
207 |
+
next_token_embed = model.gpt.transformer.wte(next_token)
|
208 |
+
if tokens is None:
|
209 |
+
tokens = next_token
|
210 |
+
else:
|
211 |
+
tokens = torch.cat((tokens, next_token), dim=1)
|
212 |
+
generated = torch.cat((generated, next_token_embed), dim=1)
|
213 |
+
if stop_token_index == next_token.item():
|
214 |
+
break
|
215 |
+
|
216 |
+
output_list = list(tokens.squeeze().cpu().numpy())
|
217 |
+
output_text = tokenizer.decode(output_list)
|
218 |
+
generated_list.append(output_text)
|
219 |
+
|
220 |
+
return generated_list[0]
|
model_2.py
ADDED
File without changes
|
model_trained.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6f44c397a407f1687578a0346cbe19262b4ba6954c3256ec656ade873ac57d07
|
3 |
+
size 982140285
|
neuralnet/dataset.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os # when loading file paths
|
2 |
+
import pandas as pd # for lookup in annotation file
|
3 |
+
import spacy # for tokenizer
|
4 |
+
import torch
|
5 |
+
from torch.nn.utils.rnn import pad_sequence # pad batch
|
6 |
+
from torch.utils.data import DataLoader, Dataset
|
7 |
+
from PIL import Image # Load img
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
import json
|
10 |
+
|
11 |
+
# Download with: python -m spacy download en
|
12 |
+
spacy_eng = spacy.load("en_core_web_sm")
|
13 |
+
|
14 |
+
|
15 |
+
class Vocabulary:
|
16 |
+
def __init__(self, freq_threshold):
|
17 |
+
self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
|
18 |
+
self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
|
19 |
+
self.freq_threshold = freq_threshold
|
20 |
+
|
21 |
+
def __len__(self):
|
22 |
+
return len(self.stoi)
|
23 |
+
|
24 |
+
@staticmethod
|
25 |
+
def tokenizer_eng(text):
|
26 |
+
return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
|
27 |
+
|
28 |
+
def build_vocabulary(self, sentence_list):
|
29 |
+
frequencies = {}
|
30 |
+
idx = 4
|
31 |
+
|
32 |
+
for sentence in sentence_list:
|
33 |
+
for word in self.tokenizer_eng(sentence):
|
34 |
+
if word not in frequencies:
|
35 |
+
frequencies[word] = 1
|
36 |
+
|
37 |
+
else:
|
38 |
+
frequencies[word] += 1
|
39 |
+
|
40 |
+
if frequencies[word] == self.freq_threshold:
|
41 |
+
self.stoi[word] = idx
|
42 |
+
self.itos[idx] = word
|
43 |
+
idx += 1
|
44 |
+
|
45 |
+
def numericalize(self, text):
|
46 |
+
tokenized_text = self.tokenizer_eng(text)
|
47 |
+
|
48 |
+
return [
|
49 |
+
self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
|
50 |
+
for token in tokenized_text
|
51 |
+
]
|
52 |
+
|
53 |
+
|
54 |
+
class FlickrDataset(Dataset):
|
55 |
+
def __init__(self, root_dir, captions_file, transform=None, freq_threshold=5):
|
56 |
+
self.root_dir = root_dir
|
57 |
+
self.df = pd.read_csv(captions_file)
|
58 |
+
self.transform = transform
|
59 |
+
|
60 |
+
# Get img, caption columns
|
61 |
+
self.imgs = self.df["image_name"]
|
62 |
+
self.captions = self.df["comment"]
|
63 |
+
|
64 |
+
# Initialize vocabulary and build vocab
|
65 |
+
self.vocab = Vocabulary(freq_threshold)
|
66 |
+
self.vocab.build_vocabulary(self.captions.tolist())
|
67 |
+
|
68 |
+
def __len__(self):
|
69 |
+
return len(self.df)
|
70 |
+
|
71 |
+
def __getitem__(self, index):
|
72 |
+
caption = self.captions[index]
|
73 |
+
img_id = self.imgs[index]
|
74 |
+
img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")
|
75 |
+
|
76 |
+
if self.transform is not None:
|
77 |
+
img = self.transform(img)
|
78 |
+
|
79 |
+
numericalized_caption = [self.vocab.stoi["<SOS>"]]
|
80 |
+
numericalized_caption += self.vocab.numericalize(caption)
|
81 |
+
numericalized_caption.append(self.vocab.stoi["<EOS>"])
|
82 |
+
|
83 |
+
return img, torch.tensor(numericalized_caption)
|
84 |
+
|
85 |
+
|
86 |
+
class MyCollate:
|
87 |
+
def __init__(self, pad_idx):
|
88 |
+
self.pad_idx = pad_idx
|
89 |
+
|
90 |
+
def __call__(self, batch):
|
91 |
+
imgs = [item[0].unsqueeze(0) for item in batch]
|
92 |
+
imgs = torch.cat(imgs, dim=0)
|
93 |
+
targets = [item[1] for item in batch]
|
94 |
+
targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)
|
95 |
+
|
96 |
+
return imgs, targets
|
97 |
+
|
98 |
+
|
99 |
+
def get_loader(
|
100 |
+
root_folder,
|
101 |
+
annotation_file,
|
102 |
+
transform,
|
103 |
+
batch_size=64,
|
104 |
+
num_workers=2,
|
105 |
+
shuffle=True,
|
106 |
+
pin_memory=True,
|
107 |
+
):
|
108 |
+
dataset = FlickrDataset(root_folder, annotation_file, transform=transform)
|
109 |
+
|
110 |
+
pad_idx = dataset.vocab.stoi["<PAD>"]
|
111 |
+
|
112 |
+
loader = DataLoader(
|
113 |
+
dataset=dataset,
|
114 |
+
batch_size=batch_size,
|
115 |
+
num_workers=num_workers,
|
116 |
+
shuffle=shuffle,
|
117 |
+
pin_memory=pin_memory,
|
118 |
+
collate_fn=MyCollate(pad_idx=pad_idx),
|
119 |
+
)
|
120 |
+
|
121 |
+
return loader, dataset
|
122 |
+
|
123 |
+
|
124 |
+
if __name__ == "__main__":
|
125 |
+
transform = transforms.Compose(
|
126 |
+
[transforms.Resize((224, 224)), transforms.ToTensor(),]
|
127 |
+
)
|
128 |
+
|
129 |
+
loader, dataset = get_loader(
|
130 |
+
"/home/koushik/vscode/Projects/pytorch/img2text_v1/flickr30k/flickr30k_images/", "/home/koushik/vscode/Projects/pytorch/img2text_v1/flickr30k/results.csv", transform=transform
|
131 |
+
)
|
132 |
+
|
133 |
+
for idx, (imgs, captions) in enumerate(loader):
|
134 |
+
print(imgs.shape)
|
135 |
+
print(captions.shape)
|
136 |
+
print(len(dataset.vocab))
|
137 |
+
test = {"itos":dataset.vocab.itos, "stoi": dataset.vocab.stoi}
|
138 |
+
json.dump(test, open('test.json', 'w'))
|
139 |
+
break
|
neuralnet/model.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchvision.models as models
|
4 |
+
|
5 |
+
|
6 |
+
class InceptionEncoder(nn.Module):
|
7 |
+
def __init__(self, embed_size, train_CNN=False):
|
8 |
+
super(InceptionEncoder, self).__init__()
|
9 |
+
self.train_CNN = train_CNN
|
10 |
+
self.inception = models.inception_v3(pretrained=True, aux_logits=False)
|
11 |
+
self.inception.fc = nn.Linear(self.inception.fc.in_features, embed_size)
|
12 |
+
self.relu = nn.ReLU()
|
13 |
+
self.bn = nn.BatchNorm1d(embed_size, momentum = 0.01)
|
14 |
+
self.dropout = nn.Dropout(0.5)
|
15 |
+
|
16 |
+
def forward(self, images):
|
17 |
+
features = self.inception(images)
|
18 |
+
norm_features = self.bn(features)
|
19 |
+
return self.dropout(self.relu(norm_features))
|
20 |
+
|
21 |
+
|
22 |
+
class LstmDecoder(nn.Module):
|
23 |
+
def __init__(self, embed_size, hidden_size, vocab_size, num_layers, device = 'cpu'):
|
24 |
+
super(LstmDecoder, self).__init__()
|
25 |
+
self.num_layers = num_layers
|
26 |
+
self.hidden_size = hidden_size
|
27 |
+
self.device = device
|
28 |
+
self.embed = nn.Embedding(vocab_size, embed_size)
|
29 |
+
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers = self.num_layers)
|
30 |
+
self.linear = nn.Linear(hidden_size, vocab_size)
|
31 |
+
self.dropout = nn.Dropout(0.5)
|
32 |
+
|
33 |
+
def forward(self, encoder_out, captions):
|
34 |
+
h0 = torch.zeros(self.num_layers, encoder_out.shape[0], self.hidden_size).to(self.device).requires_grad_()
|
35 |
+
c0 = torch.zeros(self.num_layers, encoder_out.shape[0], self.hidden_size).to(self.device).requires_grad_()
|
36 |
+
embeddings = self.dropout(self.embed(captions))
|
37 |
+
embeddings = torch.cat((encoder_out.unsqueeze(0), embeddings), dim=0)
|
38 |
+
hiddens, (hn, cn) = self.lstm(embeddings, (h0.detach(), c0.detach()))
|
39 |
+
outputs = self.linear(hiddens)
|
40 |
+
return outputs
|
41 |
+
|
42 |
+
|
43 |
+
class SeqToSeq(nn.Module):
|
44 |
+
def __init__(self, embed_size, hidden_size, vocab_size, num_layers, device = 'cpu'):
|
45 |
+
super(SeqToSeq, self).__init__()
|
46 |
+
self.encoder = InceptionEncoder(embed_size)
|
47 |
+
self.decoder = LstmDecoder(embed_size, hidden_size, vocab_size, num_layers, device)
|
48 |
+
|
49 |
+
def forward(self, images, captions):
|
50 |
+
features = self.encoder(images)
|
51 |
+
outputs = self.decoder(features, captions)
|
52 |
+
return outputs
|
53 |
+
|
54 |
+
def caption_image(self, image, vocabulary, max_length = 50):
|
55 |
+
result_caption = []
|
56 |
+
|
57 |
+
with torch.no_grad():
|
58 |
+
x = self.encoder(image).unsqueeze(0)
|
59 |
+
states = None
|
60 |
+
|
61 |
+
for _ in range(max_length):
|
62 |
+
hiddens, states = self.decoder.lstm(x, states)
|
63 |
+
output = self.decoder.linear(hiddens.squeeze(0))
|
64 |
+
predicted = output.argmax(1)
|
65 |
+
result_caption.append(predicted.item())
|
66 |
+
x = self.decoder.embed(predicted).unsqueeze(0)
|
67 |
+
|
68 |
+
if vocabulary[str(predicted.item())] == "<EOS>":
|
69 |
+
break
|
70 |
+
|
71 |
+
return [vocabulary[str(idx)] for idx in result_caption]
|
neuralnet/train.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from tqdm import tqdm
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.optim as optim
|
5 |
+
import torchvision.transforms as transforms
|
6 |
+
from torch.utils.tensorboard import SummaryWriter # For TensorBoard
|
7 |
+
from utils import save_checkpoint, load_checkpoint, print_examples
|
8 |
+
from dataset import get_loader
|
9 |
+
from model import SeqToSeq
|
10 |
+
from tabulate import tabulate # To tabulate loss and epoch
|
11 |
+
import argparse
|
12 |
+
import json
|
13 |
+
|
14 |
+
def main(args):
|
15 |
+
transform = transforms.Compose(
|
16 |
+
[
|
17 |
+
transforms.Resize((356, 356)),
|
18 |
+
transforms.RandomCrop((299, 299)),
|
19 |
+
transforms.ToTensor(),
|
20 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
21 |
+
]
|
22 |
+
)
|
23 |
+
|
24 |
+
train_loader, _ = get_loader(
|
25 |
+
root_folder = args.root_dir,
|
26 |
+
annotation_file = args.csv_file,
|
27 |
+
transform=transform,
|
28 |
+
batch_size = 64,
|
29 |
+
num_workers=2,
|
30 |
+
)
|
31 |
+
vocab = json.load(open('vocab.json'))
|
32 |
+
|
33 |
+
torch.backends.cudnn.benchmark = True
|
34 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
35 |
+
load_model = False
|
36 |
+
save_model = True
|
37 |
+
train_CNN = False
|
38 |
+
|
39 |
+
# Hyperparameters
|
40 |
+
embed_size = args.embed_size
|
41 |
+
hidden_size = args.hidden_size
|
42 |
+
vocab_size = len(vocab['stoi'])
|
43 |
+
num_layers = args.num_layers
|
44 |
+
learning_rate = args.lr
|
45 |
+
num_epochs = args.num_epochs
|
46 |
+
# for tensorboard
|
47 |
+
|
48 |
+
|
49 |
+
writer = SummaryWriter(args.log_dir)
|
50 |
+
step = 0
|
51 |
+
model_params = {'embed_size': embed_size, 'hidden_size': hidden_size, 'vocab_size':vocab_size, 'num_layers':num_layers}
|
52 |
+
# initialize model, loss etc
|
53 |
+
model = SeqToSeq(**model_params, device = device).to(device)
|
54 |
+
criterion = nn.CrossEntropyLoss(ignore_index = vocab['stoi']["<PAD>"])
|
55 |
+
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
56 |
+
|
57 |
+
# Only finetune the CNN
|
58 |
+
for name, param in model.encoder.inception.named_parameters():
|
59 |
+
if "fc.weight" in name or "fc.bias" in name:
|
60 |
+
param.requires_grad = True
|
61 |
+
else:
|
62 |
+
param.requires_grad = train_CNN
|
63 |
+
|
64 |
+
#load from a save checkpoint
|
65 |
+
if load_model:
|
66 |
+
step = load_checkpoint(torch.load(args.save_path), model, optimizer)
|
67 |
+
|
68 |
+
model.train()
|
69 |
+
best_loss, best_epoch = 10, 0
|
70 |
+
for epoch in range(num_epochs):
|
71 |
+
print_examples(model, device, vocab['itos'])
|
72 |
+
|
73 |
+
for idx, (imgs, captions) in tqdm(
|
74 |
+
enumerate(train_loader), total=len(train_loader), leave=False):
|
75 |
+
imgs = imgs.to(device)
|
76 |
+
captions = captions.to(device)
|
77 |
+
|
78 |
+
outputs = model(imgs, captions[:-1])
|
79 |
+
loss = criterion(
|
80 |
+
outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1)
|
81 |
+
)
|
82 |
+
|
83 |
+
writer.add_scalar("Training loss", loss.item(), global_step=step)
|
84 |
+
step += 1
|
85 |
+
|
86 |
+
optimizer.zero_grad()
|
87 |
+
loss.backward(loss)
|
88 |
+
optimizer.step()
|
89 |
+
|
90 |
+
train_loss = loss.item()
|
91 |
+
if train_loss < best_loss:
|
92 |
+
best_loss = train_loss
|
93 |
+
best_epoch = epoch + 1
|
94 |
+
if save_model:
|
95 |
+
checkpoint = {
|
96 |
+
"model_params": model_params,
|
97 |
+
"state_dict": model.state_dict(),
|
98 |
+
"optimizer": optimizer.state_dict(),
|
99 |
+
"step": step
|
100 |
+
}
|
101 |
+
save_checkpoint(checkpoint, args.save_path)
|
102 |
+
|
103 |
+
|
104 |
+
table = [["Loss:", train_loss],
|
105 |
+
["Step:", step],
|
106 |
+
["Epoch:", epoch + 1],
|
107 |
+
["Best Loss:", best_loss],
|
108 |
+
["Best Epoch:", best_epoch]]
|
109 |
+
print(tabulate(table))
|
110 |
+
|
111 |
+
|
112 |
+
if __name__ == "__main__":
|
113 |
+
|
114 |
+
parser = argparse.ArgumentParser()
|
115 |
+
|
116 |
+
parser.add_argument('--root_dir', type = str, default = './flickr30k/flickr30k_images', help = 'path to images folder')
|
117 |
+
parser.add_argument('--csv_file', type = str, default = './flickr30k/results.csv', help = 'path to captions csv file')
|
118 |
+
parser.add_argument('--log_dir', type = str, default = './drive/MyDrive/TensorBoard/', help = 'path to save tensorboard logs')
|
119 |
+
parser.add_argument('--save_path', type = str, default = './drive/MyDrive/checkpoints/Seq2Seq.pt', help = 'path to save checkpoint')
|
120 |
+
# Model Params
|
121 |
+
parser.add_argument('--batch_size', type = int, default = 64)
|
122 |
+
parser.add_argument('--num_epochs', type = int, default = 100)
|
123 |
+
parser.add_argument('--embed_size', type = int, default=256)
|
124 |
+
parser.add_argument('--hidden_size', type = int, default=512)
|
125 |
+
parser.add_argument('--lr', type = float, default= 0.001)
|
126 |
+
parser.add_argument('--num_layers', type = int, default = 3, help = 'number of lstm layers')
|
127 |
+
|
128 |
+
args = parser.parse_args()
|
129 |
+
|
130 |
+
main(args)
|
neuralnet/utils.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision.transforms as transforms
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
def print_examples(model, device, vocab):
|
7 |
+
transform = transforms.Compose(
|
8 |
+
[transforms.Resize((299, 299)),
|
9 |
+
transforms.ToTensor(),
|
10 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
11 |
+
)
|
12 |
+
|
13 |
+
model.eval()
|
14 |
+
|
15 |
+
test_img1 = transform(Image.open("./test_examples/dog.png").convert("RGB")).unsqueeze(0)
|
16 |
+
print("dog.png PREDICTION: " + " ".join(model.caption_image(test_img1.to(device), vocab)))
|
17 |
+
|
18 |
+
test_img2 = transform(Image.open("./test_examples/dirt_bike.png").convert("RGB")).unsqueeze(0)
|
19 |
+
print("dirt_bike.png PREDICTION: " + " ".join(model.caption_image(test_img2.to(device), vocab)))
|
20 |
+
|
21 |
+
test_img3 = transform(Image.open("./test_examples/surfing.png").convert("RGB")).unsqueeze(0)
|
22 |
+
print("wave.png PREDICTION: " + " ".join(model.caption_image(test_img3.to(device), vocab)))
|
23 |
+
|
24 |
+
test_img4 = transform(Image.open("./test_examples/horse.png").convert("RGB")).unsqueeze(0)
|
25 |
+
print("horse.png PREDICTION: " + " ".join(model.caption_image(test_img4.to(device), vocab)))
|
26 |
+
|
27 |
+
test_img5 = transform(Image.open("./test_examples/camera.png").convert("RGB")).unsqueeze(0)
|
28 |
+
print("camera.png PREDICTION: " + " ".join(model.caption_image(test_img5.to(device), vocab)))
|
29 |
+
model.train()
|
30 |
+
|
31 |
+
|
32 |
+
def save_checkpoint(state, filename="/content/drive/MyDrive/checkpoints/Seq2Seq.pt"):
|
33 |
+
print("=> Saving checkpoint")
|
34 |
+
torch.save(state, filename)
|
35 |
+
|
36 |
+
|
37 |
+
def load_checkpoint(checkpoint, model, optimizer):
|
38 |
+
print("=> Loading checkpoint")
|
39 |
+
model.load_state_dict(checkpoint["state_dict"])
|
40 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
41 |
+
step = checkpoint["step"]
|
42 |
+
return step
|
requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
ftfy
|
4 |
+
git+https://github.com/openai/CLIP.git
|
5 |
+
regex
|
6 |
+
tqdm
|
7 |
+
streamlit
|
8 |
+
scikit-image
|
9 |
+
pillow
|
10 |
+
pandas
|
11 |
+
transformers
|
12 |
+
numpy
|
13 |
+
spacy
|
14 |
+
tqdm
|
15 |
+
tabulate
|
16 |
+
click==7.1.1
|
17 |
+
gdown
|
18 |
+
wget
|
19 |
+
altair<5
|
vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|