EasyDetect / pipeline /tool_execute.py
sunnychenxiwang's picture
all
55d9644
import yaml
import json
import shortuuid
import base64
from PIL import Image
import os
from tqdm import tqdm
from PIL import Image
from openai import OpenAI
client = OpenAI(base_url="https://oneapi.xty.app/v1",api_key="sk-jD8DeGdJKrdOxpiQ5bD4845bB53346C3A0E9Ed479bE08676")
# import sys
# sys.path.append("/home/wcx/wcx/EasyDetect/tool")
from pipeline.tool.object_detetction_model import *
from pipeline.tool.google_serper import *
def get_openai_reply(image_path, text):
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
img = encode_image(image_path)
content = [
{"type": "text", "text": text},
{"type": "image_url","image_url": f"data:image/jpeg;base64,{img}"},
]
messages=[
{
"role": "user",
"content": content,
}
]
resp = client.chat.completions.create(
model="gpt-4-vision-preview",
messages=messages,
max_tokens=1024,
)
return resp.choices[0].message.content
class Tool:
def __init__(self):
self.detector = GroundingDINO()
self.search = GoogleSerperAPIWrapper()
def execute(self, image_path, new_path, objects, attribute_list, scenetext_list, fact_list):
use_text_rec = False
use_attribute = False
for key in scenetext_list:
if scenetext_list[key][0] != "none":
use_text_rec = True
text_res = None
if use_text_rec:
text_res = self.detector.execute(image_path=image_path,content="word.number",new_path=new_path,use_text_rec=True)
object_res = self.detector.execute(image_path=image_path,content=objects,new_path=new_path,use_text_rec=False)
queries = ""
if use_attribute:
cnt = 1
for key in attribute_list:
if attribute_list[key][0] != "none":
for query in attribute_list[key]:
queries += str(cnt) + "." + query + "\n"
cnt += 1
# print(queries)
if queries == "":
attribue_res = "none information"
else:
attribue_res = get_openai_reply(image_path, queries)
# print(attribue_res)
fact_res = ""
cnt = 1
for key in fact_list:
if fact_list[key][0] != "none":
evidences = self.search.execute(input="", content=str(fact_list[key]))
for evidence in evidences:
fact_res += str(cnt) + "." + evidence + "\n"
cnt += 1
if fact_res == "":
fact_res = "none information"
return object_res, attribue_res, text_res, fact_res
# if __name__ == '__main__':
# tool = Tool()
# extractor = Extractor(model="gpt-4-1106-preview", config_path= "/home/wcx/wcx/GroundingDINO/LVLM/prompt-v2-multi-claim/object_extract.yaml", type="image-to-text")
# # "/home/wcx/wcx/LVLMHall-test/text-to-image/labeled.json"
# query = Query(config_path="/home/wcx/wcx/GroundingDINO/LVLM/prompt-v2-multi-claim/query.yaml",type="image-to-text")
# path = "/home/wcx/wcx/LVLMHall-test/MSCOCO/caption/labeled/minigpt4-100-cx-revise-v1.json"
# with open(path, "r", encoding="utf-8") as f:
# for idx, line in tqdm(enumerate(f.readlines()), total=250):
# # if idx < 189:
# # continue
# data = data2
# #data = json.loads(line)
# image_path = data["image_path"]#"/newdisk3/wcx" + data["image_path"]
# claim_list = ""
# cnt = 1
# for seg in data["segments"]:
# for cla in seg["claims"]:
# claim_list += "claim" + str(cnt) + ": " + cla["claim"] + "\n"
# cnt += 1
# object_list, objects = extractor.get_response(claims=claim_list)
# print("pre:" + objects)
# attribute_list, scenetext_list, fact_list, objects = query.get_response(claim_list, objects, object_list)
# print("after:" + objects)
# print(object_list)
# print(attribute_list)
# print(scenetext_list)
# print(fact_list)
# object_res, attribue_res, text_res, fact_res = tool.execute(image_path=image_path,
# new_path="/newdisk3/wcx/MLLM/image-to-text/minigpt4/",
# attribute_list=attribute_list,
# scenetext_list=scenetext_list,
# fact_list=fact_list,
# objects=objects)
# # print(object_res)
# # print(attribue_res)
# # print(text_res)
# #print(fact_res[:50])
# print("=============================")
# break