TencentVDB_graph_search / pages /initial_and_upload.py
qcloud
1
5bfdfae
import time
import gradio as gr
from huggingface_hub import snapshot_download
import os
import zipfile
from PIL import Image, UnidentifiedImageError
from transformers import AutoProcessor, CLIPModel
from vector_db.vector_db_client import VectorDB
from tcvectordb.model.document import Document
import uuid
import traceback
import numpy as np
# 生成随机的 UUID
LOCAL_MODEL_PATH = "download_model.local_model_path"
MODEL_NAME = "download_model.model_name"
LOCAL_GRAPH_PATH="graph_upload.local_graph_path"
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
init_css="""
<style>
.equal-height-row {
display: flex;
}
.equal-height-column {
flex: 1;
display: flex;
flex-direction: column;
}
.equal-height-column > * {
flex: 1;
}
</style>
"""
class Initial_and_Upload:
def __init__(self, config,vdb: VectorDB):
self.vdb = vdb
self.model_name = config.get(MODEL_NAME)
self.local_model_path = config.get(LOCAL_MODEL_PATH)
self.local_graph_path=config.get(LOCAL_GRAPH_PATH)
self.model_cache_directory = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), self.local_model_path, self.model_name)
self.graph_cache_directory = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), self.local_graph_path)
def initial_model(self):
model = CLIPModel.from_pretrained(self.model_cache_directory)
processor = AutoProcessor.from_pretrained(self.model_cache_directory)
return model,processor
def _download_model(self, model_name, progress=gr.Progress()):
"""
下载指定的Hugging Face模型并保存在指定位置。
参数:
model_name (str): 模型在Hugging Face上的名字。
save_directory (str): 模型文件保存的位置。
"""
os.environ['TRANSFORMERS_CACHE'] = self.model_cache_directory
# 创建保存目录(如果不存在)
if not os.path.exists(self.model_cache_directory):
os.makedirs(self.model_cache_directory)
text = f"[正在尝试下载] 模型 {model_name},因为涉及到模型相关的多个文件下载,进度仅在后台显示。\n"
progress(0.5, desc=text)
try:
# 下载模型
snapshot_download(
repo_id=model_name,
local_dir=self.model_cache_directory,
local_dir_use_symlinks=False,
)
progress(1, f"模型 {model_name} 已下载并保存在 {self.model_cache_directory}")
text += f"模型 {model_name} 已下载并保存在 {self.model_cache_directory}"
time.sleep(0.3)
return text
except Exception as e:
text += f"[下载失败] 失败原因:{e}"
return text
def _process_image(self, image_path,emb_model,processor):
"""
处理单个图片文件,将其转换为向量。
参数:
image_path (str): 图片文件的路径。
返回:
torch.Tensor: 图片的向量表示。
"""
image = Image.open(image_path)
# image.verify() # 验证图片是否有效
inputs = processor(images=image, return_tensors="pt")
image_features = emb_model.get_image_features(**inputs)
return image_features
def _handle_upload(self, file, progress=gr.Progress()):
"""
处理上传的文件,识别是图片还是ZIP压缩包,并将图片转换为向量。
参数:
file (file): 上传的文件。
返回:
str: 文件类型和处理结果。
"""
output_text = ""
image_vectors = []
if not os.path.exists(self.model_cache_directory):
output_text += f"缓存目录 {self.model_cache_directory} 不存在,无法初始化模型。"
else:
model, processor = self.initial_model()
collection = self.vdb.get_collection()
if zipfile.is_zipfile(file.name):
with zipfile.ZipFile(file.name, 'r') as zip_ref:
zip_ref.extractall(self.local_graph_path)
image_files = [file_name for file_name in zip_ref.namelist() if file_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')) and not file_name.startswith('__MACOSX') and not file_name.startswith('._')]
total_files = len(image_files)
for i, file_name in enumerate(image_files):
image_path = os.path.join(self.local_graph_path, file_name)
try:
image_vector = self._process_image(image_path, model, processor).squeeze().tolist() # 转换为一维列表
random_uuid = str(uuid.uuid4()) # 转换为字符串
collection.upsert(documents=[Document(id=random_uuid, vector=image_vector, local_graph_path=image_path)], build_index=True)
output_text += f"处理图片: {file_name}\n"
except UnidentifiedImageError:
output_text += f"无法识别图片文件: {file_name}\n"
# 更新进度
progress((i + 1) / total_files)
output_text += "上传的是ZIP压缩包,已解压缩并处理所有图片。"
else:
try:
# 保存单张图片到指定文件夹
image_path = os.path.join(self.graph_cache_directory, os.path.basename(file.name))
with open(file.name, "rb") as f_src:
with open(image_path, "wb") as f_dst:
f_dst.write(f_src.read())
image_vector = self._process_image(image_path, model, processor).squeeze().tolist() # 转换为一维列表
random_uuid = str(uuid.uuid4()) # 转换为字符串
collection.upsert(documents=[Document(id=random_uuid, vector=image_vector, local_graph_path=image_path)], build_index=True)
output_text += "上传的是图片文件,并已处理。\n"
# 更新进度
progress(1.0)
except (IOError, SyntaxError) as e:
output_text += f"无法识别文件类型:{e}\n"
# 返回处理结果和图片向量
return output_text, image_vectors
def _initialize_vector_db(self, progress=gr.Progress()):
"""
初始化向量数据库。
返回:
str: 初始化结果。
"""
output_text = f"[正在尝试连接] VectorDB {self.vdb.address}\n"
progress(0, desc=output_text)
try:
client = self.vdb.create_client()
client.list_databases()
progress(0.05, f"[连接成功] VectorDB {self.vdb.address}\n")
output_text += f"[连接成功] VectorDB {self.vdb.address}\n"
client.close()
progress(0.1, f"[正在初始化] ai database '{self.vdb.db_name}'\n")
output_text += f"[正在初始化] ai database '{self.vdb.db_name}'\n"
self.vdb.init_database()
progress(0.3, f"[初始化完成] ai database '{self.vdb.db_name}'\n")
output_text += f"[初始化完成] ai database '{self.vdb.db_name}'\n"
progress(0.5, f"[正在初始化] ai collection '{self.vdb.ai_graph_emb_collection}'\n")
output_text += f"[正在初始化] ai collection '{self.vdb.ai_graph_emb_collection}'\n"
self.vdb.init_graph_collection()
progress(0.9, f"[初始化完成] ai collection '{self.vdb.ai_graph_emb_collection}'\n")
output_text += f"[初始化完成] ai collection '{self.vdb.ai_graph_emb_collection}'\n"
progress(1, f"您可以去图片上传栏目上传图片或ZIP压缩包,然后进一步进行[图片搜索]")
output_text += f"您可以去图片上传栏目上传图片或ZIP压缩包,然后进一步进行[图片搜索]"
time.sleep(0.3)
except Exception as e:
output_text += f"[数据库访问失败] 失败原因:{e}"
error_trace = traceback.format_exc()
print(error_trace)
return output_text
def get_init_panel(self):
with gr.Blocks() as demo:
gr.HTML(init_css)
with gr.Row():
with gr.Column():
model_name_input = gr.Textbox(lines=1, label="模型名称", placeholder="请输入Hugging Face模型名称...", value=self.model_name)
output = gr.Textbox(lines=10, label="下载进度", placeholder="下载进度将在这里显示...")
init_button = gr.Button("开始下载模型")
init_button.click(
fn=self._download_model,
inputs=[model_name_input],
outputs=output
)
with gr.Column():
db_init_output = gr.Textbox(lines=14.5, label="数据库初始化结果", placeholder="数据库初始化结果将在这里显示...")
db_init_button = gr.Button("初始化向量数据库")
db_init_button.click(
fn=self._initialize_vector_db,
inputs=[],
outputs=db_init_output
)
with gr.Row():
upload_file = gr.File(label="上传图片或ZIP压缩包")
with gr.Row():
upload_output = gr.Textbox(lines=10, label="上传结果", placeholder="上传结果将在这里显示...")
with gr.Row():
upload_button = gr.Button("上传文件")
upload_button.click(
fn=self._handle_upload,
inputs=[upload_file],
outputs=[upload_output, gr.State()]
)
return demo