Image-to-Text
Chinese
English
File size: 4,067 Bytes
6e6d6a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import json
import torch
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from glob import glob
from data.utils import pre_caption


class facecaption_train(Dataset):
    def __init__(self, transform, image_root, ann_root, max_words=65, prompt=''):        
        '''
        image_root (string): Root directory of images (e.g. coco/images/)
        ann_root (string): directory to store the annotation file
        '''
        all_json = sorted(glob(os.path.join(ann_root, '*.json')))
        self.annotation = []
        # for json_path in all_json[:-1]:
        for json_path in all_json[0:1]:
            print("loading " + json_path)
            with open(json_path, 'r') as json_file:
                data = json.load(json_file)
                self.annotation.extend(data)

        self.transform = transform
        self.image_root = image_root
        self.max_words = max_words
        self.prompt = prompt
        
        self.img_ids = {}
        n = 0
        for ann in self.annotation:
            img_id = ann['image_id']#[7:]
            if img_id not in self.img_ids.keys():
                self.img_ids[img_id] = n
                n += 1    
        
    def __len__(self):
        return len(self.annotation)
    
    def __getitem__(self, index):    
        
        ann = self.annotation[index]
        image_path = os.path.join(self.image_root, ann['image'])        # for face image
        # image_path = os.path.join(self.image_root, ann['image'][:21]+'.jpg')        # for laion image
        image = Image.open(image_path).convert('RGB') 
        image = self.transform(image)
        
        
        caption = self.prompt + pre_caption(*ann['caption'], self.max_words) # for face caption in captionV3
        # laion_caption = ann['laion_caption'][0] if ann['laion_caption'][0] is not None else ""
        # caption = self.prompt + pre_caption(laion_caption, self.max_words) # for laion caption in captionV3
        image_id = self.img_ids[ann['image_id']]
        return image, caption, image_id
    

class facecaption_test(Dataset):
    def __init__(self, transform, image_root, ann_root, max_words=65):        
        '''
        image_root (string): Root directory of images (e.g. coco/images/)
        ann_root (string): directory to store the annotation file
        '''
        all_json = sorted(glob(os.path.join(ann_root, '*.json')))
        self.annotation = []
        for json_path in all_json[-1:]:
            with open(json_path, 'r') as json_file:
                data = json.load(json_file)
                self.annotation.extend(data)
        self.annotation = self.annotation[:5000]

        self.transform = transform
        self.image_root = image_root
        
        self.text = []
        self.image = []
        self.txt2img = {}
        self.img2txt = {}

        txt_id = 0
        for img_id, ann in enumerate(self.annotation):
            self.image.append(ann['image'])    # for face image
            # self.image.append(ann['image'][:21]+'.jpg')    # for laion image
            self.img2txt[img_id] = []
            # for i, caption in enumerate(ann['laion_caption']):        # for laion caption in captionV3
            for i, caption in enumerate(ann['caption']):        # for face caption in captionV3
                self.text.append(pre_caption(caption, max_words))
                self.img2txt[img_id].append(txt_id)
                self.txt2img[txt_id] = img_id
                txt_id += 1
        
    def __len__(self):
        return len(self.annotation)
    
    def __getitem__(self, index):    
        
        ann = self.annotation[index]
        
        image_path = os.path.join(self.image_root, ann['image'])        # for face image
        # image_path = os.path.join(self.image_root, ann['image'][:21]+'.jpg')        # for laion image
        image = Image.open(image_path).convert('RGB') 
        image = self.transform(image)
        return image, index