|
def data2prompt(caption_response, ref_words): |
|
|
|
ref_word_str = ",".join(ref_words[:5]) |
|
|
|
task_prompt = "Based on the following Caption Response, you will output a description of the Major Object's name." |
|
|
|
input_str = "# Caption Response:\n" + caption_response + "\n" |
|
|
|
CoT_prompt = \ |
|
f""" |
|
Let's think it step by step. Output each field in JSON format. Include the following fields: |
|
- major_object: From the caption response, identify the major_object. If not present, extract it again from the detailed_description or caption_response. |
|
- better_major_object: Reread the description in the caption response to see if there's a more suitable word for the major object. If not, still output major_object. |
|
- echo_1: "I will generate a simple description in about 200 words in English for the better_major_object, introducing what the input object is." |
|
- description: Generate a WIKI description for the better_major_object (explain what is the better_major_object). |
|
- major_object_chinese: Translate the better_major_object into Chinese. |
|
- echo_2: "I will check whether there is synonym of the major_object_chinese in the '{ref_word_str}'." |
|
- synonym: If present, output the synonym directly; otherwise, output "NOT_INCLUDED." |
|
- recheck: Based on the content of the Caption Response, determine whether the synonym is accurate. If accurate, output "ACCURATE"; otherwise, output "NOT_ACCURATE." |
|
""" |
|
return task_prompt + input_str + CoT_prompt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
from src.ZhipuClient import ZhipuClient |
|
except: |
|
from ZhipuClient import ZhipuClient |
|
|
|
zhipu_client = None |
|
|
|
import json |
|
|
|
def markdown_to_json(markdown_str): |
|
|
|
if markdown_str.startswith("```json"): |
|
markdown_str = markdown_str[7:-3].strip() |
|
elif markdown_str.startswith("```"): |
|
markdown_str = markdown_str[3:-3].strip() |
|
|
|
|
|
json_dict = json.loads(markdown_str) |
|
|
|
return json_dict |
|
|
|
import re |
|
|
|
def forced_extract(input_str, keywords): |
|
result = {key: "" for key in keywords} |
|
|
|
for key in keywords: |
|
|
|
pattern = f'"{key}":\s*"(.*?)"' |
|
match = re.search(pattern, input_str) |
|
if match: |
|
result[key] = match.group(1) |
|
|
|
return result |
|
|
|
def get_major_object(caption_response, ref_words): |
|
global zhipu_client |
|
if zhipu_client is None: |
|
zhipu_client = ZhipuClient() |
|
prompt = data2prompt(caption_response , ref_words) |
|
response = zhipu_client.prompt2response(prompt) |
|
|
|
try: |
|
json_response = markdown_to_json(response) |
|
except: |
|
keyword_list = ["major_object", "better_major_object", "description", "major_object_chinese", "synonym", "recheck"] |
|
json_response = forced_extract(response, keyword_list) |
|
|
|
return json_response |
|
|
|
def verify_keyword_in_base( json_response , database ): |
|
|
|
keyword2verify = [] |
|
if "better_major_object" in json_response: |
|
keyword2verify.append(json_response["better_major_object"].lower()) |
|
|
|
if "major_object" in json_response: |
|
keyword2verify.append(json_response["major_object"].lower()) |
|
|
|
if "recheck" in json_response and json_response["recheck"] == "ACCURATE": |
|
if "synonym" in json_response and json_response["synonym"] != "NOT_INCLUDED": |
|
keyword2verify.append(json_response["synonym"].lower()) |
|
|
|
ans = None |
|
|
|
for keyword in keyword2verify: |
|
res = database.search_by_en_keyword(keyword) |
|
if res is None: |
|
continue |
|
ans = res |
|
return ans, None |
|
|
|
if len(keyword2verify) == 0: |
|
return None, None |
|
|
|
|
|
description = keyword2verify[0] |
|
if "description" in json_response: |
|
description = json_response["description"] |
|
|
|
translated_word = keyword2verify[0] |
|
|
|
keyword = translated_word |
|
if "major_object_chinese" in json_response: |
|
keyword = json_response["major_object_chinese"] |
|
|
|
data = { |
|
"keyword": keyword, |
|
"translated_word": translated_word, |
|
"description": description |
|
} |
|
|
|
return None, data |
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
try: |
|
from src.Database import Database |
|
except: |
|
from Database import Database |
|
|
|
db = Database() |
|
|
|
try: |
|
from src.Captioner import Captioner |
|
except: |
|
from Captioner import Captioner |
|
|
|
import os |
|
os.environ['HTTP_PROXY'] = 'http://localhost:8234' |
|
os.environ['HTTPS_PROXY'] = 'http://localhost:8234' |
|
|
|
|
|
captioner = Captioner() |
|
|
|
test_image = "temp_images/3or47vg0.jpg" |
|
caption_response = captioner.caption(test_image) |
|
|
|
|
|
|
|
search_result = db.search_with_image_name( test_image ) |
|
|
|
seen = set() |
|
keywords = [res['translated_word'] for res in search_result if not (res['translated_word'] in seen or seen.add(res['translated_word']))] |
|
|
|
|
|
|
|
|
|
|
|
json_response = get_major_object(caption_response , keywords) |
|
|
|
print(json_response) |
|
|
|
print() |
|
|
|
in_base_data , alt_data = verify_keyword_in_base(json_response , db) |
|
|
|
if in_base_data is not None: |
|
print(in_base_data) |
|
|
|
if alt_data is not None: |
|
print(alt_data) |
|
|
|
|
|
|