Spaces:
Sleeping
Sleeping
import yaml | |
import json | |
import shortuuid | |
import base64 | |
from PIL import Image | |
import os | |
from tqdm import tqdm | |
from PIL import Image | |
from openai import OpenAI | |
client = OpenAI(base_url="https://oneapi.xty.app/v1",api_key="sk-jD8DeGdJKrdOxpiQ5bD4845bB53346C3A0E9Ed479bE08676") | |
# import sys | |
# sys.path.append("/home/wcx/wcx/EasyDetect/tool") | |
from pipeline.tool.object_detetction_model import * | |
from pipeline.tool.google_serper import * | |
def get_openai_reply(image_path, text): | |
def encode_image(image_path): | |
with open(image_path, "rb") as image_file: | |
return base64.b64encode(image_file.read()).decode('utf-8') | |
img = encode_image(image_path) | |
content = [ | |
{"type": "text", "text": text}, | |
{"type": "image_url","image_url": f"data:image/jpeg;base64,{img}"}, | |
] | |
messages=[ | |
{ | |
"role": "user", | |
"content": content, | |
} | |
] | |
resp = client.chat.completions.create( | |
model="gpt-4-vision-preview", | |
messages=messages, | |
max_tokens=1024, | |
) | |
return resp.choices[0].message.content | |
class Tool: | |
def __init__(self): | |
self.detector = GroundingDINO() | |
self.search = GoogleSerperAPIWrapper() | |
def execute(self, image_path, new_path, objects, attribute_list, scenetext_list, fact_list): | |
use_text_rec = False | |
use_attribute = False | |
for key in scenetext_list: | |
if scenetext_list[key][0] != "none": | |
use_text_rec = True | |
text_res = None | |
if use_text_rec: | |
text_res = self.detector.execute(image_path=image_path,content="word.number",new_path=new_path,use_text_rec=True) | |
object_res = self.detector.execute(image_path=image_path,content=objects,new_path=new_path,use_text_rec=False) | |
queries = "" | |
if use_attribute: | |
cnt = 1 | |
for key in attribute_list: | |
if attribute_list[key][0] != "none": | |
for query in attribute_list[key]: | |
queries += str(cnt) + "." + query + "\n" | |
cnt += 1 | |
# print(queries) | |
if queries == "": | |
attribue_res = "none information" | |
else: | |
attribue_res = get_openai_reply(image_path, queries) | |
# print(attribue_res) | |
fact_res = "" | |
cnt = 1 | |
for key in fact_list: | |
if fact_list[key][0] != "none": | |
evidences = self.search.execute(input="", content=str(fact_list[key])) | |
for evidence in evidences: | |
fact_res += str(cnt) + "." + evidence + "\n" | |
cnt += 1 | |
if fact_res == "": | |
fact_res = "none information" | |
return object_res, attribue_res, text_res, fact_res | |
# if __name__ == '__main__': | |
# tool = Tool() | |
# extractor = Extractor(model="gpt-4-1106-preview", config_path= "/home/wcx/wcx/GroundingDINO/LVLM/prompt-v2-multi-claim/object_extract.yaml", type="image-to-text") | |
# # "/home/wcx/wcx/LVLMHall-test/text-to-image/labeled.json" | |
# query = Query(config_path="/home/wcx/wcx/GroundingDINO/LVLM/prompt-v2-multi-claim/query.yaml",type="image-to-text") | |
# path = "/home/wcx/wcx/LVLMHall-test/MSCOCO/caption/labeled/minigpt4-100-cx-revise-v1.json" | |
# with open(path, "r", encoding="utf-8") as f: | |
# for idx, line in tqdm(enumerate(f.readlines()), total=250): | |
# # if idx < 189: | |
# # continue | |
# data = data2 | |
# #data = json.loads(line) | |
# image_path = data["image_path"]#"/newdisk3/wcx" + data["image_path"] | |
# claim_list = "" | |
# cnt = 1 | |
# for seg in data["segments"]: | |
# for cla in seg["claims"]: | |
# claim_list += "claim" + str(cnt) + ": " + cla["claim"] + "\n" | |
# cnt += 1 | |
# object_list, objects = extractor.get_response(claims=claim_list) | |
# print("pre:" + objects) | |
# attribute_list, scenetext_list, fact_list, objects = query.get_response(claim_list, objects, object_list) | |
# print("after:" + objects) | |
# print(object_list) | |
# print(attribute_list) | |
# print(scenetext_list) | |
# print(fact_list) | |
# object_res, attribue_res, text_res, fact_res = tool.execute(image_path=image_path, | |
# new_path="/newdisk3/wcx/MLLM/image-to-text/minigpt4/", | |
# attribute_list=attribute_list, | |
# scenetext_list=scenetext_list, | |
# fact_list=fact_list, | |
# objects=objects) | |
# # print(object_res) | |
# # print(attribue_res) | |
# # print(text_res) | |
# #print(fact_res[:50]) | |
# print("=============================") | |
# break | |