File size: 6,202 Bytes
24c4def
 
 
 
 
 
 
55d9644
24c4def
 
55d9644
24c4def
 
55d9644
 
 
 
 
 
 
 
 
 
24c4def
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55d9644
 
 
24c4def
 
 
 
 
 
55d9644
 
24c4def
 
 
 
 
 
 
55d9644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24c4def
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import json
import yaml
import copy
import asyncio
from nltk.corpus import wordnet

class QueryGenerator:
    def __init__(self, prompt_path, chat):
        self.type = type
        with open(prompt_path,"r",encoding='utf-8') as file:
            self.prompt = yaml.load(file, yaml.FullLoader)
        self.chat = chat
        
    def objects_extract(self, claim_list, use_attribue=False, response=None):
        if use_attribue:
            user_prompt = self.prompt[self.type]["object"]["user"].format(claims=claim_list)
            message = [[
                                {"role": "system", "content": self.prompt["object"]["system"]},
                                {"role": "user", "content": user_prompt}
                        ],]
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)
            response = loop.run_until_complete(self.chat.get_response(messages=message))

        try:
            response = json.loads(response[0])
        except Exception as e:
            print(e)
            
        objects = set(()) # 暂时不考虑spacy那种 感觉没啥用 如果是fact需要忽略实体
        for key in response:
            object_list = response[key].split(".")
            response[key] = object_list
            for object in object_list:
                if object != "none":
                    objects.add(object)
    
        objects = ".".join([object for object in list(objects)])
        return response, objects
    
    def get_hypernyms(self, word):
        synsets = wordnet.synsets(word)
        hypernyms = []

        for synset in synsets:
            for hypernym in synset.hypernyms():
                hypernyms.extend(hypernym.lemma_names())

        hypernyms = list(set(hypernyms))
        hypernyms = ".".join([hypernym for hypernym in hypernyms])
        return hypernyms
    
    def remove_hypernyms(self, objects):
        hypernyms_dict = {}
        for object in objects:
            hypernyms = self.get_hypernyms(object)
            hypernyms_dict[object] = hypernyms
        
        backup = copy.deepcopy(objects)
        for object in objects:
            hypernyms_list = []
            for key in hypernyms_dict:
                if key != object:
                    hypernyms_list.append(hypernyms_dict[key])
            hypernyms_list = ".".join([hypernym for hypernym in hypernyms_list])
            if object in hypernyms_list:
                backup.remove(object)
                
        objects = ".".join([object for object in backup])
        return objects
    
    def filter(self, res, object_list, use_attribue=False):
        if use_attribue:
            attribute_ques_list = json.loads(res[0])
        scenetext_ques_list = json.loads(res[1])
        fact_ques_list = json.loads(res[2])
        objects = set(())
        for idx, key in enumerate(fact_ques_list):
            if fact_ques_list[key][0] != "none":
                object_list[idx] = "none" # 将对应的object赋值为0
                if use_attribue:
                    attribute_ques_list[key] = ["none"]
                scenetext_ques_list[key] = ["none"]
            else:
                for object in object_list[key]:
                    if object != "none":
                        objects.add(object)

        objects = self.remove_hypernyms(objects)
        if use_attribue:
            return attribute_ques_list, scenetext_ques_list, fact_ques_list, objects
        else:
            return scenetext_ques_list, fact_ques_list, objects
            
    def get_response(self, claim_list, type, use_attribute=False):
        self.type = type
        if use_attribute:
            object_list, objects = self.objects_extract(claim_list=claim_list, use_attribue=True)
            self.message_list = [
                [{"role": "system", "content": self.prompt[type]["attribute"]["system"]}, {"role": "user", "content": self.prompt[type]["attribute"]["user"].format(objects=objects,claims=claim_list)}],
                [{"role": "system", "content": self.prompt[type]["scene-text"]["system"]}, {"role": "user", "content": self.prompt[type]["scene-text"]["user"].format(claims=claim_list)}],
                [{"role": "system", "content": self.prompt[type]["fact"]["system"]}, {"role": "user", "content": self.prompt[type]["fact"]["user"].format(claims=claim_list)}]
            ]
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)
            res = loop.run_until_complete(self.chat.get_response(messages=self.message_list))
            if self.type == "image-to-text":
                attribute_ques_list, scenetext_ques_list, fact_ques_list, objects = self.filter(res, object_list)
            else:
                attribute_ques_list, scenetext_ques_list, fact_ques_list = json.loads(res[0]), json.loads(res[1]), json.loads(res[2])
                
            return objects, attribute_ques_list, scenetext_ques_list, fact_ques_list
        else:
            self.message_list = [
                [{"role": "system", "content": self.prompt[type]["object"]["system"]},{"role": "user", "content": self.prompt[type]["object"]["user"].format(claims=claim_list)}],
                [{"role": "system", "content": self.prompt[type]["scene-text"]["system"]}, {"role": "user", "content": self.prompt[type]["scene-text"]["user"].format(claims=claim_list)}],
                [{"role": "system", "content": self.prompt[type]["fact"]["system"]}, {"role": "user", "content": self.prompt[type]["fact"]["user"].format(claims=claim_list)}]
            ]
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)
            res = loop.run_until_complete(self.chat.get_response(messages=self.message_list))
            object_list, objects = self.objects_extract(claim_list=claim_list, response=res)
            if self.type == "image-to-text":
                scenetext_ques_list, fact_ques_list, objects = self.filter(res, object_list)
            else:
                scenetext_ques_list, fact_ques_list = json.loads(res[1]), json.loads(res[2])
                
            return objects, scenetext_ques_list, fact_ques_list