Spaces:
Sleeping
Sleeping
File size: 6,202 Bytes
24c4def 55d9644 24c4def 55d9644 24c4def 55d9644 24c4def 55d9644 24c4def 55d9644 24c4def 55d9644 24c4def |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import json
import yaml
import copy
import asyncio
from nltk.corpus import wordnet
class QueryGenerator:
def __init__(self, prompt_path, chat):
self.type = type
with open(prompt_path,"r",encoding='utf-8') as file:
self.prompt = yaml.load(file, yaml.FullLoader)
self.chat = chat
def objects_extract(self, claim_list, use_attribue=False, response=None):
if use_attribue:
user_prompt = self.prompt[self.type]["object"]["user"].format(claims=claim_list)
message = [[
{"role": "system", "content": self.prompt["object"]["system"]},
{"role": "user", "content": user_prompt}
],]
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
response = loop.run_until_complete(self.chat.get_response(messages=message))
try:
response = json.loads(response[0])
except Exception as e:
print(e)
objects = set(()) # 暂时不考虑spacy那种 感觉没啥用 如果是fact需要忽略实体
for key in response:
object_list = response[key].split(".")
response[key] = object_list
for object in object_list:
if object != "none":
objects.add(object)
objects = ".".join([object for object in list(objects)])
return response, objects
def get_hypernyms(self, word):
synsets = wordnet.synsets(word)
hypernyms = []
for synset in synsets:
for hypernym in synset.hypernyms():
hypernyms.extend(hypernym.lemma_names())
hypernyms = list(set(hypernyms))
hypernyms = ".".join([hypernym for hypernym in hypernyms])
return hypernyms
def remove_hypernyms(self, objects):
hypernyms_dict = {}
for object in objects:
hypernyms = self.get_hypernyms(object)
hypernyms_dict[object] = hypernyms
backup = copy.deepcopy(objects)
for object in objects:
hypernyms_list = []
for key in hypernyms_dict:
if key != object:
hypernyms_list.append(hypernyms_dict[key])
hypernyms_list = ".".join([hypernym for hypernym in hypernyms_list])
if object in hypernyms_list:
backup.remove(object)
objects = ".".join([object for object in backup])
return objects
def filter(self, res, object_list, use_attribue=False):
if use_attribue:
attribute_ques_list = json.loads(res[0])
scenetext_ques_list = json.loads(res[1])
fact_ques_list = json.loads(res[2])
objects = set(())
for idx, key in enumerate(fact_ques_list):
if fact_ques_list[key][0] != "none":
object_list[idx] = "none" # 将对应的object赋值为0
if use_attribue:
attribute_ques_list[key] = ["none"]
scenetext_ques_list[key] = ["none"]
else:
for object in object_list[key]:
if object != "none":
objects.add(object)
objects = self.remove_hypernyms(objects)
if use_attribue:
return attribute_ques_list, scenetext_ques_list, fact_ques_list, objects
else:
return scenetext_ques_list, fact_ques_list, objects
def get_response(self, claim_list, type, use_attribute=False):
self.type = type
if use_attribute:
object_list, objects = self.objects_extract(claim_list=claim_list, use_attribue=True)
self.message_list = [
[{"role": "system", "content": self.prompt[type]["attribute"]["system"]}, {"role": "user", "content": self.prompt[type]["attribute"]["user"].format(objects=objects,claims=claim_list)}],
[{"role": "system", "content": self.prompt[type]["scene-text"]["system"]}, {"role": "user", "content": self.prompt[type]["scene-text"]["user"].format(claims=claim_list)}],
[{"role": "system", "content": self.prompt[type]["fact"]["system"]}, {"role": "user", "content": self.prompt[type]["fact"]["user"].format(claims=claim_list)}]
]
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
res = loop.run_until_complete(self.chat.get_response(messages=self.message_list))
if self.type == "image-to-text":
attribute_ques_list, scenetext_ques_list, fact_ques_list, objects = self.filter(res, object_list)
else:
attribute_ques_list, scenetext_ques_list, fact_ques_list = json.loads(res[0]), json.loads(res[1]), json.loads(res[2])
return objects, attribute_ques_list, scenetext_ques_list, fact_ques_list
else:
self.message_list = [
[{"role": "system", "content": self.prompt[type]["object"]["system"]},{"role": "user", "content": self.prompt[type]["object"]["user"].format(claims=claim_list)}],
[{"role": "system", "content": self.prompt[type]["scene-text"]["system"]}, {"role": "user", "content": self.prompt[type]["scene-text"]["user"].format(claims=claim_list)}],
[{"role": "system", "content": self.prompt[type]["fact"]["system"]}, {"role": "user", "content": self.prompt[type]["fact"]["user"].format(claims=claim_list)}]
]
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
res = loop.run_until_complete(self.chat.get_response(messages=self.message_list))
object_list, objects = self.objects_extract(claim_list=claim_list, response=res)
if self.type == "image-to-text":
scenetext_ques_list, fact_ques_list, objects = self.filter(res, object_list)
else:
scenetext_ques_list, fact_ques_list = json.loads(res[1]), json.loads(res[2])
return objects, scenetext_ques_list, fact_ques_list
|