Upload 19 files
Browse files- src/CLIPExtractor.py +93 -0
- src/Captioner.py +148 -0
- src/Database.py +278 -0
- src/Founder.py +69 -0
- src/GameMaster.py +189 -0
- src/ImageBase.py +162 -0
- src/ZhipuClient.py +35 -0
- src/__pycache__/Captioner.cpython-310.pyc +0 -0
- src/__pycache__/Database.cpython-310.pyc +0 -0
- src/__pycache__/GameMaster.cpython-310.pyc +0 -0
- src/__pycache__/ImageBase.cpython-310.pyc +0 -0
- src/__pycache__/ZhipuClient.cpython-310.pyc +0 -0
- src/__pycache__/generate_cultivation.cpython-310.pyc +0 -0
- src/__pycache__/get_major_object.cpython-310.pyc +0 -0
- src/__pycache__/text_embedding.cpython-310.pyc +0 -0
- src/generate_cultivation.py +183 -0
- src/get_comments_from_level.py +29 -0
- src/get_major_object.py +195 -0
- src/text_embedding.py +289 -0
src/CLIPExtractor.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from transformers import CLIPProcessor, CLIPModel
|
4 |
+
import cv2
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
class CLIPExtractor:
|
11 |
+
def __init__(self, model_name="openai/clip-vit-large-patch14", cache_dir=None):
|
12 |
+
|
13 |
+
# 设置代理环境变量
|
14 |
+
os.environ['HTTP_PROXY'] = 'http://localhost:8234'
|
15 |
+
os.environ['HTTPS_PROXY'] = 'http://localhost:8234'
|
16 |
+
|
17 |
+
# 设置环境变量
|
18 |
+
os.environ["HF_ENDPOINT"] = "https://hf-api.gitee.com"
|
19 |
+
os.environ["HF_HOME"] = os.path.expanduser("models/")
|
20 |
+
|
21 |
+
if not cache_dir:
|
22 |
+
# 指定缓存目录
|
23 |
+
cache_dir = "models"
|
24 |
+
if not os.path.exists(cache_dir) and os.path.exists("../models"):
|
25 |
+
cache_dir = "../models"
|
26 |
+
|
27 |
+
# Initialize the model and processor with specified values
|
28 |
+
self.model = CLIPModel.from_pretrained(model_name, cache_dir=cache_dir)
|
29 |
+
self.processor = CLIPProcessor.from_pretrained(model_name, cache_dir=cache_dir)
|
30 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
31 |
+
self.model.to(self.device)
|
32 |
+
|
33 |
+
def extract_image(self, frame):
|
34 |
+
# Convert frame (from OpenCV) to PIL Image
|
35 |
+
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
36 |
+
images = [image]
|
37 |
+
|
38 |
+
# Process the image and extract features
|
39 |
+
inputs = self.processor(images=images, return_tensors="pt").to(self.device)
|
40 |
+
with torch.no_grad():
|
41 |
+
outputs = self.model.get_image_features(**inputs)
|
42 |
+
|
43 |
+
ans = outputs.cpu().numpy()
|
44 |
+
return ans[0]
|
45 |
+
|
46 |
+
def extract_image_from_file(self, file_name):
|
47 |
+
if not os.path.exists(file_name):
|
48 |
+
raise FileNotFoundError(f"File {file_name} not found.")
|
49 |
+
|
50 |
+
images = [Image.open(file_name).convert("RGB")]
|
51 |
+
|
52 |
+
# Process the image and extract features
|
53 |
+
inputs = self.processor(images=images, return_tensors="pt").to(self.device)
|
54 |
+
with torch.no_grad():
|
55 |
+
outputs = self.model.get_image_features(**inputs)
|
56 |
+
|
57 |
+
ans = outputs.cpu().numpy()
|
58 |
+
return ans[0]
|
59 |
+
|
60 |
+
def extract_text(self, text):
|
61 |
+
if not isinstance(text, str) or not text:
|
62 |
+
raise ValueError("Input text should be a non-empty string.")
|
63 |
+
|
64 |
+
# Tokenize the text
|
65 |
+
inputs = self.processor.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=77).to(self.device)
|
66 |
+
|
67 |
+
|
68 |
+
# Process the text and extract features
|
69 |
+
# inputs = self.processor(text=[text], return_tensors="pt", padding=True).to(self.device)
|
70 |
+
|
71 |
+
with torch.no_grad():
|
72 |
+
outputs = self.model.get_text_features(**inputs)
|
73 |
+
|
74 |
+
ans = outputs.cpu().numpy()
|
75 |
+
return ans[0]
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
|
80 |
+
clip_extractor = CLIPExtractor()
|
81 |
+
|
82 |
+
sample_image = "images/狐狸.jpg"
|
83 |
+
# 提取图像特征
|
84 |
+
image_feature = clip_extractor.extract_image_from_file(sample_image)
|
85 |
+
|
86 |
+
|
87 |
+
# 提取文本特征
|
88 |
+
sample_text = "A photo of fox"
|
89 |
+
text_feature = clip_extractor.extract_text(sample_text)
|
90 |
+
|
91 |
+
# consine similarity
|
92 |
+
cosine_similarity = np.dot(image_feature, text_feature) / (np.linalg.norm(image_feature) * np.linalg.norm(text_feature))
|
93 |
+
print(cosine_similarity)
|
src/Captioner.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import base64
|
3 |
+
from io import BytesIO
|
4 |
+
import os
|
5 |
+
from openai import OpenAI
|
6 |
+
import json
|
7 |
+
|
8 |
+
class Captioner:
|
9 |
+
def __init__(self, api_key_path = None, proxy=None, api_base="https://api.lingyiwanwu.com/v1"):
|
10 |
+
|
11 |
+
if api_key_path is None:
|
12 |
+
# try find datas/01_key.txt and ../datas/01_key.txt
|
13 |
+
cand_paths = ['datas/01_key.txt', '../datas/01_key.txt']
|
14 |
+
flag = False
|
15 |
+
for path in cand_paths:
|
16 |
+
if os.path.exists(path):
|
17 |
+
api_key_path = path
|
18 |
+
flag = True
|
19 |
+
break
|
20 |
+
|
21 |
+
if not flag:
|
22 |
+
raise ValueError("Please provide the path to the API key file.")
|
23 |
+
|
24 |
+
|
25 |
+
self.api_key = self.load_access_token(api_key_path)
|
26 |
+
self.api_base = api_base
|
27 |
+
if proxy:
|
28 |
+
os.environ['HTTP_PROXY'] = proxy
|
29 |
+
os.environ['HTTPS_PROXY'] = proxy
|
30 |
+
self.client = OpenAI(
|
31 |
+
api_key=self.api_key,
|
32 |
+
base_url=self.api_base
|
33 |
+
)
|
34 |
+
|
35 |
+
self.history = {}
|
36 |
+
self.history_file = None
|
37 |
+
|
38 |
+
self.load_history()
|
39 |
+
|
40 |
+
def load_access_token(self, file_path):
|
41 |
+
with open(file_path, 'r') as file:
|
42 |
+
return file.read().strip()
|
43 |
+
|
44 |
+
def image2base64(self, image_path):
|
45 |
+
# 打开图像
|
46 |
+
with Image.open(image_path) as img:
|
47 |
+
# 检查图像高度是否超过480
|
48 |
+
if img.height > 480:
|
49 |
+
# 计算调整后的宽度,以保持宽高比不变
|
50 |
+
aspect_ratio = img.width / img.height
|
51 |
+
new_height = 480
|
52 |
+
new_width = int(new_height * aspect_ratio)
|
53 |
+
img = img.resize((new_width, new_height), Image.ANTIALIAS)
|
54 |
+
|
55 |
+
# 使用BytesIO在内存中保存调整大小后的图像
|
56 |
+
buffered = BytesIO()
|
57 |
+
img.save(buffered, format="JPEG")
|
58 |
+
buffered.seek(0)
|
59 |
+
|
60 |
+
# 将图像转换为Base64编码字符串
|
61 |
+
img_base64 = "data:image/jpeg;base64," + base64.b64encode(buffered.read()).decode('utf-8')
|
62 |
+
|
63 |
+
return img_base64
|
64 |
+
|
65 |
+
def load_history(self, jsonl_file_name=None):
|
66 |
+
if jsonl_file_name is None:
|
67 |
+
jsonl_file_name = "datas/caption_history.jsonl"
|
68 |
+
|
69 |
+
self.history_file = jsonl_file_name
|
70 |
+
|
71 |
+
if os.path.exists(jsonl_file_name):
|
72 |
+
with open(jsonl_file_name, 'r', encoding='utf-8') as f:
|
73 |
+
for line in f:
|
74 |
+
data = json.loads(line)
|
75 |
+
self.history[data['file_name']] = data['response']
|
76 |
+
|
77 |
+
def search_from_history(self, file_name):
|
78 |
+
return self.history.get(file_name, None)
|
79 |
+
|
80 |
+
def save_history(self, jsonl_file_name=None):
|
81 |
+
if jsonl_file_name is None:
|
82 |
+
jsonl_file_name = self.history_file
|
83 |
+
|
84 |
+
if jsonl_file_name:
|
85 |
+
with open(jsonl_file_name, 'w', encoding='utf-8') as f:
|
86 |
+
for file_name, response in self.history.items():
|
87 |
+
json.dump({'file_name': file_name, 'response': response}, f, ensure_ascii=False)
|
88 |
+
f.write('\n')
|
89 |
+
|
90 |
+
# print(f"History saved to {jsonl_file_name}")
|
91 |
+
|
92 |
+
def add_to_history(self, file_name, response):
|
93 |
+
self.history[file_name] = response
|
94 |
+
|
95 |
+
def caption(self, image_name):
|
96 |
+
|
97 |
+
# Check if the caption is already in the history
|
98 |
+
cached_response = self.search_from_history(image_name)
|
99 |
+
if cached_response:
|
100 |
+
# print("return the cache")
|
101 |
+
return cached_response
|
102 |
+
|
103 |
+
prompt = """Analyze the image and output in JSON format, including the following fields:
|
104 |
+
- "detailed_description": A detailed description of the image content.
|
105 |
+
- "major_object": Determine the main object/scene in the image based on the description, output with a simple word
|
106 |
+
- "Chinese_name": 判断图片中主要物体的中文名
|
107 |
+
- "real_or_composite": Determine whether this image was taken with a camera or created/modifed by a computer, output with real or composite."""
|
108 |
+
|
109 |
+
img_base64 = self.image2base64(image_name)
|
110 |
+
|
111 |
+
completion = self.client.chat.completions.create(
|
112 |
+
model="yi-vision",
|
113 |
+
messages=[
|
114 |
+
{
|
115 |
+
"role": "user",
|
116 |
+
"content": [
|
117 |
+
{
|
118 |
+
"type": "text",
|
119 |
+
"text": prompt
|
120 |
+
},
|
121 |
+
{
|
122 |
+
"type": "image_url",
|
123 |
+
"image_url": {
|
124 |
+
"url": img_base64
|
125 |
+
}
|
126 |
+
}
|
127 |
+
]
|
128 |
+
}
|
129 |
+
],
|
130 |
+
stream=False
|
131 |
+
)
|
132 |
+
|
133 |
+
response = completion.choices[0].message.content
|
134 |
+
|
135 |
+
# Add the new response to history
|
136 |
+
self.add_to_history(image_name, response)
|
137 |
+
# Save history after adding the new entry
|
138 |
+
self.save_history()
|
139 |
+
|
140 |
+
return response
|
141 |
+
|
142 |
+
if __name__ == "__main__":
|
143 |
+
import os
|
144 |
+
os.environ['HTTP_PROXY'] = 'http://localhost:8234'
|
145 |
+
os.environ['HTTPS_PROXY'] = 'http://localhost:8234'
|
146 |
+
captioner = Captioner()
|
147 |
+
test_image = "temp_images/3zjz9b3l.jpg"
|
148 |
+
print(captioner.caption(test_image))
|
src/Database.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import os
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
7 |
+
|
8 |
+
|
9 |
+
class Database:
|
10 |
+
def __init__(self, parquet_path=None, customized_parquets = None):
|
11 |
+
self.default_parquet_path = 'datas/database_4000.parquet'
|
12 |
+
self.parquet_path = parquet_path or self.default_parquet_path
|
13 |
+
|
14 |
+
self.default_customized_parquets = ["datas/customized_database_0.parquet"]
|
15 |
+
self.customized_parquets = customized_parquets or self.default_customized_parquets
|
16 |
+
|
17 |
+
self.datas = None
|
18 |
+
self.last_save_table = None
|
19 |
+
|
20 |
+
if os.path.exists(self.parquet_path):
|
21 |
+
self.load_from_parquet(self.parquet_path)
|
22 |
+
|
23 |
+
self.load_from_customized(self.customized_parquets)
|
24 |
+
|
25 |
+
self.clip_extractor = None
|
26 |
+
self.bge_extractor = None
|
27 |
+
|
28 |
+
self.en_keyword2data = {}
|
29 |
+
|
30 |
+
def build_en_keyword2index(self):
|
31 |
+
# build in lower case
|
32 |
+
self.en_keyword2data = {row['translated_word'].lower(): row for i, row in self.datas.iterrows()}
|
33 |
+
|
34 |
+
def search_by_en_keyword(self, keyword):
|
35 |
+
if len(self.en_keyword2data) == 0:
|
36 |
+
self.build_en_keyword2index()
|
37 |
+
|
38 |
+
keyword = keyword.lower()
|
39 |
+
if keyword in self.en_keyword2data:
|
40 |
+
ans = self.en_keyword2data[keyword].to_dict()
|
41 |
+
del ans["clip_feature"]
|
42 |
+
del ans["bge_feature"]
|
43 |
+
return ans
|
44 |
+
else:
|
45 |
+
return None
|
46 |
+
|
47 |
+
def load_from_parquet(self, parquet_path):
|
48 |
+
self.datas = pd.read_parquet(parquet_path)
|
49 |
+
|
50 |
+
def load_from_customized(self, customized_parquets=None):
|
51 |
+
customized_parquets = customized_parquets or self.customized_parquets
|
52 |
+
|
53 |
+
# Load each parquet file and concatenate them into the self.datas DataFrame
|
54 |
+
for index, parquet_file in enumerate(customized_parquets):
|
55 |
+
if os.path.exists(parquet_file):
|
56 |
+
temp_df = pd.read_parquet(parquet_file)
|
57 |
+
if self.datas is None:
|
58 |
+
self.datas = temp_df
|
59 |
+
else:
|
60 |
+
self.datas = pd.concat([self.datas, temp_df], ignore_index=True)
|
61 |
+
|
62 |
+
# if last parquet file
|
63 |
+
if index == len(customized_parquets) - 1:
|
64 |
+
self.last_save_table = temp_df
|
65 |
+
|
66 |
+
# if customized_parquets:
|
67 |
+
# Record the last parquet file's contents as self.last_save_table
|
68 |
+
|
69 |
+
|
70 |
+
def add_data(self, data, if_save=True):
|
71 |
+
required_columns = ['keyword', 'name_in_cultivation', 'description_in_cultivation', 'translated_word', 'description']
|
72 |
+
for column in required_columns:
|
73 |
+
if column not in data:
|
74 |
+
raise ValueError(f"Missing required field: {column}")
|
75 |
+
|
76 |
+
# Optional field
|
77 |
+
if 'founder' not in data:
|
78 |
+
data['founder'] = ""
|
79 |
+
|
80 |
+
# Extract features
|
81 |
+
if self.clip_extractor is None:
|
82 |
+
self.init_clip_extractor()
|
83 |
+
if self.bge_extractor is None:
|
84 |
+
self.init_bge_extractor()
|
85 |
+
|
86 |
+
data['clip_feature'] = self.clip_extractor.extract_text(data['translated_word'] + '.' + data['description'])
|
87 |
+
data['bge_feature'] = self.bge_extractor.extract([data['keyword']])[0].tolist()
|
88 |
+
|
89 |
+
# Convert to DataFrame and add to self.datas
|
90 |
+
data_df = pd.DataFrame([data])
|
91 |
+
if self.datas is None:
|
92 |
+
self.datas = data_df
|
93 |
+
else:
|
94 |
+
self.datas = pd.concat([self.datas, data_df], ignore_index=True)
|
95 |
+
|
96 |
+
# set self.en_keyword2data to last row of self.datas
|
97 |
+
self.en_keyword2data[data['translated_word'].lower()] = self.datas.iloc[-1]
|
98 |
+
|
99 |
+
# Add to last_save_table
|
100 |
+
if self.last_save_table is None:
|
101 |
+
# self.last_save_table = data_df
|
102 |
+
# create a new DataFrame with the same columns as self.datas
|
103 |
+
self.last_save_table = pd.DataFrame(columns=self.datas.columns)
|
104 |
+
|
105 |
+
self.last_save_table = pd.concat([self.last_save_table, data_df], ignore_index=True)
|
106 |
+
|
107 |
+
if if_save:
|
108 |
+
self.save_to_parquet(self.customized_parquets[-1], self.last_save_table )
|
109 |
+
|
110 |
+
def add_datas(self, datas, if_save=True):
|
111 |
+
for data in datas:
|
112 |
+
self.add_data(data, if_save=False)
|
113 |
+
if if_save:
|
114 |
+
self.save_to_parquet(self.customized_parquets[-1], self.last_save_table)
|
115 |
+
|
116 |
+
def init_from_excel(self, excel_path):
|
117 |
+
df = pd.read_excel(excel_path)
|
118 |
+
|
119 |
+
# Drop rows with any empty cell in the required columns
|
120 |
+
df.dropna(subset=['keyword', 'name_in_cultivation', 'description_in_cultivation', 'translated_word', 'description'], inplace=True)
|
121 |
+
|
122 |
+
# Add the new columns
|
123 |
+
df['clip_feature'] = None
|
124 |
+
df['bge_feature'] = None
|
125 |
+
|
126 |
+
self.datas = df
|
127 |
+
|
128 |
+
self.extract_clip()
|
129 |
+
self.extract_bge()
|
130 |
+
|
131 |
+
def save_to_parquet(self, parquet_path=None, df = None):
|
132 |
+
|
133 |
+
parquet_path = parquet_path or self.default_parquet_path
|
134 |
+
if df is None:
|
135 |
+
if self.datas is not None:
|
136 |
+
self.datas.to_parquet(parquet_path)
|
137 |
+
else:
|
138 |
+
df.to_parquet(parquet_path)
|
139 |
+
|
140 |
+
def init_clip_extractor(self):
|
141 |
+
if self.clip_extractor is None:
|
142 |
+
try:
|
143 |
+
from CLIPExtractor import CLIPExtractor
|
144 |
+
except:
|
145 |
+
from src.CLIPExtractor import CLIPExtractor
|
146 |
+
|
147 |
+
cache_dir = "models"
|
148 |
+
|
149 |
+
self.clip_extractor = CLIPExtractor(model_name = "openai/clip-vit-large-patch14",cache_dir = cache_dir)
|
150 |
+
|
151 |
+
|
152 |
+
def extract_clip(self):
|
153 |
+
if self.clip_extractor is None:
|
154 |
+
self.init_clip_extractor()
|
155 |
+
|
156 |
+
clip_features = []
|
157 |
+
# for text in tqdm(self.datas['keyword'], desc='Extracting CLIP features'):
|
158 |
+
for index, row in tqdm(self.datas.iterrows(), desc='Extracting CLIP features', total=len(self.datas)):
|
159 |
+
text = row['translated_word'] + '.' + row['description']
|
160 |
+
if text:
|
161 |
+
feature = self.clip_extractor.extract_text(text)
|
162 |
+
else:
|
163 |
+
feature = None
|
164 |
+
clip_features.append(feature)
|
165 |
+
|
166 |
+
self.datas['clip_feature'] = clip_features
|
167 |
+
|
168 |
+
def init_bge_extractor(self):
|
169 |
+
if self.bge_extractor is None:
|
170 |
+
try:
|
171 |
+
from text_embedding import TextExtractor
|
172 |
+
except:
|
173 |
+
from src.text_embedding import TextExtractor
|
174 |
+
|
175 |
+
self.bge_extractor = TextExtractor('BAAI/bge-small-zh-v1.5')
|
176 |
+
|
177 |
+
def top_k_search(self, query_feature, attribute, top_k=15):
|
178 |
+
# Ensure the attribute exists in the dataframe
|
179 |
+
if attribute not in self.datas.columns:
|
180 |
+
raise ValueError(f"Attribute {attribute} not found in the data.")
|
181 |
+
|
182 |
+
# Convert query feature and attribute features to numpy arrays
|
183 |
+
query_feature = np.array(query_feature).reshape(1, -1)
|
184 |
+
attribute_features = np.stack(self.datas[attribute].dropna().values)
|
185 |
+
|
186 |
+
# Compute cosine similarity between query and all attributes
|
187 |
+
similarities = cosine_similarity(query_feature, attribute_features)[0]
|
188 |
+
|
189 |
+
# Get the top_k indices based on similarity
|
190 |
+
top_k_indices = np.argsort(similarities)[-top_k:][::-1]
|
191 |
+
|
192 |
+
# Retrieve the top_k most similar items
|
193 |
+
top_k_results = self.datas.iloc[top_k_indices].copy()
|
194 |
+
|
195 |
+
top_k_results = top_k_results.drop(columns=['clip_feature', 'bge_feature'])
|
196 |
+
|
197 |
+
top_k_results['similarity'] = similarities[top_k_indices]
|
198 |
+
|
199 |
+
return top_k_results.to_dict(orient='records')
|
200 |
+
|
201 |
+
def search_with_image_name(self, image_name):
|
202 |
+
self.init_clip_extractor()
|
203 |
+
|
204 |
+
img_feature = self.clip_extractor.extract_image_from_file(image_name)
|
205 |
+
|
206 |
+
return self.top_k_search(img_feature, 'clip_feature')
|
207 |
+
|
208 |
+
def search_with_image(self, image, if_opencv = False ):
|
209 |
+
if self.clip_extractor is None:
|
210 |
+
self.init_clip_extractor()
|
211 |
+
|
212 |
+
img_feature = self.clip_extractor.extract_image(image, if_opencv = if_opencv)
|
213 |
+
|
214 |
+
return self.top_k_search(img_feature, 'clip_feature')
|
215 |
+
|
216 |
+
def search_with_chinese(self, text):
|
217 |
+
if self.bge_extractor is None:
|
218 |
+
self.init_bge_extractor()
|
219 |
+
|
220 |
+
text_feature = self.bge_extractor.extract([text])[0].tolist()
|
221 |
+
|
222 |
+
return self.top_k_search(text_feature, 'bge_feature')
|
223 |
+
|
224 |
+
|
225 |
+
|
226 |
+
def extract_bge(self):
|
227 |
+
if self.bge_extractor is None:
|
228 |
+
self.init_bge_extractor()
|
229 |
+
|
230 |
+
# Extract features for each row and store them in the bge_feature column
|
231 |
+
bge_features = []
|
232 |
+
for text in tqdm(self.datas['keyword'], desc='Extracting BGE features'):
|
233 |
+
if text:
|
234 |
+
feature = self.bge_extractor.extract([text])[0].tolist()
|
235 |
+
else:
|
236 |
+
feature = None
|
237 |
+
bge_features.append(feature)
|
238 |
+
|
239 |
+
self.datas['bge_feature'] = bge_features
|
240 |
+
|
241 |
+
if __name__ == '__main__':
|
242 |
+
# Usage example
|
243 |
+
db = Database()
|
244 |
+
re_extract = False
|
245 |
+
if db.datas is None or re_extract:
|
246 |
+
print("Rebuilding database from excel file")
|
247 |
+
db.init_from_excel('datas/database_4000.xlsx')
|
248 |
+
db.save_to_parquet()
|
249 |
+
|
250 |
+
# print(db.datas[0].keys())
|
251 |
+
|
252 |
+
query_text = "钢琴"
|
253 |
+
|
254 |
+
results = db.search_with_chinese(query_text)
|
255 |
+
|
256 |
+
print(results[0].keys())
|
257 |
+
|
258 |
+
for result in results[:3]:
|
259 |
+
print(result)
|
260 |
+
|
261 |
+
image_path = "datas/老虎.jpg"
|
262 |
+
|
263 |
+
results = db.search_with_image_name(image_path)
|
264 |
+
|
265 |
+
for result in results[:3]:
|
266 |
+
print(result)
|
267 |
+
# 'keyword': '老虎狗', 'name_in_cultivation': '灵虎犬神', 'description_in_cultivation': '在九天灵脉汇聚的仙山之巅,灵虎犬神身披星图
|
268 |
+
# 斑纹,汲取日月精华,以雷霆之力守护仙脉,其双眼中映照着轮回之道,是修仙者追寻天地真理的指引,也是象征极致灵性的神秘灵兽。', 'translated_word': 'Tiger Dog', 'description': 'A Tiger Dog is a term that might refer to a mythical creature or a breed of dog with a distinctive and unusual appearance, resembling the features of a tiger. It could be characterized by its striking coat with patterns similar to those of a tiger, or by having a demeanor that is fierce and majestic like a tiger. This term is not commonly used in
|
269 |
+
# conventional contexts and might be found in stories, folktales, or in the names of unique dog breeds that have been bred to exhibit such features.', 'founder': ''
|
270 |
+
# test_new_data = {
|
271 |
+
# "keyword": "老虎狗2",
|
272 |
+
# "name_in_cultivation": "灵虎犬神",
|
273 |
+
# "description_in_cultivation": "在九天灵脉汇聚的仙山之巅,灵虎犬神身披星图斑纹,汲取日月精华,以雷霆之力守护仙脉,其双眼中映照着轮回之道,是修仙者追寻天地真理的指引,也是象征极致灵性的神秘灵兽。",
|
274 |
+
# "translated_word": "Tiger Dog",
|
275 |
+
# "description":"A Tiger Dog is a term that might refer to a mythical creature or a breed of dog with a distinctive and unusual appearance, resembling the features of a tiger. It could be characterized by its striking coat with patterns similar to those of a tiger, or by having a demeanor that is fierce and majestic like a tiger. This term is not commonly used in conventional contexts and might be found in stories, folktales, or in the names of unique dog breeds that have been bred to exhibit such features."
|
276 |
+
# }
|
277 |
+
|
278 |
+
# db.add_data(test_new_data)
|
src/Founder.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from collections import defaultdict
|
3 |
+
|
4 |
+
class Founder:
|
5 |
+
def __init__(self, filepath='datas/founder.jsonl'):
|
6 |
+
self.filepath = filepath
|
7 |
+
self.datas = {}
|
8 |
+
self.founder2items = defaultdict(list)
|
9 |
+
|
10 |
+
try:
|
11 |
+
self.load_founder()
|
12 |
+
except FileNotFoundError:
|
13 |
+
self.datas = {}
|
14 |
+
|
15 |
+
# Initialize the reverse mapping
|
16 |
+
for word, founder in self.datas.items():
|
17 |
+
self.founder2items[founder].append(word)
|
18 |
+
|
19 |
+
def load_founder(self):
|
20 |
+
"""Load founder data from a jsonl file."""
|
21 |
+
with open(self.filepath, 'r', encoding='utf-8') as file:
|
22 |
+
for line in file:
|
23 |
+
data = json.loads(line.strip())
|
24 |
+
self.datas.update(data)
|
25 |
+
|
26 |
+
def save_founder(self):
|
27 |
+
"""Save founder data to a jsonl file."""
|
28 |
+
with open(self.filepath, 'w', encoding='utf-8') as file:
|
29 |
+
for word, founder in self.datas.items():
|
30 |
+
file.write(json.dumps({word: founder}, ensure_ascii=False) + '\n')
|
31 |
+
|
32 |
+
def get_founder(self, word):
|
33 |
+
"""Get the founder of a given word."""
|
34 |
+
return self.datas.get(word, None)
|
35 |
+
|
36 |
+
def set_founder(self, word, founder, enforce=False):
|
37 |
+
"""Set the founder of a word if it's not already set or if enforce is True."""
|
38 |
+
if word in self.datas and not enforce:
|
39 |
+
print(f"Warning: {word} already has a founder: {self.datas[word]}. Use enforce=True to override.")
|
40 |
+
else:
|
41 |
+
self.datas[word] = founder
|
42 |
+
self.founder2items[founder].append(word)
|
43 |
+
self.save_founder()
|
44 |
+
|
45 |
+
def get_all_items_from_founder(self, founder):
|
46 |
+
"""Get all words discovered by a specific founder."""
|
47 |
+
return self.founder2items.get(founder, [])
|
48 |
+
|
49 |
+
def get_top_rank(self, top_k=20):
|
50 |
+
"""Get the top_k founders with the most discovered words."""
|
51 |
+
sorted_founders = sorted(self.founder2items.items(), key=lambda x: len(x[1]), reverse=True)
|
52 |
+
return sorted_founders[:top_k]
|
53 |
+
|
54 |
+
# Example usage:
|
55 |
+
# founder = Founder()
|
56 |
+
# founder.set_founder('apple', 'Alice')
|
57 |
+
# founder.set_founder('banana', 'Bob')
|
58 |
+
# print(founder.get_founder('apple'))
|
59 |
+
# print(founder.get_all_items_from_founder('Alice'))
|
60 |
+
# print(founder.get_top_rank())
|
61 |
+
|
62 |
+
if __name__ == '__main__':
|
63 |
+
founder = Founder()
|
64 |
+
founder.set_founder('test_apple', '鲁鲁道祖')
|
65 |
+
founder.set_founder('test_banana', '鲁鲁道祖')
|
66 |
+
founder.set_founder('test_orange', "文钊道祖")
|
67 |
+
print(founder.get_founder('test_apple'))
|
68 |
+
print(founder.get_all_items_from_founder('Alice'))
|
69 |
+
print(founder.get_top_rank())
|
src/GameMaster.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from glob import glob
|
3 |
+
|
4 |
+
try:
|
5 |
+
from src.Database import Database
|
6 |
+
from src.Captioner import Captioner
|
7 |
+
from src.ImageBase import Imagebase
|
8 |
+
from src.get_major_object import get_major_object, verify_keyword_in_base
|
9 |
+
from src.generate_cultivation import generate_cultivation_with_rag
|
10 |
+
except:
|
11 |
+
from Database import Database
|
12 |
+
from Captioner import Captioner
|
13 |
+
from ImageBase import Imagebase
|
14 |
+
from get_major_object import get_major_object, verify_keyword_in_base
|
15 |
+
from generate_cultivation import generate_cultivation_with_rag
|
16 |
+
|
17 |
+
|
18 |
+
class GameMaster:
|
19 |
+
def __init__( self ):
|
20 |
+
self.textdb = self.init_textdb()
|
21 |
+
|
22 |
+
self.clip_extractor = self.textdb.clip_extractor
|
23 |
+
|
24 |
+
self.imgdb = self.init_imgdb()
|
25 |
+
|
26 |
+
self.captioner = Captioner()
|
27 |
+
|
28 |
+
self.minimal_image_threshold = 0.9
|
29 |
+
|
30 |
+
def init_textdb( self ):
|
31 |
+
text_db = Database()
|
32 |
+
text_db.init_bge_extractor()
|
33 |
+
text_db.init_clip_extractor()
|
34 |
+
return text_db
|
35 |
+
|
36 |
+
def init_imgdb( self ):
|
37 |
+
img_db = Imagebase()
|
38 |
+
return img_db
|
39 |
+
|
40 |
+
def random_image_text_data( self, n = 12 ):
|
41 |
+
random_img_datas = self.imgdb.random_sample(n)
|
42 |
+
# keep image_name and keywords only
|
43 |
+
image_names = [img_data['image_name'] for img_data in random_img_datas]
|
44 |
+
blank_image_path = "datas/blank_item.jpg"
|
45 |
+
for i in range(len(image_names)):
|
46 |
+
if not os.path.exists(image_names[i]):
|
47 |
+
image_names[i] = blank_image_path
|
48 |
+
|
49 |
+
keywords_zh = [img_data['keyword'] for img_data in random_img_datas]
|
50 |
+
keywords = [img_data['translated_word'] for img_data in random_img_datas]
|
51 |
+
descriptions = []
|
52 |
+
|
53 |
+
for keyword, keyword_zh in zip(keywords, keywords_zh):
|
54 |
+
result = self.textdb.search_by_en_keyword(keyword)
|
55 |
+
if result and "description_in_cultivation" in result:
|
56 |
+
description = result['description_in_cultivation']
|
57 |
+
if "name_in_cultivation" in result:
|
58 |
+
description = result['name_in_cultivation'] + "--" + description
|
59 |
+
descriptions.append(description)
|
60 |
+
else:
|
61 |
+
descriptions.append("")
|
62 |
+
|
63 |
+
#return tuple of imapge path and description
|
64 |
+
return zip(image_names, descriptions)
|
65 |
+
|
66 |
+
|
67 |
+
def search_with_path( self, image_path , threshold = None ):
|
68 |
+
# this is a relatively light weight search
|
69 |
+
image_feature = self.clip_extractor.extract_image_from_file(image_path)
|
70 |
+
|
71 |
+
# image_search_result = img_db.search_with_image_name(image_path)
|
72 |
+
image_search_result = self.imgdb.top_k_search(image_feature, top_k=1)
|
73 |
+
|
74 |
+
search_result = None
|
75 |
+
|
76 |
+
if threshold is None:
|
77 |
+
threshold = self.minimal_image_threshold
|
78 |
+
|
79 |
+
if image_search_result and len(image_search_result)>0 and image_search_result[0]['similarity'] > threshold:
|
80 |
+
|
81 |
+
# try find data with translated_word
|
82 |
+
result = self.textdb.search_by_en_keyword(image_search_result[0]['translated_word'])
|
83 |
+
if result and "name_in_cultivation" in result:
|
84 |
+
search_result = result
|
85 |
+
search_result['similarity'] = image_search_result[0]['similarity']
|
86 |
+
else:
|
87 |
+
print("Warning! Unfound keyword: ", image_search_result[0]['translated_word'])
|
88 |
+
|
89 |
+
# backup_results = None
|
90 |
+
# if search_result is None:
|
91 |
+
# try search with textdb
|
92 |
+
backup_results = self.textdb.top_k_search(image_feature, 'clip_feature', top_k = 5)
|
93 |
+
|
94 |
+
return search_result, backup_results, image_feature
|
95 |
+
|
96 |
+
def generate_cultivation_data( self, image_path , image_feature, text_search_result ):
|
97 |
+
# this is very expensive
|
98 |
+
|
99 |
+
cultivation_data = None
|
100 |
+
|
101 |
+
try:
|
102 |
+
caption_response = self.captioner.caption(image_path)
|
103 |
+
except:
|
104 |
+
print("Error occurred while captioning the image ", image_path)
|
105 |
+
return cultivation_data
|
106 |
+
|
107 |
+
if text_search_result is None:
|
108 |
+
# complete text search
|
109 |
+
text_search_result = self.textdb.top_k_search(image_feature, 'clip_feature', top_k = 5)
|
110 |
+
|
111 |
+
seen = set()
|
112 |
+
keywords = [res['translated_word'] for res in text_search_result if not (res['translated_word'] in seen or seen.add(res['translated_word']))]
|
113 |
+
|
114 |
+
try:
|
115 |
+
json_response = get_major_object(caption_response , keywords)
|
116 |
+
except:
|
117 |
+
print("Error occurred while getting major object from caption ", caption_response)
|
118 |
+
return cultivation_data
|
119 |
+
|
120 |
+
in_base_data , alt_data = verify_keyword_in_base(json_response , self.textdb )
|
121 |
+
|
122 |
+
if in_base_data is not None:
|
123 |
+
cultivation_data = in_base_data
|
124 |
+
|
125 |
+
# 这意味着找到了一张新的图片,不需要生成额外的词条
|
126 |
+
# required_fields = ['image_name', 'keyword', 'translated_word']
|
127 |
+
image_data = {
|
128 |
+
'image_name': image_path,
|
129 |
+
'keyword': in_base_data['keyword'],
|
130 |
+
'translated_word': in_base_data['translated_word']
|
131 |
+
}
|
132 |
+
self.imgdb.add_image( image_data, True, image_feature )
|
133 |
+
elif alt_data is not None:
|
134 |
+
try:
|
135 |
+
cultivation_data = generate_cultivation_with_rag(alt_data, text_search_result)
|
136 |
+
except:
|
137 |
+
print("Error occurred while generating cultivation data")
|
138 |
+
return cultivation_data
|
139 |
+
|
140 |
+
new_data = {
|
141 |
+
"keyword": alt_data['keyword'],
|
142 |
+
"name_in_cultivation": cultivation_data['new_name'],
|
143 |
+
"description_in_cultivation": cultivation_data['final_enhanced_description'],
|
144 |
+
"translated_word": alt_data['translated_word'],
|
145 |
+
"description": alt_data['description']
|
146 |
+
}
|
147 |
+
self.textdb.add_data(new_data)
|
148 |
+
print("Added new data to textdb: ", new_data["name_in_cultivation"])
|
149 |
+
|
150 |
+
image_data = {
|
151 |
+
'image_name': image_path,
|
152 |
+
'keyword': new_data['keyword'],
|
153 |
+
'translated_word': new_data['translated_word']
|
154 |
+
}
|
155 |
+
self.imgdb.add_image( image_data, True, image_feature )
|
156 |
+
print("Added new image to imgdb: ", image_data["keyword"])
|
157 |
+
|
158 |
+
cultivation_data = new_data
|
159 |
+
|
160 |
+
return cultivation_data
|
161 |
+
|
162 |
+
|
163 |
+
|
164 |
+
if __name__ == "__main__":
|
165 |
+
os.environ['HTTP_PROXY'] = 'http://localhost:8234'
|
166 |
+
os.environ['HTTPS_PROXY'] = 'http://localhost:8234'
|
167 |
+
|
168 |
+
game_master = GameMaster()
|
169 |
+
|
170 |
+
target_folder="temp_images"
|
171 |
+
|
172 |
+
image_files = glob(os.path.join(target_folder, "*.jpg"))
|
173 |
+
|
174 |
+
for index, image_path in enumerate(image_files):
|
175 |
+
print("index:" , index )
|
176 |
+
|
177 |
+
search_result, backup_results, image_feature = game_master.search_with_path(image_path)
|
178 |
+
|
179 |
+
if search_result:
|
180 |
+
print(search_result)
|
181 |
+
|
182 |
+
break
|
183 |
+
|
184 |
+
test_image_path = "temp_images/向日葵.jpg"
|
185 |
+
|
186 |
+
search_result, backup_results, image_feature = game_master.search_with_path(test_image_path)
|
187 |
+
cultivation_data = game_master.generate_cultivation_data( \
|
188 |
+
test_image_path, image_feature, backup_results )
|
189 |
+
print(cultivation_data)
|
src/ImageBase.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import os
|
3 |
+
from tqdm import tqdm
|
4 |
+
import numpy as np
|
5 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
6 |
+
|
7 |
+
class Imagebase:
|
8 |
+
def __init__(self, parquet_path=None):
|
9 |
+
self.default_parquet_path = 'datas/imagebase.parquet'
|
10 |
+
self.parquet_path = parquet_path or self.default_parquet_path
|
11 |
+
self.datas = None
|
12 |
+
|
13 |
+
if os.path.exists(self.parquet_path):
|
14 |
+
self.load_from_parquet(self.parquet_path)
|
15 |
+
|
16 |
+
self.clip_extractor = None
|
17 |
+
|
18 |
+
def random_sample(self, num_samples=12):
|
19 |
+
if self.datas is not None:
|
20 |
+
return self.datas.sample(num_samples).to_dict(orient='records')
|
21 |
+
else:
|
22 |
+
return []
|
23 |
+
|
24 |
+
def load_from_parquet(self, parquet_path):
|
25 |
+
self.datas = pd.read_parquet(parquet_path)
|
26 |
+
|
27 |
+
def save_to_parquet(self, parquet_path=None):
|
28 |
+
parquet_path = parquet_path or self.default_parquet_path
|
29 |
+
if self.datas is not None:
|
30 |
+
self.datas.to_parquet(parquet_path)
|
31 |
+
|
32 |
+
def init_clip_extractor(self):
|
33 |
+
if self.clip_extractor is None:
|
34 |
+
try:
|
35 |
+
from CLIPExtractor import CLIPExtractor
|
36 |
+
except:
|
37 |
+
from src.CLIPExtractor import CLIPExtractor
|
38 |
+
|
39 |
+
cache_dir = "D:\\aistudio\\LubaoGithub\\models"
|
40 |
+
self.clip_extractor = CLIPExtractor(model_name="openai/clip-vit-large-patch14", cache_dir=cache_dir)
|
41 |
+
|
42 |
+
def top_k_search(self, query_feature, top_k=15):
|
43 |
+
if self.datas is None:
|
44 |
+
return []
|
45 |
+
if 'clip_feature' not in self.datas.columns:
|
46 |
+
raise ValueError("clip_feature column not found in the data.")
|
47 |
+
|
48 |
+
query_feature = np.array(query_feature).reshape(1, -1)
|
49 |
+
attribute_features = np.stack(self.datas['clip_feature'].dropna().values)
|
50 |
+
|
51 |
+
similarities = cosine_similarity(query_feature, attribute_features)[0]
|
52 |
+
|
53 |
+
top_k_indices = np.argsort(similarities)[-top_k:][::-1]
|
54 |
+
|
55 |
+
top_k_results = self.datas.iloc[top_k_indices].copy()
|
56 |
+
|
57 |
+
top_k_results['similarity'] = similarities[top_k_indices]
|
58 |
+
|
59 |
+
# Drop the 'clip_feature' column
|
60 |
+
top_k_results = top_k_results.drop(columns=['clip_feature'])
|
61 |
+
|
62 |
+
return top_k_results.to_dict(orient='records')
|
63 |
+
|
64 |
+
|
65 |
+
def search_with_image_name(self, image_name):
|
66 |
+
self.init_clip_extractor()
|
67 |
+
|
68 |
+
img_feature = self.clip_extractor.extract_image_from_file(image_name)
|
69 |
+
|
70 |
+
return self.top_k_search(img_feature)
|
71 |
+
|
72 |
+
def search_with_image(self, image, if_opencv=False):
|
73 |
+
self.init_clip_extractor()
|
74 |
+
|
75 |
+
img_feature = self.clip_extractor.extract_image(image, if_opencv=if_opencv)
|
76 |
+
|
77 |
+
return self.top_k_search(img_feature)
|
78 |
+
|
79 |
+
def add_image(self, data, if_save = True, image_feature = None):
|
80 |
+
required_fields = ['image_name', 'keyword', 'translated_word']
|
81 |
+
if not all(field in data for field in required_fields):
|
82 |
+
raise ValueError(f"Data must contain the following fields: {required_fields}")
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
image_name = data['image_name']
|
87 |
+
if image_feature is None:
|
88 |
+
self.init_clip_extractor()
|
89 |
+
data['clip_feature'] = self.clip_extractor.extract_image_from_file(image_name)
|
90 |
+
else:
|
91 |
+
data['clip_feature'] = image_feature
|
92 |
+
|
93 |
+
if self.datas is None:
|
94 |
+
self.datas = pd.DataFrame([data])
|
95 |
+
else:
|
96 |
+
self.datas = pd.concat([self.datas, pd.DataFrame([data])], ignore_index=True)
|
97 |
+
if if_save:
|
98 |
+
self.save_to_parquet()
|
99 |
+
|
100 |
+
def add_images(self, datas):
|
101 |
+
for data in datas:
|
102 |
+
self.add_image(data, if_save=False)
|
103 |
+
self.save_to_parquet()
|
104 |
+
|
105 |
+
import os
|
106 |
+
from glob import glob
|
107 |
+
|
108 |
+
def scan_and_update_imagebase(db, target_folder="temp_images"):
|
109 |
+
# 获取target_folder目录下所有.jpg文件
|
110 |
+
image_files = glob(os.path.join(target_folder, "*.jpg"))
|
111 |
+
|
112 |
+
duplicate_count = 0
|
113 |
+
added_count = 0
|
114 |
+
|
115 |
+
for image_path in image_files:
|
116 |
+
# 使用文件名作为keyword
|
117 |
+
keyword = os.path.basename(image_path).rsplit('.', 1)[0]
|
118 |
+
translated_word = keyword # 可以根据需要调整translated_word
|
119 |
+
|
120 |
+
# 搜索数据库中是否有相似的图片
|
121 |
+
results = db.search_with_image_name(image_path)
|
122 |
+
|
123 |
+
if results and results[0]['similarity'] > 0.9:
|
124 |
+
print(f"Image '{image_path}' is considered a duplicate.")
|
125 |
+
duplicate_count += 1
|
126 |
+
else:
|
127 |
+
new_image_data = {
|
128 |
+
'image_name': image_path,
|
129 |
+
'keyword': keyword,
|
130 |
+
'translated_word': translated_word
|
131 |
+
}
|
132 |
+
db.add_image(new_image_data)
|
133 |
+
print(f"Image '{image_path}' added to the database.")
|
134 |
+
added_count += 1
|
135 |
+
|
136 |
+
print(f"Total duplicate images found: {duplicate_count}")
|
137 |
+
print(f"Total new images added to the database: {added_count}")
|
138 |
+
|
139 |
+
if __name__ == '__main__':
|
140 |
+
img_db = Imagebase()
|
141 |
+
|
142 |
+
# 目���目录
|
143 |
+
target_folder = "temp_images"
|
144 |
+
|
145 |
+
# 扫描并更新数据库
|
146 |
+
scan_and_update_imagebase(img_db, target_folder)
|
147 |
+
|
148 |
+
# Usage example
|
149 |
+
# img_db = Imagebase()
|
150 |
+
|
151 |
+
# new_image_data = {
|
152 |
+
# 'image_name': "datas/老虎.jpg",
|
153 |
+
# 'keyword': 'tiger',
|
154 |
+
# 'translated_word': '老虎'
|
155 |
+
# }
|
156 |
+
|
157 |
+
# img_db.add_image(new_image_data)
|
158 |
+
|
159 |
+
# image_path = "datas/老虎.jpg"
|
160 |
+
# results = img_db.search_with_image_name(image_path)
|
161 |
+
# for result in results[:3]:
|
162 |
+
# print(result)
|
src/ZhipuClient.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from zhipuai import ZhipuAI
|
2 |
+
import os
|
3 |
+
|
4 |
+
class ZhipuClient:
|
5 |
+
def __init__(self, api_key_file_path = None):
|
6 |
+
if api_key_file_path is None:
|
7 |
+
cands = ['./datas/zhipu_key.txt', '../datas/zhipu_key.txt']
|
8 |
+
flag = False
|
9 |
+
for cand in cands:
|
10 |
+
if os.path.exists(cand):
|
11 |
+
api_key_file_path = cand
|
12 |
+
flag = True
|
13 |
+
break
|
14 |
+
if not flag:
|
15 |
+
raise ValueError("No valid api key file found.")
|
16 |
+
|
17 |
+
self.api_key = self._load_access_token(api_key_file_path)
|
18 |
+
self.client = ZhipuAI(api_key=self.api_key)
|
19 |
+
|
20 |
+
def _load_access_token(self, file_path):
|
21 |
+
with open(file_path, 'r') as file:
|
22 |
+
return file.read().strip()
|
23 |
+
|
24 |
+
def prompt2response(self, prompt):
|
25 |
+
response = self.client.chat.completions.create(
|
26 |
+
model="glm-4", # 填写需要调用的模型名称
|
27 |
+
messages=[
|
28 |
+
{"role": "user", "content": prompt}
|
29 |
+
],
|
30 |
+
)
|
31 |
+
return response.choices[0].message.content
|
32 |
+
|
33 |
+
# Usage:
|
34 |
+
# zhipu_client = ZhipuClient('../datas/zhipu_key.txt')
|
35 |
+
# response = zhipu_client.prompt2response('Your prompt here')
|
src/__pycache__/Captioner.cpython-310.pyc
ADDED
Binary file (4.22 kB). View file
|
|
src/__pycache__/Database.cpython-310.pyc
ADDED
Binary file (7 kB). View file
|
|
src/__pycache__/GameMaster.cpython-310.pyc
ADDED
Binary file (4.83 kB). View file
|
|
src/__pycache__/ImageBase.cpython-310.pyc
ADDED
Binary file (4.53 kB). View file
|
|
src/__pycache__/ZhipuClient.cpython-310.pyc
ADDED
Binary file (1.34 kB). View file
|
|
src/__pycache__/generate_cultivation.cpython-310.pyc
ADDED
Binary file (5 kB). View file
|
|
src/__pycache__/get_major_object.cpython-310.pyc
ADDED
Binary file (4.22 kB). View file
|
|
src/__pycache__/text_embedding.cpython-310.pyc
ADDED
Binary file (7.49 kB). View file
|
|
src/generate_cultivation.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
def data2reference( top_k_items, output_n = 3 ):
|
4 |
+
outputted_items = set()
|
5 |
+
|
6 |
+
output_str = "#Reference:\n"
|
7 |
+
|
8 |
+
for item in top_k_items:
|
9 |
+
item_in_life = item["keyword"]
|
10 |
+
if item_in_life in outputted_items:
|
11 |
+
continue
|
12 |
+
name_in_cultivation = item["name_in_cultivation"]
|
13 |
+
description_in_cultivation = item["description_in_cultivation"]
|
14 |
+
# output_str += f"name_in_life: {item_in_life}\n"
|
15 |
+
# output_str += f"name_in_cultivation: {name_in_cultivation}\n"
|
16 |
+
# output_str += f"description_in_cultivation: {description_in_cultivation}\n\n"
|
17 |
+
# output with into json format
|
18 |
+
output_data = {
|
19 |
+
"name_in_life": item_in_life,
|
20 |
+
"name_in_cultivation": name_in_cultivation,
|
21 |
+
"description_in_cultivation": description_in_cultivation
|
22 |
+
}
|
23 |
+
output_str += json.dumps(output_data, ensure_ascii=False) + "\n\n"
|
24 |
+
|
25 |
+
outputted_items.add(item_in_life)
|
26 |
+
if len(outputted_items) >= output_n:
|
27 |
+
break
|
28 |
+
return output_str.strip()
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
def data2prompt(query_item , top_k_items):
|
33 |
+
|
34 |
+
reference_prompt = data2reference(top_k_items, 3)
|
35 |
+
|
36 |
+
task_prompt1 = "\n请参考Reference中的物品描述,将Input中的输入物品,联系改写成修仙世界中的对应物品\n"
|
37 |
+
|
38 |
+
input_prompt = "# Input:\n"
|
39 |
+
if "keyword" in query_item:
|
40 |
+
input_prompt += f"input_name:{query_item['keyword']}\n"
|
41 |
+
if "description" in query_item:
|
42 |
+
input_prompt += f"description_in_life:{query_item['description']}\n"
|
43 |
+
else:
|
44 |
+
# directly dump query_item
|
45 |
+
input_prompt += json.dumps(query_item, ensure_ascii=False) + "\n"
|
46 |
+
|
47 |
+
CoT_prompt = \
|
48 |
+
"""Let's think it step by step,以json形式输出逐个字段。包含以下字段
|
49 |
+
- name_in_life: 进一步明确要生成描述的物品名称
|
50 |
+
- name_in_cultivation_1: 尝试编写物品在修仙界对应的名称
|
51 |
+
- description_in_cultivation_1: 尝试编写物品在修仙界对应的描述
|
52 |
+
- echo_1: "我将分析description_in_cultivation_1与Reference中的差异,分析description_in_cultivation_1是否已经足够生动"
|
53 |
+
- critique: 相比于Reference中的描述,分析description_in_cultivation_1在哪些方面有所欠缺
|
54 |
+
- echo_2: "根据input_name和description_in_cultivation_1,我将分析从物体的哪些属性,可以进一步加强、夸张和修改描述"
|
55 |
+
- analysis: 分析从物体的哪些属性,可以进一步加强、夸张和修改描述
|
56 |
+
- echo_3: "我将尝试3次,从不同角度加强description_in_cultivation_1的描述"
|
57 |
+
- candidate_descriptions: 从不同角度,输出3次不同的加强后的描述
|
58 |
+
- analysis_candidates: 分析各个candidates有什么优点
|
59 |
+
- echo_4: "根据analysis_candidates,我将merge出一个最终的描述"
|
60 |
+
- final_enhanced_description: 通过各个candidates的优点, merge出一个最终的描述
|
61 |
+
- echo_5: "我将分析根据final_description,是否简易将物品名称替换为新的名词"
|
62 |
+
- name_fit_analysis: 分析item_name是否还匹配final_description的描述,是否需要给input_name起一个更响亮的名字
|
63 |
+
- new_name: 如果需要,给input_name起一个更响亮的名字, 如果不需要,则仍然输出name_in_cultivation_1
|
64 |
+
"""
|
65 |
+
|
66 |
+
return reference_prompt + task_prompt1 + input_prompt + CoT_prompt
|
67 |
+
|
68 |
+
try:
|
69 |
+
from src.ZhipuClient import ZhipuClient
|
70 |
+
except:
|
71 |
+
from ZhipuClient import ZhipuClient
|
72 |
+
|
73 |
+
zhipu_client = None
|
74 |
+
|
75 |
+
|
76 |
+
import json
|
77 |
+
|
78 |
+
def markdown_to_json(markdown_str):
|
79 |
+
# 移除Markdown语法中可能存在的标记,如代码块标记等
|
80 |
+
if markdown_str.startswith("```json"):
|
81 |
+
markdown_str = markdown_str[7:-3].strip()
|
82 |
+
elif markdown_str.startswith("```"):
|
83 |
+
markdown_str = markdown_str[3:-3].strip()
|
84 |
+
|
85 |
+
# 将字符串转换为JSON字典
|
86 |
+
json_dict = json.loads(markdown_str)
|
87 |
+
|
88 |
+
return json_dict
|
89 |
+
|
90 |
+
import re
|
91 |
+
|
92 |
+
def forced_extract(input_str, keywords):
|
93 |
+
result = {key: "" for key in keywords}
|
94 |
+
|
95 |
+
for key in keywords:
|
96 |
+
# 使用正则表达式来查找关键词-值对
|
97 |
+
pattern = f'"{key}":\s*"(.*?)"'
|
98 |
+
match = re.search(pattern, input_str)
|
99 |
+
if match:
|
100 |
+
result[key] = match.group(1)
|
101 |
+
|
102 |
+
return result
|
103 |
+
|
104 |
+
def generate_cultivation_with_rag( query_item, search_result ):
|
105 |
+
global zhipu_client
|
106 |
+
if zhipu_client is None:
|
107 |
+
zhipu_client = ZhipuClient()
|
108 |
+
prompt = data2prompt(query_item, search_result)
|
109 |
+
response = zhipu_client.prompt2response(prompt)
|
110 |
+
|
111 |
+
try:
|
112 |
+
json_response = markdown_to_json(response)
|
113 |
+
except:
|
114 |
+
keyword_list = ["name_in_life", "name_in_cultivation_1","description_in_cultivation_1", "final_enhanced_description", "new_name"]
|
115 |
+
json_response = forced_extract(response, keyword_list)
|
116 |
+
|
117 |
+
if "new_name" not in json_response or json_response["new_name"] == "":
|
118 |
+
if "name_in_cultivation_1" in json_response:
|
119 |
+
json_response["new_name"] = json_response["name_in_cultivation_1"]
|
120 |
+
else:
|
121 |
+
json_response["new_name"] = ""
|
122 |
+
|
123 |
+
if "final_enhanced_description" not in json_response or json_response["final_enhanced_description"] == "":
|
124 |
+
if "description_in_cultivation_1" in json_response:
|
125 |
+
json_response["final_enhanced_description"] = json_response["description_in_cultivation_1"]
|
126 |
+
else:
|
127 |
+
json_response["final_enhanced_description"] = json_response["new_name"]
|
128 |
+
|
129 |
+
|
130 |
+
return json_response
|
131 |
+
|
132 |
+
if __name__ == '__main__':
|
133 |
+
try:
|
134 |
+
from src.Database import Database
|
135 |
+
except:
|
136 |
+
from Database import Database
|
137 |
+
|
138 |
+
db = Database()
|
139 |
+
|
140 |
+
try:
|
141 |
+
from src.Captioner import Captioner
|
142 |
+
except:
|
143 |
+
from Captioner import Captioner
|
144 |
+
|
145 |
+
import os
|
146 |
+
os.environ['HTTP_PROXY'] = 'http://localhost:8234'
|
147 |
+
os.environ['HTTPS_PROXY'] = 'http://localhost:8234'
|
148 |
+
|
149 |
+
|
150 |
+
captioner = Captioner()
|
151 |
+
|
152 |
+
test_image = "temp_images/3or47vg0.jpg"
|
153 |
+
caption_response = captioner.caption(test_image)
|
154 |
+
|
155 |
+
# print(caption_response)
|
156 |
+
|
157 |
+
search_result = db.search_with_image_name( test_image )
|
158 |
+
|
159 |
+
# print(search_result[0].keys())
|
160 |
+
# reference_str = data2reference(search_result, output_n = 3)
|
161 |
+
# print(reference_str)
|
162 |
+
|
163 |
+
seen = set()
|
164 |
+
keywords = [res['translated_word'] for res in search_result if not (res['translated_word'] in seen or seen.add(res['translated_word']))]
|
165 |
+
# print(keywords)
|
166 |
+
|
167 |
+
# prompt = data2prompt(caption_response , keywords)
|
168 |
+
# print(prompt)
|
169 |
+
from get_major_object import get_major_object, verify_keyword_in_base
|
170 |
+
|
171 |
+
json_response = get_major_object(caption_response , keywords)
|
172 |
+
|
173 |
+
print(json_response)
|
174 |
+
|
175 |
+
print()
|
176 |
+
|
177 |
+
in_base_data , alt_data = verify_keyword_in_base(json_response , db)
|
178 |
+
|
179 |
+
if alt_data is not None:
|
180 |
+
result = generate_cultivation_with_rag(alt_data , search_result)
|
181 |
+
print(result)
|
182 |
+
|
183 |
+
|
src/get_comments_from_level.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def get_comments_from_level(inbase_similarity_level, inlibrary_similarity_level):
|
2 |
+
name = "{name}" # Placeholder for the item name
|
3 |
+
|
4 |
+
# Define comments based on the similarity levels
|
5 |
+
comments_mapping = {
|
6 |
+
(0, 0): f"道友,在下才疏学浅不太认识这个东西,我感觉这个东西有点像古籍上所说的{name},还要找经验更丰富的长老来确定下",
|
7 |
+
(0, 1): f"这个东西似乎有些眼熟,但天机阁尚未收录,古籍上的描述也有些模糊,可能需要长老们进一步鉴定,我猜是{name}",
|
8 |
+
(0, 2): f"这东西颇为罕见,天机阁未曾有过记录,但古籍中的描述与{name}颇为相似,我将呈给长老们以作鉴定",
|
9 |
+
(0, 3): f"想必这一定是{name}吧,虽然天机阁还没有收录过这个东西,倒是和修仙古籍上的记载非常相像。我赶紧拿给长老再鉴定下",
|
10 |
+
(1, 0): f"天机阁的记录中似乎没有这个东西,但我依稀记得古籍中提到过{name},还需长老进一步确认",
|
11 |
+
(1, 1): f"这个物品有些特别,天机阁的记录不多,古籍中的描述也只是一笔带过,可能是{name},还需长老鉴定",
|
12 |
+
(1, 2): f"此物颇为罕见,天机阁记录较少,但古籍中的描述与{name}有一定相似之处,长老们或能给出答案",
|
13 |
+
(1, 3): f"虽然古籍中对{name}的描述详细,但天机阁中却鲜有记录,或许这是一件稀世之宝",
|
14 |
+
(2, 0): f"天机阁中对此物知之甚少,但古籍中曾提到{name},这件物品或许不简单,需长老们鉴定",
|
15 |
+
(2, 1): f"天机阁中对此物的记录不多,古籍中对{name}的描述也有限,但似乎是一件非凡之物",
|
16 |
+
(2, 2): f"这件物品在古籍中有所记载,天机阁也有少量收录,看来是{name}无疑,但还需长老确认",
|
17 |
+
(2, 3): f"虽然在古籍中有记载,天机阁过往有一点点收录,但也算稀世珍宝,{name}确实非凡",
|
18 |
+
(3, 0): f"天机阁中没有记录,但古籍中对{name}的描述颇为详细,这件物品可能是个谜",
|
19 |
+
(3, 1): f"天机阁中记录较少,但古籍中对{name}的描述详尽,这件物品或许有着不同寻常的来历",
|
20 |
+
(3, 2): f"古籍中记载{name}颇多,天机阁中也有所收录,看来这东西并不罕见",
|
21 |
+
(3, 3): f"{name}这种东西很常见啊,天机阁的库房里面都有不少呢"
|
22 |
+
}
|
23 |
+
|
24 |
+
# Return the appropriate comment based on the similarity levels
|
25 |
+
return comments_mapping.get((inbase_similarity_level, inlibrary_similarity_level), "道友,我会给出初步的鉴定") + "。"
|
26 |
+
|
27 |
+
# Example usage:
|
28 |
+
# comments = get_comments_from_level(2, 3)
|
29 |
+
# print(comments)
|
src/get_major_object.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def data2prompt(caption_response, ref_words):
|
2 |
+
|
3 |
+
ref_word_str = ",".join(ref_words[:5])
|
4 |
+
|
5 |
+
task_prompt = "Based on the following Caption Response, you will output a description of the Major Object's name."
|
6 |
+
|
7 |
+
input_str = "# Caption Response:\n" + caption_response + "\n"
|
8 |
+
|
9 |
+
CoT_prompt = \
|
10 |
+
f"""
|
11 |
+
Let's think it step by step. Output each field in JSON format. Include the following fields:
|
12 |
+
- major_object: From the caption response, identify the major_object. If not present, extract it again from the detailed_description or caption_response.
|
13 |
+
- better_major_object: Reread the description in the caption response to see if there's a more suitable word for the major object. If not, still output major_object.
|
14 |
+
- echo_1: "I will generate a simple description in about 200 words in English for the better_major_object, introducing what the input object is."
|
15 |
+
- description: Generate a WIKI description for the better_major_object (explain what is the better_major_object).
|
16 |
+
- major_object_chinese: Translate the better_major_object into Chinese.
|
17 |
+
- echo_2: "I will check whether there is synonym of the major_object_chinese in the '{ref_word_str}'."
|
18 |
+
- synonym: If present, output the synonym directly; otherwise, output "NOT_INCLUDED."
|
19 |
+
- recheck: Based on the content of the Caption Response, determine whether the synonym is accurate. If accurate, output "ACCURATE"; otherwise, output "NOT_ACCURATE."
|
20 |
+
"""
|
21 |
+
return task_prompt + input_str + CoT_prompt
|
22 |
+
|
23 |
+
|
24 |
+
# def data2prompt(caption_response , ref_words ):
|
25 |
+
|
26 |
+
|
27 |
+
# ref_word_str = ",".join(ref_words[:5])
|
28 |
+
|
29 |
+
# ref_str = "# Reference Word:\n"+ref_word_str+"\n\n"
|
30 |
+
|
31 |
+
# task_prompt = "你将根据下面的Caption Response,输出Major Object的名称描述"
|
32 |
+
|
33 |
+
# input_str = "# Caption Response:\n"+caption_response+"\n"
|
34 |
+
|
35 |
+
# CoT_prompt = \
|
36 |
+
# """
|
37 |
+
# Let's think it step by step,以json形式输出逐个字段。包含以下字段
|
38 |
+
# - major_object: 从caption response中,确认major_object,如果没有,则从detailed_description或者caption_response中重新抽取
|
39 |
+
# - better_major_object: 重新阅读caption response中的描述,看看是否有更合适的major object的词语,如果没有则仍然输出major_object
|
40 |
+
# - echo_1: "I will generate a simple description in about 200 words in English for the input word, introducing what the input object is"
|
41 |
+
# - description: generate the description for the input object
|
42 |
+
# - major_object_chinese: 将major_object翻译为中文
|
43 |
+
# - echo_2: "我将判断reference word中,是否存在major_object的同义词"
|
44 |
+
# - 同义词: 如果存在,则直接输出同义词,否则输出"NOT_INCLUDED"
|
45 |
+
# - recheck: 结合Caption Response的内容,判断同义词是否准确,如果准确,则输出"ACCURATE",否则输出"NOT_ACCURATE"
|
46 |
+
# """
|
47 |
+
# return ref_str+task_prompt+input_str+CoT_prompt
|
48 |
+
|
49 |
+
try:
|
50 |
+
from src.ZhipuClient import ZhipuClient
|
51 |
+
except:
|
52 |
+
from ZhipuClient import ZhipuClient
|
53 |
+
|
54 |
+
zhipu_client = None
|
55 |
+
|
56 |
+
import json
|
57 |
+
|
58 |
+
def markdown_to_json(markdown_str):
|
59 |
+
# 移除Markdown语法中可能存在的标记,如代码块标记等
|
60 |
+
if markdown_str.startswith("```json"):
|
61 |
+
markdown_str = markdown_str[7:-3].strip()
|
62 |
+
elif markdown_str.startswith("```"):
|
63 |
+
markdown_str = markdown_str[3:-3].strip()
|
64 |
+
|
65 |
+
# 将字符串转换为JSON字典
|
66 |
+
json_dict = json.loads(markdown_str)
|
67 |
+
|
68 |
+
return json_dict
|
69 |
+
|
70 |
+
import re
|
71 |
+
|
72 |
+
def forced_extract(input_str, keywords):
|
73 |
+
result = {key: "" for key in keywords}
|
74 |
+
|
75 |
+
for key in keywords:
|
76 |
+
# 使用正则表达式来查找关键词-值对
|
77 |
+
pattern = f'"{key}":\s*"(.*?)"'
|
78 |
+
match = re.search(pattern, input_str)
|
79 |
+
if match:
|
80 |
+
result[key] = match.group(1)
|
81 |
+
|
82 |
+
return result
|
83 |
+
|
84 |
+
def get_major_object(caption_response, ref_words):
|
85 |
+
global zhipu_client
|
86 |
+
if zhipu_client is None:
|
87 |
+
zhipu_client = ZhipuClient()
|
88 |
+
prompt = data2prompt(caption_response , ref_words)
|
89 |
+
response = zhipu_client.prompt2response(prompt)
|
90 |
+
|
91 |
+
try:
|
92 |
+
json_response = markdown_to_json(response)
|
93 |
+
except:
|
94 |
+
keyword_list = ["major_object", "better_major_object", "description", "major_object_chinese", "synonym", "recheck"]
|
95 |
+
json_response = forced_extract(response, keyword_list)
|
96 |
+
|
97 |
+
return json_response
|
98 |
+
|
99 |
+
def verify_keyword_in_base( json_response , database ):
|
100 |
+
|
101 |
+
keyword2verify = []
|
102 |
+
if "better_major_object" in json_response:
|
103 |
+
keyword2verify.append(json_response["better_major_object"].lower())
|
104 |
+
|
105 |
+
if "major_object" in json_response:
|
106 |
+
keyword2verify.append(json_response["major_object"].lower())
|
107 |
+
|
108 |
+
if "recheck" in json_response and json_response["recheck"] == "ACCURATE":
|
109 |
+
if "synonym" in json_response and json_response["synonym"] != "NOT_INCLUDED":
|
110 |
+
keyword2verify.append(json_response["synonym"].lower())
|
111 |
+
|
112 |
+
ans = None
|
113 |
+
|
114 |
+
for keyword in keyword2verify:
|
115 |
+
res = database.search_by_en_keyword(keyword)
|
116 |
+
if res is None:
|
117 |
+
continue
|
118 |
+
ans = res
|
119 |
+
return ans, None
|
120 |
+
|
121 |
+
if len(keyword2verify) == 0:
|
122 |
+
return None, None
|
123 |
+
|
124 |
+
# 这里我们需要一个新的data, keyword是中文名, translated_word是英文名,description是英文描述
|
125 |
+
description = keyword2verify[0]
|
126 |
+
if "description" in json_response:
|
127 |
+
description = json_response["description"]
|
128 |
+
|
129 |
+
translated_word = keyword2verify[0]
|
130 |
+
|
131 |
+
keyword = translated_word
|
132 |
+
if "major_object_chinese" in json_response:
|
133 |
+
keyword = json_response["major_object_chinese"]
|
134 |
+
|
135 |
+
data = {
|
136 |
+
"keyword": keyword,
|
137 |
+
"translated_word": translated_word,
|
138 |
+
"description": description
|
139 |
+
}
|
140 |
+
|
141 |
+
return None, data
|
142 |
+
|
143 |
+
|
144 |
+
|
145 |
+
|
146 |
+
if __name__ == '__main__':
|
147 |
+
try:
|
148 |
+
from src.Database import Database
|
149 |
+
except:
|
150 |
+
from Database import Database
|
151 |
+
|
152 |
+
db = Database()
|
153 |
+
|
154 |
+
try:
|
155 |
+
from src.Captioner import Captioner
|
156 |
+
except:
|
157 |
+
from Captioner import Captioner
|
158 |
+
|
159 |
+
import os
|
160 |
+
os.environ['HTTP_PROXY'] = 'http://localhost:8234'
|
161 |
+
os.environ['HTTPS_PROXY'] = 'http://localhost:8234'
|
162 |
+
|
163 |
+
|
164 |
+
captioner = Captioner()
|
165 |
+
|
166 |
+
test_image = "temp_images/3or47vg0.jpg"
|
167 |
+
caption_response = captioner.caption(test_image)
|
168 |
+
|
169 |
+
# print(caption_response)
|
170 |
+
|
171 |
+
search_result = db.search_with_image_name( test_image )
|
172 |
+
|
173 |
+
seen = set()
|
174 |
+
keywords = [res['translated_word'] for res in search_result if not (res['translated_word'] in seen or seen.add(res['translated_word']))]
|
175 |
+
# print(keywords)
|
176 |
+
|
177 |
+
# prompt = data2prompt(caption_response , keywords)
|
178 |
+
# print(prompt)
|
179 |
+
|
180 |
+
json_response = get_major_object(caption_response , keywords)
|
181 |
+
|
182 |
+
print(json_response)
|
183 |
+
|
184 |
+
print()
|
185 |
+
|
186 |
+
in_base_data , alt_data = verify_keyword_in_base(json_response , db)
|
187 |
+
|
188 |
+
if in_base_data is not None:
|
189 |
+
print(in_base_data)
|
190 |
+
|
191 |
+
if alt_data is not None:
|
192 |
+
print(alt_data)
|
193 |
+
|
194 |
+
# {'keyword': '埃菲尔铁塔', 'translated_word': 'eiffel tower', 'description': "The Eiffel Tower is an iconic symbol of Paris and one of the most recognizable stru
|
195 |
+
# ower', 'description': "The Eiffel Tower is an iconic symbol of Paris and one of the most recognizable structures in the world. Designed and constructed by the engineer Gustave Eiffel and his company for the 1889 Exposition Universelle (World's Fair) to celebrate the 100th anniversary of the French Revolution, the tower was initially criticized by some of France's leading artists and intellectuals. However, it quickly became a beloved landmark and a symbol of French pride. Standing 324 meters tall, the tower is made of wrought iron and consists of thousands of metal parts, including over 18,000 individual iron rivets. It is renowned for its architectural and engineering design, and it is visited by millions of people each year, making it one of the most visited paid monuments in the world."}
|
src/text_embedding.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModel
|
3 |
+
import os
|
4 |
+
|
5 |
+
class TextExtractor:
|
6 |
+
def __init__(self, model_name, proxy=None):
|
7 |
+
"""
|
8 |
+
Initialize the TextExtractor with a specified model and optional proxy settings.
|
9 |
+
|
10 |
+
Parameters:
|
11 |
+
- model_name (str): The name of the pre-trained model to load from HuggingFace Hub.
|
12 |
+
- proxy (str, optional): The proxy address to use for HTTP and HTTPS requests.
|
13 |
+
"""
|
14 |
+
if proxy is None:
|
15 |
+
proxy = 'http://localhost:8234'
|
16 |
+
|
17 |
+
if proxy:
|
18 |
+
os.environ['HTTP_PROXY'] = proxy
|
19 |
+
os.environ['HTTPS_PROXY'] = proxy
|
20 |
+
try:
|
21 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
22 |
+
self.model = AutoModel.from_pretrained(model_name)
|
23 |
+
except:
|
24 |
+
print('try switch on local_files_only')
|
25 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True)
|
26 |
+
self.model = AutoModel.from_pretrained(model_name, local_files_only=True)
|
27 |
+
|
28 |
+
self.model.eval()
|
29 |
+
|
30 |
+
def extract(self, sentences):
|
31 |
+
"""
|
32 |
+
Extract sentence embeddings for the provided sentences.
|
33 |
+
|
34 |
+
Parameters:
|
35 |
+
- sentences (list of str): A list of sentences to extract embeddings for.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
- torch.Tensor: The normalized sentence embeddings.
|
39 |
+
"""
|
40 |
+
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
|
41 |
+
|
42 |
+
with torch.no_grad():
|
43 |
+
model_output = self.model(**encoded_input)
|
44 |
+
sentence_embeddings = model_output[0][:, 0]
|
45 |
+
|
46 |
+
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
|
47 |
+
return sentence_embeddings
|
48 |
+
|
49 |
+
import pandas as pd
|
50 |
+
def get_qas(excel_file = None):
|
51 |
+
|
52 |
+
defaule_excel_file = 'data/output_fixid.xlsx'
|
53 |
+
if excel_file is None:
|
54 |
+
excel_file = defaule_excel_file
|
55 |
+
|
56 |
+
# 读取Excel文件
|
57 |
+
df = pd.read_excel(excel_file)
|
58 |
+
|
59 |
+
df = df[df["question"].notna()]
|
60 |
+
df = df[df["summary"].notna()]
|
61 |
+
|
62 |
+
datas = []
|
63 |
+
|
64 |
+
# 遍历DataFrame的每一行
|
65 |
+
for index, row in df.iterrows():
|
66 |
+
id = row['id']
|
67 |
+
question = row['question']
|
68 |
+
short_answer = row['summary']
|
69 |
+
category = row['category']
|
70 |
+
|
71 |
+
texts = [question, short_answer]
|
72 |
+
|
73 |
+
data_value = {
|
74 |
+
"texts":texts,
|
75 |
+
}
|
76 |
+
|
77 |
+
data = {
|
78 |
+
"id":id,
|
79 |
+
"value":data_value
|
80 |
+
}
|
81 |
+
|
82 |
+
datas.append(data)
|
83 |
+
|
84 |
+
return datas
|
85 |
+
|
86 |
+
from tqdm import tqdm
|
87 |
+
|
88 |
+
def extract_embedding(datas, text_extractor):
|
89 |
+
"""
|
90 |
+
Extract embeddings for each item in the provided data.
|
91 |
+
|
92 |
+
Parameters:
|
93 |
+
- datas (list of dict): A list of dictionaries containing text data.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
- list of dict: The input data with added embeddings.
|
97 |
+
"""
|
98 |
+
for data in tqdm(datas):
|
99 |
+
texts = data["value"]["texts"]
|
100 |
+
text = "。".join(texts)
|
101 |
+
embeddings = text_extractor.extract(text)
|
102 |
+
embeddings_list = embeddings.tolist() # Convert tensor to list of lists
|
103 |
+
data["value"]["embedding"] = embeddings_list
|
104 |
+
return datas
|
105 |
+
|
106 |
+
def save_parquet(datas, file_path):
|
107 |
+
"""
|
108 |
+
Save the provided data to a Parquet file.
|
109 |
+
|
110 |
+
Parameters:
|
111 |
+
- datas (list of dict): A list of dictionaries containing text data and embeddings.
|
112 |
+
- file_path (str): The path to the output Parquet file.
|
113 |
+
"""
|
114 |
+
# Flatten the data for easier conversion to DataFrame
|
115 |
+
flattened_data = []
|
116 |
+
for data in datas:
|
117 |
+
id = data["id"]
|
118 |
+
texts = data["value"]["texts"]
|
119 |
+
text = "。".join(texts)
|
120 |
+
embedding = data["value"]["embedding"]
|
121 |
+
flattened_data.append({
|
122 |
+
"id": id,
|
123 |
+
"text": text,
|
124 |
+
"embedding": embedding
|
125 |
+
})
|
126 |
+
|
127 |
+
# Create DataFrame
|
128 |
+
df = pd.DataFrame(flattened_data)
|
129 |
+
|
130 |
+
# Save DataFrame to Parquet
|
131 |
+
df.to_parquet(file_path, index=False)
|
132 |
+
|
133 |
+
import pandas as pd
|
134 |
+
import os
|
135 |
+
|
136 |
+
def get_id2embedding(regen=False, parquet_file='datas/qa_with_embedding.parquet'):
|
137 |
+
"""
|
138 |
+
Get a dictionary mapping IDs to embeddings. Regenerate embeddings if specified.
|
139 |
+
|
140 |
+
Parameters:
|
141 |
+
- parquet_file (str): The path to the Parquet file.
|
142 |
+
- regen (bool): Whether to regenerate embeddings.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
- dict: A dictionary mapping IDs to list of float embeddings.
|
146 |
+
"""
|
147 |
+
if regen or not os.path.exists(parquet_file):
|
148 |
+
print("Regenerating embeddings...")
|
149 |
+
# Example usage:
|
150 |
+
model_name = 'BAAI/bge-small-zh-v1.5'
|
151 |
+
text_extractor = TextExtractor(model_name)
|
152 |
+
|
153 |
+
datas = get_qas()
|
154 |
+
print("Extracting embeddings for", len(datas), "data items")
|
155 |
+
|
156 |
+
datas = extract_embedding(datas, text_extractor)
|
157 |
+
save_parquet(datas, parquet_file)
|
158 |
+
|
159 |
+
df = pd.read_parquet(parquet_file)
|
160 |
+
|
161 |
+
id2embedding = {}
|
162 |
+
for index, row in df.iterrows():
|
163 |
+
id = row['id']
|
164 |
+
embedding = row['embedding']
|
165 |
+
id2embedding[id] = embedding[0]
|
166 |
+
|
167 |
+
return id2embedding
|
168 |
+
|
169 |
+
import torch
|
170 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
171 |
+
import heapq
|
172 |
+
|
173 |
+
def __get_id2top30map(id2embedding):
|
174 |
+
"""
|
175 |
+
Get a dictionary mapping IDs to their top 30 nearest neighbors based on cosine similarity.
|
176 |
+
|
177 |
+
Parameters:
|
178 |
+
- id2embedding (dict): A dictionary mapping IDs to list of float embeddings.
|
179 |
+
|
180 |
+
Returns:
|
181 |
+
- dict: A dictionary mapping each ID to a list of the top 30 nearest neighbor IDs.
|
182 |
+
"""
|
183 |
+
ids = list(id2embedding.keys())
|
184 |
+
embeddings = torch.tensor([id2embedding[id] for id in ids])
|
185 |
+
|
186 |
+
# Compute cosine similarity matrix
|
187 |
+
cos_sim_matrix = cosine_similarity(embeddings)
|
188 |
+
|
189 |
+
id2top30map = {}
|
190 |
+
for i, id in enumerate(ids):
|
191 |
+
# Get the similarity scores for the current ID
|
192 |
+
sim_scores = cos_sim_matrix[i]
|
193 |
+
|
194 |
+
# Get the top 30 indices (excluding the current ID itself)
|
195 |
+
top_indices = heapq.nlargest(31, range(len(sim_scores)), key=lambda x: sim_scores[x])
|
196 |
+
top_indices.remove(i) # Remove the index of the current ID
|
197 |
+
|
198 |
+
# Map the indices back to IDs
|
199 |
+
top_30_ids = [ids[idx] for idx in top_indices[:30]]
|
200 |
+
|
201 |
+
id2top30map[id] = top_30_ids
|
202 |
+
|
203 |
+
return id2top30map
|
204 |
+
|
205 |
+
import pickle
|
206 |
+
|
207 |
+
def get_id2top30map( id2embedding = None ):
|
208 |
+
default_save_pkl = "data/id2top30map.pkl"
|
209 |
+
if id2embedding is None:
|
210 |
+
if os.path.exists(default_save_pkl):
|
211 |
+
with open(default_save_pkl, 'rb') as f:
|
212 |
+
id2top30map = pickle.load(f)
|
213 |
+
else:
|
214 |
+
print("No embedding found, generating new one...")
|
215 |
+
id2embedding = get_id2embedding(regen=False)
|
216 |
+
id2top30map = __get_id2top30map(id2embedding)
|
217 |
+
with open(default_save_pkl, 'wb') as f:
|
218 |
+
pickle.dump(id2top30map, f)
|
219 |
+
else:
|
220 |
+
id2top30map = __get_id2top30map(id2embedding)
|
221 |
+
|
222 |
+
return id2top30map
|
223 |
+
|
224 |
+
|
225 |
+
|
226 |
+
if __name__ == '__main__':
|
227 |
+
if False:
|
228 |
+
# Example usage:
|
229 |
+
model_name = 'BAAI/bge-small-zh-v1.5'
|
230 |
+
sentences = ["样例数据-1", "样例数据-2"]
|
231 |
+
|
232 |
+
text_extractor = TextExtractor(model_name)
|
233 |
+
embeddings = text_extractor.extract(sentences)
|
234 |
+
print("Sentence embeddings:", embeddings)
|
235 |
+
|
236 |
+
datas = get_qas()
|
237 |
+
|
238 |
+
print("extract embedding for ", len(datas), " datas")
|
239 |
+
|
240 |
+
datas = extract_embedding(datas, text_extractor )
|
241 |
+
|
242 |
+
default_parquet_save_name = "data/qa_with_embedding.parquet"
|
243 |
+
|
244 |
+
save_parquet(datas, default_parquet_save_name)
|
245 |
+
if True:
|
246 |
+
id2embedding = get_id2embedding(regen=False)
|
247 |
+
print(len(id2embedding[4]))
|
248 |
+
id2top30map = get_id2top30map( None )
|
249 |
+
print("ID to Top 30 Neighbors dictionary:", id2top30map[4])
|
250 |
+
|
251 |
+
if True:
|
252 |
+
|
253 |
+
start_id = 332
|
254 |
+
visited_ids = [start_id]
|
255 |
+
current_queue = [start_id]
|
256 |
+
|
257 |
+
expend_num = 5
|
258 |
+
|
259 |
+
for iteration in range(10):
|
260 |
+
current_node = current_queue.pop(0)
|
261 |
+
top30 = id2top30map[current_node]
|
262 |
+
current_expend = []
|
263 |
+
for id in top30:
|
264 |
+
if id not in visited_ids:
|
265 |
+
visited_ids.append(id)
|
266 |
+
current_queue.append(id)
|
267 |
+
current_expend.append(id)
|
268 |
+
if len(current_expend) >= expend_num:
|
269 |
+
break
|
270 |
+
display_text = f"{current_node} | ->" + ",".join([str(i) for i in current_expend])
|
271 |
+
print(display_text)
|
272 |
+
|
273 |
+
from get_qa_and_image import get_qa_and_image
|
274 |
+
image_datas = get_qa_and_image()
|
275 |
+
|
276 |
+
id2index = {}
|
277 |
+
|
278 |
+
for i, data in enumerate(image_datas):
|
279 |
+
id2index[data['id']] = i
|
280 |
+
|
281 |
+
indexes = [id2index[i] for i in visited_ids if i in id2index]
|
282 |
+
image_names = [image_datas[index]['value']['image'] for index in indexes]
|
283 |
+
|
284 |
+
target_copy_folder = "data/asso_collection"
|
285 |
+
|
286 |
+
import shutil
|
287 |
+
# copy image into target_copy_folder
|
288 |
+
for image_name in image_names:
|
289 |
+
shutil.copy(image_name, target_copy_folder)
|