import os import gradio as gr import easyocr from dotenv import load_dotenv from openai import OpenAI from PIL import Image import numpy as np import torch # 加载环境变量 load_dotenv() # 初始化 OpenAI 客户端 try: # 首先尝试从环境变量获取 openai_api_key = os.getenv('OPENAI_API_KEY') if not openai_api_key: # 如果环境变量中没有,尝试从 .env 文件加载 if os.path.exists('.env'): load_dotenv('.env') openai_api_key = os.getenv('OPENAI_API_KEY') if not openai_api_key: raise ValueError("No OpenAI API key found in environment variables or .env file") client = OpenAI(api_key=openai_api_key) print("Successfully initialized OpenAI client") except Exception as e: print(f"Error initializing OpenAI client: {str(e)}") raise # 设置环境变量以禁用 CUDA 警告 os.environ['CUDA_VISIBLE_DEVICES'] = '' # 设置模型下载目录 MODEL_CACHE_DIR = os.path.join(os.path.dirname(__file__), 'models') os.makedirs(MODEL_CACHE_DIR, exist_ok=True) # 简化 GPU 检查 def check_gpu(): try: if torch.cuda.is_available(): return 'cuda' except: pass return 'cpu' # 初始化设备 device = check_gpu() print(f"Running on device: {device}") # 预下载模型 def download_models(): try: print("Checking for pre-downloaded models...") model_files = [ os.path.join(MODEL_CACHE_DIR, 'craft_mlt_25k.pth'), os.path.join(MODEL_CACHE_DIR, 'chinese_sim.pth'), os.path.join(MODEL_CACHE_DIR, 'english_g2.pth') ] all_models_exist = all(os.path.exists(f) for f in model_files) if not all_models_exist: print("Some models need to be downloaded...") # 强制在 CPU 模式下下载模型 temp_reader = easyocr.Reader( ['ch_sim', 'en'], gpu=False, model_storage_directory=MODEL_CACHE_DIR, download_enabled=True, verbose=True ) print("Model download completed") else: print("All models already downloaded") except Exception as e: print(f"Error during model download: {str(e)}") # 下载模型 download_models() # 初始化 EasyOCR def initialize_easyocr(): try: print("Initializing EasyOCR...") reader = easyocr.Reader( ['ch_sim', 'en'], gpu=False, # 强制使用 CPU 模式 model_storage_directory=MODEL_CACHE_DIR, download_enabled=False, # 禁用自动下载 verbose=True ) print("EasyOCR initialization completed!") return reader except Exception as e: print(f"Error initializing EasyOCR: {str(e)}") raise # 初始化 reader reader = initialize_easyocr() def process_image(image): """处理上传的图片并返回识别结果和分析""" if image is None: return "请上传图片", "等待图片上传..." try: # 提取文字 text = extract_text_from_image(image) if not text.strip(): return "未能识别到文字内容,请尝试上传清晰的图片", "无法分析空白内容" # 分析内容 analysis = analyze_slide(text) return text, analysis except Exception as e: return f"处理出错: {str(e)}", "请重试或联系管理员" def extract_text_from_image(image): """从图片中提取文字""" try: if isinstance(image, str): image_path = image else: if isinstance(image, np.ndarray): image = Image.fromarray(image) image_path = "temp_image.png" image.save(image_path) print("开始识别文字...") result = reader.readtext( image_path, detail=1, paragraph=True ) print("文字识别完成") # 删除临时文件 if image_path == "temp_image.png" and os.path.exists(image_path): os.remove(image_path) # 修改文字提取逻辑 sorted_text = [] for item in result: # 检查返回结果的格式 if isinstance(item, (list, tuple)): if len(item) >= 2: # 确保至少有 bbox 和 text text = item[1] if len(item) >= 2 else "" prob = item[2] if len(item) >= 3 else 1.0 if prob > 0.5: # 只保留置信度大于 0.5 的结果 sorted_text.append(text) elif isinstance(item, dict): # 处理可能的字典格式 text = item.get('text', '') prob = item.get('confidence', 1.0) if prob > 0.5: sorted_text.append(text) final_text = ' '.join(sorted_text) if not final_text.strip(): return "未能识别到清晰的文字,请尝试上传更清晰的图片" print(f"识别到的文字: {final_text[:100]}...") # 打印前100个字符用于调试 return final_text except Exception as e: print(f"文字识别出错: {str(e)}") import traceback traceback.print_exc() # 打印详细错误信息 return f"图片处理出错: {str(e)}" def analyze_slide(text): """使用 GPT-3.5-turbo 分析幻灯片内容""" try: prompt = f"""请分析以下幻灯片内容,并提供清晰的讲解: 内容:{text} 请按照以下结构组织回答: 1. 主要内容:用2-3句话概括核心内容 2. 重点解释:详细解释重要概念和关键点 3. 知识延伸:与其他知识的联系 4. 应用场景:在实际中的应用示例 请用中文回答,语言要通俗易懂。""" response = client.chat.completions.create( model="gpt-3.5-turbo", # 改用 GPT-3.5-turbo messages=[{"role": "user", "content": prompt}], temperature=0.7, max_tokens=1000 ) return response.choices[0].message.content except Exception as e: error_msg = str(e) if "model_not_found" in error_msg: return "API 配置错误:无法访问指定的模型。请确保您的 OpenAI API Key 有正确的访问权限。" elif "invalid_request_error" in error_msg: return "API 请求错误:请检查 API Key 是否正确设置。" else: return f"内容分析出错: {error_msg}" def chat_with_assistant(message, history, slide_text): """与 AI 助手对话""" if not message: return history try: context = f"""当前幻灯片内容:{slide_text} 请基于以上幻灯片内容,回答用户的问题。如果问题与幻灯片内容无关,也可以回答其他问题。""" messages = [ {"role": "system", "content": "你是一位专业的课程助教,负责帮助学生理解课程内容。请用清晰易懂的中文回答问题。"}, {"role": "user", "content": context} ] # 添加历史对话 for human, assistant in history: messages.append({"role": "user", "content": human}) messages.append({"role": "assistant", "content": assistant}) messages.append({"role": "user", "content": message}) response = client.chat.completions.create( model="gpt-3.5-turbo", # 改用 GPT-3.5-turbo messages=messages, temperature=0.7 ) history.append((message, response.choices[0].message.content)) return history except Exception as e: error_msg = str(e) if "model_not_found" in error_msg: error_response = "API 配置错误:无法访问指定的模型。请确保您的 OpenAI API Key 有正确的访问权限。" elif "invalid_request_error" in error_msg: error_response = "API 请求错误:请检查 API Key 是否正确设置。" else: error_response = f"回答出错: {error_msg}" history.append((message, error_response)) return history def check_api_key(): """检查 API Key 是否已设置并测试连接""" api_key = os.getenv('OPENAI_API_KEY') if not api_key: return """

