yesidcanoc commited on
Commit
5fc5f75
1 Parent(s): 123d7a4

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +90 -0
README.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Image captioning model
2
+
3
+ ## How To use this model.
4
+
5
+ Adapt the code below to your needs.
6
+ ```
7
+ import os
8
+ from PIL import Image
9
+ import torchvision.transforms as transforms
10
+ from transformers import GPT2TokenizerFast, VisionEncoderDecoderModel
11
+
12
+ class DataProcessing:
13
+ def __init__(self):
14
+ # GPT-2 tokenizer
15
+ self.tokenizer = GPT2TokenizerFast.from_pretrained('distilgpt2')
16
+ self.tokenizer.pad_token = self.tokenizer.eos_token
17
+ # Define the transforms to be applied to the images
18
+ self.transform = transforms.Compose([
19
+ transforms.Resize((224, 224)),
20
+ transforms.ToTensor(),
21
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
22
+ ])
23
+
24
+ class GenerateCaptions(DataProcessing):
25
+ NUM_BEAMS = 3
26
+ MAX_LENGTH = 15
27
+ EARLY_STOPPING = True
28
+ DO_SAMPLE = True
29
+ TOP_K = 10
30
+ NUM_RETURN_SEQUENCES = 2 # number of captions to generate
31
+
32
+ def __init__(self, captioning_model):
33
+ self.captioning_model = captioning_model
34
+ super().__init__()
35
+
36
+ def read_img_predict(self, path):
37
+ try:
38
+ with Image.open(path) as img:
39
+ if img.mode != "RGB":
40
+ img = img.convert('RGB')
41
+ img_transformed = self.transform(img).unsqueeze(0)
42
+ # tensor dimensions max_lenght X num_return_sequences, where ij == some_token_id
43
+ model_output = self.captioning_model.generate(
44
+ img_transformed,
45
+ num_beams=self.NUM_BEAMS,
46
+ max_length=self.MAX_LENGTH,
47
+ early_stopping=self.EARLY_STOPPING,
48
+ do_sample=self.DO_SAMPLE,
49
+ top_k=self.TOP_K,
50
+ num_return_sequences=self.NUM_RETURN_SEQUENCES,
51
+ )
52
+ # g is a tensor like this one: tensor([50256, 13, 198, 198, 198, 198, 198, 198, 198, 50256,
53
+ # 50256, 50256, 50256, 50256, 50256])
54
+ captions = [self.tokenizer.decode(g, skip_special_tokens=True).strip() for g in model_output]
55
+
56
+ return captions
57
+ except FileNotFoundError:
58
+ raise FileNotFoundError(f"File not found: {path}")
59
+
60
+ def generate_caption(self, path):
61
+ """
62
+ Generate captions for a single image or a directory of images
63
+ :param path: path to image or directory of images
64
+ :return: captions
65
+ """
66
+ if os.path.isdir(path):
67
+ self.decoded_predictions = []
68
+ for root, dirs, files in os.walk(path):
69
+ for file in files:
70
+ self.decoded_predictions.append(self.read_img_predict(os.path.join(root, file)))
71
+ return self.decoded_predictions
72
+ elif os.path.isfile(path):
73
+ return self.read_img_predict(path)
74
+ else:
75
+ raise ValueError(f"Invalid path: {path}")
76
+
77
+
78
+
79
+
80
+
81
+ image_captioning_model = VisionEncoderDecoderModel.from_pretrained("yesidcanoc/image-captioning-swin-tiny-distilgpt2")
82
+
83
+ generate_captions = GenerateCaptions(image_captioning_model)
84
+
85
+ captions = generate_captions.generate_caption('../data/test_data/images')
86
+
87
+ print(captions)
88
+
89
+
90
+ ```