Spaces:
Runtime error
Runtime error
import time | |
import gradio as gr | |
from huggingface_hub import snapshot_download | |
import os | |
import zipfile | |
from PIL import Image, UnidentifiedImageError | |
from transformers import AutoProcessor, CLIPModel | |
from vector_db.vector_db_client import VectorDB | |
from tcvectordb.model.document import Document | |
import uuid | |
import traceback | |
import numpy as np | |
# 生成随机的 UUID | |
LOCAL_MODEL_PATH = "download_model.local_model_path" | |
MODEL_NAME = "download_model.model_name" | |
LOCAL_GRAPH_PATH="graph_upload.local_graph_path" | |
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" | |
init_css=""" | |
<style> | |
.equal-height-row { | |
display: flex; | |
} | |
.equal-height-column { | |
flex: 1; | |
display: flex; | |
flex-direction: column; | |
} | |
.equal-height-column > * { | |
flex: 1; | |
} | |
</style> | |
""" | |
class Initial_and_Upload: | |
def __init__(self, config,vdb: VectorDB): | |
self.vdb = vdb | |
self.model_name = config.get(MODEL_NAME) | |
self.local_model_path = config.get(LOCAL_MODEL_PATH) | |
self.local_graph_path=config.get(LOCAL_GRAPH_PATH) | |
self.model_cache_directory = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), self.local_model_path, self.model_name) | |
self.graph_cache_directory = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), self.local_graph_path) | |
def initial_model(self): | |
model = CLIPModel.from_pretrained(self.model_cache_directory) | |
processor = AutoProcessor.from_pretrained(self.model_cache_directory) | |
return model,processor | |
def _download_model(self, model_name, progress=gr.Progress()): | |
""" | |
下载指定的Hugging Face模型并保存在指定位置。 | |
参数: | |
model_name (str): 模型在Hugging Face上的名字。 | |
save_directory (str): 模型文件保存的位置。 | |
""" | |
os.environ['TRANSFORMERS_CACHE'] = self.model_cache_directory | |
# 创建保存目录(如果不存在) | |
if not os.path.exists(self.model_cache_directory): | |
os.makedirs(self.model_cache_directory) | |
text = f"[正在尝试下载] 模型 {model_name},因为涉及到模型相关的多个文件下载,进度仅在后台显示。\n" | |
progress(0.5, desc=text) | |
try: | |
# 下载模型 | |
snapshot_download( | |
repo_id=model_name, | |
local_dir=self.model_cache_directory, | |
local_dir_use_symlinks=False, | |
) | |
progress(1, f"模型 {model_name} 已下载并保存在 {self.model_cache_directory}") | |
text += f"模型 {model_name} 已下载并保存在 {self.model_cache_directory}" | |
time.sleep(0.3) | |
return text | |
except Exception as e: | |
text += f"[下载失败] 失败原因:{e}" | |
return text | |
def _process_image(self, image_path,emb_model,processor): | |
""" | |
处理单个图片文件,将其转换为向量。 | |
参数: | |
image_path (str): 图片文件的路径。 | |
返回: | |
torch.Tensor: 图片的向量表示。 | |
""" | |
image = Image.open(image_path) | |
# image.verify() # 验证图片是否有效 | |
inputs = processor(images=image, return_tensors="pt") | |
image_features = emb_model.get_image_features(**inputs) | |
return image_features | |
def _handle_upload(self, file, progress=gr.Progress()): | |
""" | |
处理上传的文件,识别是图片还是ZIP压缩包,并将图片转换为向量。 | |
参数: | |
file (file): 上传的文件。 | |
返回: | |
str: 文件类型和处理结果。 | |
""" | |
output_text = "" | |
image_vectors = [] | |
if not os.path.exists(self.model_cache_directory): | |
output_text += f"缓存目录 {self.model_cache_directory} 不存在,无法初始化模型。" | |
else: | |
model, processor = self.initial_model() | |
collection = self.vdb.get_collection() | |
if zipfile.is_zipfile(file.name): | |
with zipfile.ZipFile(file.name, 'r') as zip_ref: | |
zip_ref.extractall(self.local_graph_path) | |
image_files = [file_name for file_name in zip_ref.namelist() if file_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')) and not file_name.startswith('__MACOSX') and not file_name.startswith('._')] | |
total_files = len(image_files) | |
for i, file_name in enumerate(image_files): | |
image_path = os.path.join(self.local_graph_path, file_name) | |
try: | |
image_vector = self._process_image(image_path, model, processor).squeeze().tolist() # 转换为一维列表 | |
random_uuid = str(uuid.uuid4()) # 转换为字符串 | |
collection.upsert(documents=[Document(id=random_uuid, vector=image_vector, local_graph_path=image_path)], build_index=True) | |
output_text += f"处理图片: {file_name}\n" | |
except UnidentifiedImageError: | |
output_text += f"无法识别图片文件: {file_name}\n" | |
# 更新进度 | |
progress((i + 1) / total_files) | |
output_text += "上传的是ZIP压缩包,已解压缩并处理所有图片。" | |
else: | |
try: | |
# 保存单张图片到指定文件夹 | |
image_path = os.path.join(self.graph_cache_directory, os.path.basename(file.name)) | |
with open(file.name, "rb") as f_src: | |
with open(image_path, "wb") as f_dst: | |
f_dst.write(f_src.read()) | |
image_vector = self._process_image(image_path, model, processor).squeeze().tolist() # 转换为一维列表 | |
random_uuid = str(uuid.uuid4()) # 转换为字符串 | |
collection.upsert(documents=[Document(id=random_uuid, vector=image_vector, local_graph_path=image_path)], build_index=True) | |
output_text += "上传的是图片文件,并已处理。\n" | |
# 更新进度 | |
progress(1.0) | |
except (IOError, SyntaxError) as e: | |
output_text += f"无法识别文件类型:{e}\n" | |
# 返回处理结果和图片向量 | |
return output_text, image_vectors | |
def _initialize_vector_db(self, progress=gr.Progress()): | |
""" | |
初始化向量数据库。 | |
返回: | |
str: 初始化结果。 | |
""" | |
output_text = f"[正在尝试连接] VectorDB {self.vdb.address}\n" | |
progress(0, desc=output_text) | |
try: | |
client = self.vdb.create_client() | |
client.list_databases() | |
progress(0.05, f"[连接成功] VectorDB {self.vdb.address}\n") | |
output_text += f"[连接成功] VectorDB {self.vdb.address}\n" | |
client.close() | |
progress(0.1, f"[正在初始化] ai database '{self.vdb.db_name}'\n") | |
output_text += f"[正在初始化] ai database '{self.vdb.db_name}'\n" | |
self.vdb.init_database() | |
progress(0.3, f"[初始化完成] ai database '{self.vdb.db_name}'\n") | |
output_text += f"[初始化完成] ai database '{self.vdb.db_name}'\n" | |
progress(0.5, f"[正在初始化] ai collection '{self.vdb.ai_graph_emb_collection}'\n") | |
output_text += f"[正在初始化] ai collection '{self.vdb.ai_graph_emb_collection}'\n" | |
self.vdb.init_graph_collection() | |
progress(0.9, f"[初始化完成] ai collection '{self.vdb.ai_graph_emb_collection}'\n") | |
output_text += f"[初始化完成] ai collection '{self.vdb.ai_graph_emb_collection}'\n" | |
progress(1, f"您可以去图片上传栏目上传图片或ZIP压缩包,然后进一步进行[图片搜索]") | |
output_text += f"您可以去图片上传栏目上传图片或ZIP压缩包,然后进一步进行[图片搜索]" | |
time.sleep(0.3) | |
except Exception as e: | |
output_text += f"[数据库访问失败] 失败原因:{e}" | |
error_trace = traceback.format_exc() | |
print(error_trace) | |
return output_text | |
def get_init_panel(self): | |
with gr.Blocks() as demo: | |
gr.HTML(init_css) | |
with gr.Row(): | |
with gr.Column(): | |
model_name_input = gr.Textbox(lines=1, label="模型名称", placeholder="请输入Hugging Face模型名称...", value=self.model_name) | |
output = gr.Textbox(lines=10, label="下载进度", placeholder="下载进度将在这里显示...") | |
init_button = gr.Button("开始下载模型") | |
init_button.click( | |
fn=self._download_model, | |
inputs=[model_name_input], | |
outputs=output | |
) | |
with gr.Column(): | |
db_init_output = gr.Textbox(lines=14.5, label="数据库初始化结果", placeholder="数据库初始化结果将在这里显示...") | |
db_init_button = gr.Button("初始化向量数据库") | |
db_init_button.click( | |
fn=self._initialize_vector_db, | |
inputs=[], | |
outputs=db_init_output | |
) | |
with gr.Row(): | |
upload_file = gr.File(label="上传图片或ZIP压缩包") | |
with gr.Row(): | |
upload_output = gr.Textbox(lines=10, label="上传结果", placeholder="上传结果将在这里显示...") | |
with gr.Row(): | |
upload_button = gr.Button("上传文件") | |
upload_button.click( | |
fn=self._handle_upload, | |
inputs=[upload_file], | |
outputs=[upload_output, gr.State()] | |
) | |
return demo |