import time import gradio as gr from huggingface_hub import snapshot_download import os import zipfile from PIL import Image, UnidentifiedImageError from transformers import AutoProcessor, CLIPModel from vector_db.vector_db_client import VectorDB from tcvectordb.model.document import Document import uuid import traceback import numpy as np # 生成随机的 UUID LOCAL_MODEL_PATH = "download_model.local_model_path" MODEL_NAME = "download_model.model_name" LOCAL_GRAPH_PATH="graph_upload.local_graph_path" os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" init_css=""" """ class Initial_and_Upload: def __init__(self, config,vdb: VectorDB): self.vdb = vdb self.model_name = config.get(MODEL_NAME) self.local_model_path = config.get(LOCAL_MODEL_PATH) self.local_graph_path=config.get(LOCAL_GRAPH_PATH) self.model_cache_directory = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), self.local_model_path, self.model_name) self.graph_cache_directory = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), self.local_graph_path) def initial_model(self): model = CLIPModel.from_pretrained(self.model_cache_directory) processor = AutoProcessor.from_pretrained(self.model_cache_directory) return model,processor def _download_model(self, model_name, progress=gr.Progress()): """ 下载指定的Hugging Face模型并保存在指定位置。 参数: model_name (str): 模型在Hugging Face上的名字。 save_directory (str): 模型文件保存的位置。 """ os.environ['TRANSFORMERS_CACHE'] = self.model_cache_directory # 创建保存目录(如果不存在) if not os.path.exists(self.model_cache_directory): os.makedirs(self.model_cache_directory) text = f"[正在尝试下载] 模型 {model_name},因为涉及到模型相关的多个文件下载,进度仅在后台显示。\n" progress(0.5, desc=text) try: # 下载模型 snapshot_download( repo_id=model_name, local_dir=self.model_cache_directory, local_dir_use_symlinks=False, ) progress(1, f"模型 {model_name} 已下载并保存在 {self.model_cache_directory}") text += f"模型 {model_name} 已下载并保存在 {self.model_cache_directory}" time.sleep(0.3) return text except Exception as e: text += f"[下载失败] 失败原因:{e}" return text def _process_image(self, image_path,emb_model,processor): """ 处理单个图片文件,将其转换为向量。 参数: image_path (str): 图片文件的路径。 返回: torch.Tensor: 图片的向量表示。 """ image = Image.open(image_path) # image.verify() # 验证图片是否有效 inputs = processor(images=image, return_tensors="pt") image_features = emb_model.get_image_features(**inputs) return image_features def _handle_upload(self, file, progress=gr.Progress()): """ 处理上传的文件,识别是图片还是ZIP压缩包,并将图片转换为向量。 参数: file (file): 上传的文件。 返回: str: 文件类型和处理结果。 """ output_text = "" image_vectors = [] if not os.path.exists(self.model_cache_directory): output_text += f"缓存目录 {self.model_cache_directory} 不存在,无法初始化模型。" else: model, processor = self.initial_model() collection = self.vdb.get_collection() if zipfile.is_zipfile(file.name): with zipfile.ZipFile(file.name, 'r') as zip_ref: zip_ref.extractall(self.local_graph_path) image_files = [file_name for file_name in zip_ref.namelist() if file_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')) and not file_name.startswith('__MACOSX') and not file_name.startswith('._')] total_files = len(image_files) for i, file_name in enumerate(image_files): image_path = os.path.join(self.local_graph_path, file_name) try: image_vector = self._process_image(image_path, model, processor).squeeze().tolist() # 转换为一维列表 random_uuid = str(uuid.uuid4()) # 转换为字符串 collection.upsert(documents=[Document(id=random_uuid, vector=image_vector, local_graph_path=image_path)], build_index=True) output_text += f"处理图片: {file_name}\n" except UnidentifiedImageError: output_text += f"无法识别图片文件: {file_name}\n" # 更新进度 progress((i + 1) / total_files) output_text += "上传的是ZIP压缩包,已解压缩并处理所有图片。" else: try: # 保存单张图片到指定文件夹 image_path = os.path.join(self.graph_cache_directory, os.path.basename(file.name)) with open(file.name, "rb") as f_src: with open(image_path, "wb") as f_dst: f_dst.write(f_src.read()) image_vector = self._process_image(image_path, model, processor).squeeze().tolist() # 转换为一维列表 random_uuid = str(uuid.uuid4()) # 转换为字符串 collection.upsert(documents=[Document(id=random_uuid, vector=image_vector, local_graph_path=image_path)], build_index=True) output_text += "上传的是图片文件,并已处理。\n" # 更新进度 progress(1.0) except (IOError, SyntaxError) as e: output_text += f"无法识别文件类型:{e}\n" # 返回处理结果和图片向量 return output_text, image_vectors def _initialize_vector_db(self, progress=gr.Progress()): """ 初始化向量数据库。 返回: str: 初始化结果。 """ output_text = f"[正在尝试连接] VectorDB {self.vdb.address}\n" progress(0, desc=output_text) try: client = self.vdb.create_client() client.list_databases() progress(0.05, f"[连接成功] VectorDB {self.vdb.address}\n") output_text += f"[连接成功] VectorDB {self.vdb.address}\n" client.close() progress(0.1, f"[正在初始化] ai database '{self.vdb.db_name}'\n") output_text += f"[正在初始化] ai database '{self.vdb.db_name}'\n" self.vdb.init_database() progress(0.3, f"[初始化完成] ai database '{self.vdb.db_name}'\n") output_text += f"[初始化完成] ai database '{self.vdb.db_name}'\n" progress(0.5, f"[正在初始化] ai collection '{self.vdb.ai_graph_emb_collection}'\n") output_text += f"[正在初始化] ai collection '{self.vdb.ai_graph_emb_collection}'\n" self.vdb.init_graph_collection() progress(0.9, f"[初始化完成] ai collection '{self.vdb.ai_graph_emb_collection}'\n") output_text += f"[初始化完成] ai collection '{self.vdb.ai_graph_emb_collection}'\n" progress(1, f"您可以去图片上传栏目上传图片或ZIP压缩包,然后进一步进行[图片搜索]") output_text += f"您可以去图片上传栏目上传图片或ZIP压缩包,然后进一步进行[图片搜索]" time.sleep(0.3) except Exception as e: output_text += f"[数据库访问失败] 失败原因:{e}" error_trace = traceback.format_exc() print(error_trace) return output_text def get_init_panel(self): with gr.Blocks() as demo: gr.HTML(init_css) with gr.Row(): with gr.Column(): model_name_input = gr.Textbox(lines=1, label="模型名称", placeholder="请输入Hugging Face模型名称...", value=self.model_name) output = gr.Textbox(lines=10, label="下载进度", placeholder="下载进度将在这里显示...") init_button = gr.Button("开始下载模型") init_button.click( fn=self._download_model, inputs=[model_name_input], outputs=output ) with gr.Column(): db_init_output = gr.Textbox(lines=14.5, label="数据库初始化结果", placeholder="数据库初始化结果将在这里显示...") db_init_button = gr.Button("初始化向量数据库") db_init_button.click( fn=self._initialize_vector_db, inputs=[], outputs=db_init_output ) with gr.Row(): upload_file = gr.File(label="上传图片或ZIP压缩包") with gr.Row(): upload_output = gr.Textbox(lines=10, label="上传结果", placeholder="上传结果将在这里显示...") with gr.Row(): upload_button = gr.Button("上传文件") upload_button.click( fn=self._handle_upload, inputs=[upload_file], outputs=[upload_output, gr.State()] ) return demo