File size: 7,764 Bytes
0319a9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
def data2prompt(caption_response, ref_words):
ref_word_str = ",".join(ref_words[:5])
task_prompt = "Based on the following Caption Response, you will output a description of the Major Object's name."
input_str = "# Caption Response:\n" + caption_response + "\n"
CoT_prompt = \
f"""
Let's think it step by step. Output each field in JSON format. Include the following fields:
- major_object: From the caption response, identify the major_object. If not present, extract it again from the detailed_description or caption_response.
- better_major_object: Reread the description in the caption response to see if there's a more suitable word for the major object. If not, still output major_object.
- echo_1: "I will generate a simple description in about 200 words in English for the better_major_object, introducing what the input object is."
- description: Generate a WIKI description for the better_major_object (explain what is the better_major_object).
- major_object_chinese: Translate the better_major_object into Chinese.
- echo_2: "I will check whether there is synonym of the major_object_chinese in the '{ref_word_str}'."
- synonym: If present, output the synonym directly; otherwise, output "NOT_INCLUDED."
- recheck: Based on the content of the Caption Response, determine whether the synonym is accurate. If accurate, output "ACCURATE"; otherwise, output "NOT_ACCURATE."
"""
return task_prompt + input_str + CoT_prompt
# def data2prompt(caption_response , ref_words ):
# ref_word_str = ",".join(ref_words[:5])
# ref_str = "# Reference Word:\n"+ref_word_str+"\n\n"
# task_prompt = "你将根据下面的Caption Response,输出Major Object的名称描述"
# input_str = "# Caption Response:\n"+caption_response+"\n"
# CoT_prompt = \
# """
# Let's think it step by step,以json形式输出逐个字段。包含以下字段
# - major_object: 从caption response中,确认major_object,如果没有,则从detailed_description或者caption_response中重新抽取
# - better_major_object: 重新阅读caption response中的描述,看看是否有更合适的major object的词语,如果没有则仍然输出major_object
# - echo_1: "I will generate a simple description in about 200 words in English for the input word, introducing what the input object is"
# - description: generate the description for the input object
# - major_object_chinese: 将major_object翻译为中文
# - echo_2: "我将判断reference word中,是否存在major_object的同义词"
# - 同义词: 如果存在,则直接输出同义词,否则输出"NOT_INCLUDED"
# - recheck: 结合Caption Response的内容,判断同义词是否准确,如果准确,则输出"ACCURATE",否则输出"NOT_ACCURATE"
# """
# return ref_str+task_prompt+input_str+CoT_prompt
try:
from src.ZhipuClient import ZhipuClient
except:
from ZhipuClient import ZhipuClient
zhipu_client = None
import json
def markdown_to_json(markdown_str):
# 移除Markdown语法中可能存在的标记,如代码块标记等
if markdown_str.startswith("```json"):
markdown_str = markdown_str[7:-3].strip()
elif markdown_str.startswith("```"):
markdown_str = markdown_str[3:-3].strip()
# 将字符串转换为JSON字典
json_dict = json.loads(markdown_str)
return json_dict
import re
def forced_extract(input_str, keywords):
result = {key: "" for key in keywords}
for key in keywords:
# 使用正则表达式来查找关键词-值对
pattern = f'"{key}":\s*"(.*?)"'
match = re.search(pattern, input_str)
if match:
result[key] = match.group(1)
return result
def get_major_object(caption_response, ref_words):
global zhipu_client
if zhipu_client is None:
zhipu_client = ZhipuClient()
prompt = data2prompt(caption_response , ref_words)
response = zhipu_client.prompt2response(prompt)
try:
json_response = markdown_to_json(response)
except:
keyword_list = ["major_object", "better_major_object", "description", "major_object_chinese", "synonym", "recheck"]
json_response = forced_extract(response, keyword_list)
return json_response
def verify_keyword_in_base( json_response , database ):
keyword2verify = []
if "better_major_object" in json_response:
keyword2verify.append(json_response["better_major_object"].lower())
if "major_object" in json_response:
keyword2verify.append(json_response["major_object"].lower())
if "recheck" in json_response and json_response["recheck"] == "ACCURATE":
if "synonym" in json_response and json_response["synonym"] != "NOT_INCLUDED":
keyword2verify.append(json_response["synonym"].lower())
ans = None
for keyword in keyword2verify:
res = database.search_by_en_keyword(keyword)
if res is None:
continue
ans = res
return ans, None
if len(keyword2verify) == 0:
return None, None
# 这里我们需要一个新的data, keyword是中文名, translated_word是英文名,description是英文描述
description = keyword2verify[0]
if "description" in json_response:
description = json_response["description"]
translated_word = keyword2verify[0]
keyword = translated_word
if "major_object_chinese" in json_response:
keyword = json_response["major_object_chinese"]
data = {
"keyword": keyword,
"translated_word": translated_word,
"description": description
}
return None, data
if __name__ == '__main__':
try:
from src.Database import Database
except:
from Database import Database
db = Database()
try:
from src.Captioner import Captioner
except:
from Captioner import Captioner
import os
os.environ['HTTP_PROXY'] = 'http://localhost:8234'
os.environ['HTTPS_PROXY'] = 'http://localhost:8234'
captioner = Captioner()
test_image = "temp_images/3or47vg0.jpg"
caption_response = captioner.caption(test_image)
# print(caption_response)
search_result = db.search_with_image_name( test_image )
seen = set()
keywords = [res['translated_word'] for res in search_result if not (res['translated_word'] in seen or seen.add(res['translated_word']))]
# print(keywords)
# prompt = data2prompt(caption_response , keywords)
# print(prompt)
json_response = get_major_object(caption_response , keywords)
print(json_response)
print()
in_base_data , alt_data = verify_keyword_in_base(json_response , db)
if in_base_data is not None:
print(in_base_data)
if alt_data is not None:
print(alt_data)
# {'keyword': '埃菲尔铁塔', 'translated_word': 'eiffel tower', 'description': "The Eiffel Tower is an iconic symbol of Paris and one of the most recognizable stru
# ower', 'description': "The Eiffel Tower is an iconic symbol of Paris and one of the most recognizable structures in the world. Designed and constructed by the engineer Gustave Eiffel and his company for the 1889 Exposition Universelle (World's Fair) to celebrate the 100th anniversary of the French Revolution, the tower was initially criticized by some of France's leading artists and intellectuals. However, it quickly became a beloved landmark and a symbol of French pride. Standing 324 meters tall, the tower is made of wrought iron and consists of thousands of metal parts, including over 18,000 individual iron rivets. It is renowned for its architectural and engineering design, and it is visited by millions of people each year, making it one of the most visited paid monuments in the world."} |