File size: 5,499 Bytes
a220803
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import itertools
import numpy as np
from PIL import Image
from PIL import ImageSequence
from nltk import pos_tag, word_tokenize

from LLaMA2_Accessory.SPHINX import SPHINXModel
from gpt_combinator import caption_summary

class CaptionRefiner():
    def __init__(self, sample_num, add_detect=True, add_pos=True, add_attr=True,
                 openai_api_key=None, openai_api_base=None,
        ):
        self.sample_num = sample_num
        self.ADD_DETECTION_OBJ = add_detect
        self.ADD_POS = add_pos
        self.ADD_ATTR = add_attr
        self.openai_api_key = openai_api_key
        self.openai_api_base =openai_api_base

    def video_load_split(self, video_path=None):
        frame_img_list, sampled_img_list = [], []

        if ".gif" in video_path:
            img = Image.open(video_path) 
            # process every frame in GIF from <PIL.GifImagePlugin.GifImageFile> to <PIL.JpegImagePlugin.JpegImageFile>
            for frame in ImageSequence.Iterator(img):
                frame_np = np.array(frame.copy().convert('RGB').getdata(),dtype=np.uint8).reshape(frame.size[1],frame.size[0],3)
                frame_img = Image.fromarray(np.uint8(frame_np))
                frame_img_list.append(frame_img)
        elif ".mp4" in video_path:
            pass

        # sample frames from the mp4/gif
        for i in range(0, len(frame_img_list), int(len(frame_img_list)/self.sample_num)):
            sampled_img_list.append(frame_img_list[i])
        
        return sampled_img_list # [<PIL.JpegImagePlugin.JpegImageFile>, ...]

    def caption_refine(self, video_path, org_caption, model_path):
        sampled_imgs = self.video_load_split(video_path)

        model = SPHINXModel.from_pretrained(
            pretrained_path=model_path, 
            with_visual=True
        )
        
        existing_objects, scene_description = [], []
        text = word_tokenize(org_caption)
        existing_objects = [word for word,tag in pos_tag(text) if tag in ["NN", "NNS", "NNP"]]
        if self.ADD_DETECTION_OBJ:
            # Detect the objects and scene in the sampled images
    
            qas = [["Where is this scene in the picture most likely to take place?", None]]
            sc_response = model.generate_response(qas, sampled_imgs[0], max_gen_len=1024, temperature=0.9, top_p=0.5, seed=0)
            scene_description.append(sc_response)

            # # Lacking accuracy
            # for img in sampled_imgs:
            #     qas = [["Please detect the objects in the image.", None]]
            #     response = model.generate_response(qas, img, max_gen_len=1024, temperature=0.9, top_p=0.5, seed=0)
            #     print(response)
                
        object_attrs = []
        if self.ADD_ATTR:
            # Detailed Description for all the objects in the sampled images
            for obj in existing_objects:
                obj_attr = []
                for img in sampled_imgs:
                    qas = [["Please describe the attribute of the {}, including color, position, etc".format(obj), None]]
                    response = model.generate_response(qas, img, max_gen_len=1024, temperature=0.9, top_p=0.5, seed=0)
                    obj_attr.append(response)
                object_attrs.append({obj : obj_attr})

        space_relations = []
        if self.ADD_POS:
            obj_pairs = list(itertools.combinations(existing_objects, 2))
            # Description for the relationship between each object in the sample images
            for obj_pair in obj_pairs:
                qas = [["What is the spatial relationship between the {} and the {}? Please describe in lease than twenty words".format(obj_pair[0], obj_pair[1]), None]]
                response = model.generate_response(qas, img, max_gen_len=1024, temperature=0.9, top_p=0.5, seed=0)
                space_relations.append(response)
        
        return dict(
            org_caption = org_caption,
            scene_description = scene_description,
            existing_objects = existing_objects,
            object_attrs = object_attrs,
            space_relations = space_relations,
        )
    
    def gpt_summary(self, total_captions):
        # combine all captions into a detailed long caption
        detailed_caption = ""

        if "org_caption" in total_captions.keys():
            detailed_caption += "In summary, "+ total_captions['org_caption']

        if "scene_description" in total_captions.keys():
            detailed_caption += "We first describe the whole scene. "+total_captions['scene_description'][-1]

        if "existing_objects" in total_captions.keys():
            tmp_sentence = "There are multiple objects in the video, including "
            for obj in total_captions['existing_objects']:
                tmp_sentence += obj+", "
            detailed_caption += tmp_sentence
        
        # if "object_attrs" in total_captions.keys():
        #     caption_summary(
        #         caption_list="", 
        #         api_key=self.openai_api_key, 
        #         api_base=self.openai_api_base,
        #     )
        
        if "space_relations" in total_captions.keys():
            tmp_sentence = "As for the spatial relationship. "
            for sentence in total_captions['space_relations']: tmp_sentence += sentence
            detailed_caption += tmp_sentence
        
        detailed_caption = caption_summary(detailed_caption, self.open_api_key, self.open_api_base)
        
        return detailed_caption