File size: 5,467 Bytes
24c4def
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import yaml
import json
import shortuuid
import base64
from PIL import Image
import os
from tqdm import tqdm
from PIL import Image
from openai import OpenAI
client = OpenAI(base_url="https://oneapi.xty.app/v1",api_key="sk-jD8DeGdJKrdOxpiQ5bD4845bB53346C3A0E9Ed479bE08676")
# import sys
# sys.path.append("/home/wcx/wcx/EasyDetect/tool")
from pipeline.tool.object_detetction_model import *
from pipeline.tool.google_serper import *



def get_openai_reply(image_path, text):
        def encode_image(image_path):
            with open(image_path, "rb") as image_file:
                return base64.b64encode(image_file.read()).decode('utf-8')
        
        img = encode_image(image_path)

        content = [
                    {"type": "text", "text": text},
                    {"type": "image_url","image_url": f"data:image/jpeg;base64,{img}"},
                ]
        
        messages=[
                {
                    "role": "user",
                    "content": content,
                }
            ]
        resp = client.chat.completions.create(
        model="gpt-4-vision-preview",
            messages=messages,
            max_tokens=1024,
        )
        return resp.choices[0].message.content



class Tool:
    def __init__(self):
        config = yaml.load(open("/home/wcx/wcx/GroundingDINO/LVLM/config/config.yaml", "r"), Loader=yaml.FullLoader) 
        self.detector = GroundingDINO(config=config)
        self.search = GoogleSerperAPIWrapper()
        
    def execute(self, image_path, new_path, objects, attribute_list, scenetext_list, fact_list):
        use_text_rec = False
        use_attribute = False
        for key in scenetext_list:
            if scenetext_list[key][0] != "none":
                use_text_rec = True  
        text_res = None
        if use_text_rec:  
            text_res = self.detector.execute(image_path=image_path,content="word.number",new_path=new_path,use_text_rec=True)
        object_res = self.detector.execute(image_path=image_path,content=objects,new_path=new_path,use_text_rec=False) 
        
        
        queries = ""
        if use_attribute:
            cnt = 1
            for key in attribute_list:
                if attribute_list[key][0] != "none":
                    for query in attribute_list[key]:
                        queries += str(cnt) + "." + query + "\n"
                        cnt += 1

        # print(queries)
        if queries == "":
            attribue_res = "none information"
        else:
            attribue_res = get_openai_reply(image_path, queries)
        # print(attribue_res)
        
        fact_res = ""
        cnt = 1
        for key in fact_list:
            if fact_list[key][0] != "none": 
                evidences = self.search.execute(input="", content=str(fact_list[key]))
                for evidence in evidences:
                    fact_res += str(cnt) + "." + evidence + "\n"
                    cnt += 1
        if fact_res == "":
            fact_res = "none information"
        
        return object_res, attribue_res, text_res, fact_res




# if __name__ == '__main__':
#     tool = Tool()
#     extractor = Extractor(model="gpt-4-1106-preview", config_path= "/home/wcx/wcx/GroundingDINO/LVLM/prompt-v2-multi-claim/object_extract.yaml", type="image-to-text")
#     # "/home/wcx/wcx/LVLMHall-test/text-to-image/labeled.json"
#     query = Query(config_path="/home/wcx/wcx/GroundingDINO/LVLM/prompt-v2-multi-claim/query.yaml",type="image-to-text")
#     path = "/home/wcx/wcx/LVLMHall-test/MSCOCO/caption/labeled/minigpt4-100-cx-revise-v1.json"
#     with open(path, "r", encoding="utf-8") as f:
#         for idx, line in tqdm(enumerate(f.readlines()), total=250):
#             # if idx < 189:
#             #     continue
#             data = data2
#             #data = json.loads(line)
#             image_path = data["image_path"]#"/newdisk3/wcx" + data["image_path"]
#             claim_list = ""
#             cnt = 1
#             for seg in data["segments"]:
#                 for cla in seg["claims"]:
#                     claim_list += "claim" + str(cnt) + ": " + cla["claim"] + "\n"
#                     cnt += 1
#             object_list, objects = extractor.get_response(claims=claim_list)
#             print("pre:" + objects)
#             attribute_list, scenetext_list, fact_list, objects = query.get_response(claim_list, objects, object_list)
#             print("after:" + objects)
#             print(object_list)
#             print(attribute_list)
#             print(scenetext_list)
#             print(fact_list)
#             object_res, attribue_res, text_res, fact_res = tool.execute(image_path=image_path, 
#                                                                         new_path="/newdisk3/wcx/MLLM/image-to-text/minigpt4/",
#                                                                         attribute_list=attribute_list, 
#                                                                         scenetext_list=scenetext_list,
#                                                                         fact_list=fact_list, 
#                                                                         objects=objects)
#             # print(object_res)
#             # print(attribue_res)
#             # print(text_res)
#             #print(fact_res[:50])
#             print("=============================")
#             break