EasyDetect / pipeline /query_generate.py
sunnychenxiwang's picture
update all
24c4def
raw
history blame
4.44 kB
import json
import yaml
import copy
import asyncio
from nltk.corpus import wordnet
class QueryGenerator:
def __init__(self, prompt_path, chat, type):
self.type = type
with open(prompt_path,"r",encoding='utf-8') as file:
self.prompt = yaml.load(file, yaml.FullLoader)[type]
self.chat = chat
def objects_extract(self, claim_list):
user_prompt = self.prompt["object"]["user"].format(claims=claim_list)
message = [[
{"role": "system", "content": self.prompt["object"]["system"]},
{"role": "user", "content": user_prompt}
],]
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
response = loop.run_until_complete(self.chat.get_response(messages=message))
try:
response = json.loads(response[0])
except Exception as e:
print(e)
objects = set(()) # 暂时不考虑spacy那种 感觉没啥用 如果是fact需要忽略实体
for key in response:
object_list = response[key].split(".")
response[key] = object_list
for object in object_list:
if object != "none":
objects.add(object)
objects = ".".join([object for object in list(objects)])
return response, objects
def get_hypernyms(self, word):
synsets = wordnet.synsets(word)
hypernyms = []
for synset in synsets:
for hypernym in synset.hypernyms():
hypernyms.extend(hypernym.lemma_names())
hypernyms = list(set(hypernyms))
hypernyms = ".".join([hypernym for hypernym in hypernyms])
return hypernyms
def remove_hypernyms(self, objects):
hypernyms_dict = {}
for object in objects:
hypernyms = self.get_hypernyms(object)
hypernyms_dict[object] = hypernyms
backup = copy.deepcopy(objects)
for object in objects:
hypernyms_list = []
for key in hypernyms_dict:
if key != object:
hypernyms_list.append(hypernyms_dict[key])
hypernyms_list = ".".join([hypernym for hypernym in hypernyms_list])
if object in hypernyms_list:
backup.remove(object)
objects = ".".join([object for object in backup])
return objects
def filter(self, res, object_list):
attribute_ques_list = json.loads(res[0])
scenetext_ques_list = json.loads(res[1])
fact_ques_list = json.loads(res[2])
objects = set(())
for idx, key in enumerate(fact_ques_list):
if fact_ques_list[key][0] != "none":
object_list[idx] = "none" # 将对应的object赋值为0
attribute_ques_list[key] = ["none"]
scenetext_ques_list[key] = ["none"]
else:
for object in object_list[key]:
if object != "none":
objects.add(object)
objects = self.remove_hypernyms(objects)
return attribute_ques_list, scenetext_ques_list, fact_ques_list, objects
def get_response(self, claim_list):
object_list, objects = self.objects_extract(claim_list=claim_list)
self.message_list = [
[{"role": "system", "content": self.prompt["attribute"]["system"]}, {"role": "user", "content": self.prompt["attribute"]["user"].format(objects=objects,claims=claim_list)}],
[{"role": "system", "content": self.prompt["scene-text"]["system"]}, {"role": "user", "content": self.prompt["scene-text"]["user"].format(claims=claim_list)}],
[{"role": "system", "content": self.prompt["fact"]["system"]}, {"role": "user", "content": self.prompt["fact"]["user"].format(claims=claim_list)}]
]
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
res = loop.run_until_complete(self.chat.get_response(messages=self.message_list))
# res = asyncio.run(self.chat.async_get_response(messages=self.message_list))
if self.type == "image-to-text":
attribute_ques_list, scenetext_ques_list, fact_ques_list, objects = self.filter(res, object_list)
return objects, attribute_ques_list, scenetext_ques_list, fact_ques_list