⚠️ OpenAI API Key 未设置

请在 Hugging Face Space 的 Settings 中设置 Repository Secrets:

  1. 进入 Space Settings
  2. 找到 Repository Secrets 部分
  3. 添加名为 OPENAI_API_KEY 的 Secret
  4. 将你的 OpenAI API Key 填入值中
  5. 保存后重新启动 Space
""" # 测试 API 连接 try: client.chat.completions.create( model="gpt-3.5-turbo", messages=[{"role": "user", "content": "测试连接"}], max_tokens=5 ) return None except Exception as e: return f"""

⚠️ API 连接测试失败

错误信息:{str(e)}

请检查:

  1. API Key 是否正确
  2. API Key 是否有效
  3. 是否有足够的额度
""" # 创建 Gradio 界面 with gr.Blocks(title="课程幻灯片理解助手") as demo: api_key_error = check_api_key() if api_key_error: gr.Markdown(api_key_error) else: gr.Markdown("# 📚 课程幻灯片理解助手") gr.Markdown("上传幻灯片图片,AI 将自动识别内容并提供详细讲解") # 存储当前识别的文字,用于对话上下文 current_text = gr.State("") with gr.Row(): with gr.Column(scale=1): image_input = gr.Image( label="上传幻灯片图片", type="pil", sources=["upload", "clipboard"] ) status_text = gr.Markdown("等待上传图片...") with gr.Column(scale=2): text_output = gr.Textbox( label="识别的文字内容", lines=3, placeholder="上传图片后将显示识别的文字内容..." ) analysis_output = gr.Textbox( label="AI 讲解分析", lines=10, placeholder="等待分析结果..." ) gr.Markdown("---") gr.Markdown("### 💬 与 AI 助手对话") chatbot = gr.Chatbot( label="对话历史", height=400 ) with gr.Row(): msg = gr.Textbox( label="输入你的问题", placeholder="请输入你的问题...", scale=4 ) clear = gr.Button("🗑️ 清除对话", scale=1) # 设置事件处理 def update_status(image): return "正在处理图片..." if image is not None else "等待上传图片..." image_input.change( fn=update_status, inputs=[image_input], outputs=[status_text] ).then( fn=process_image, inputs=[image_input], outputs=[text_output, analysis_output] ).then( fn=lambda x: x, inputs=[text_output], outputs=[current_text] ).then( fn=lambda: "处理完成", outputs=[status_text] ) def chat_and_clear(message, history, slide_text): """聊天并清除输入""" result = chat_with_assistant(message, history, slide_text) return result, "" # 返回对话历史和空字符串来清除输入 msg.submit( fn=chat_and_clear, inputs=[msg, chatbot, current_text], outputs=[chatbot, msg] # 添加 msg 作为输出来清除它 ) clear.click( fn=lambda: ([], ""), outputs=[chatbot, msg] ) # 启动应用 if __name__ == "__main__": demo.launch( share=True, max_threads=4, show_error=True )