Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import easyocr | |
| from dotenv import load_dotenv | |
| from openai import OpenAI | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| # 加载环境变量 | |
| load_dotenv() | |
| # 初始化 OpenAI 客户端 | |
| try: | |
| # 首先尝试从环境变量获取 | |
| openai_api_key = os.getenv('OPENAI_API_KEY') | |
| if not openai_api_key: | |
| # 如果环境变量中没有,尝试从 .env 文件加载 | |
| if os.path.exists('.env'): | |
| load_dotenv('.env') | |
| openai_api_key = os.getenv('OPENAI_API_KEY') | |
| if not openai_api_key: | |
| raise ValueError("No OpenAI API key found in environment variables or .env file") | |
| client = OpenAI(api_key=openai_api_key) | |
| print("Successfully initialized OpenAI client") | |
| except Exception as e: | |
| print(f"Error initializing OpenAI client: {str(e)}") | |
| raise | |
| # 设置环境变量以禁用 CUDA 警告 | |
| os.environ['CUDA_VISIBLE_DEVICES'] = '' | |
| # 设置模型下载目录 | |
| MODEL_CACHE_DIR = os.path.join(os.path.dirname(__file__), 'models') | |
| os.makedirs(MODEL_CACHE_DIR, exist_ok=True) | |
| # 简化 GPU 检查 | |
| def check_gpu(): | |
| try: | |
| if torch.cuda.is_available(): | |
| return 'cuda' | |
| except: | |
| pass | |
| return 'cpu' | |
| # 初始化设备 | |
| device = check_gpu() | |
| print(f"Running on device: {device}") | |
| # 预下载模型 | |
| def download_models(): | |
| try: | |
| print("Checking for pre-downloaded models...") | |
| model_files = [ | |
| os.path.join(MODEL_CACHE_DIR, 'craft_mlt_25k.pth'), | |
| os.path.join(MODEL_CACHE_DIR, 'chinese_sim.pth'), | |
| os.path.join(MODEL_CACHE_DIR, 'english_g2.pth') | |
| ] | |
| all_models_exist = all(os.path.exists(f) for f in model_files) | |
| if not all_models_exist: | |
| print("Some models need to be downloaded...") | |
| # 强制在 CPU 模式下下载模型 | |
| temp_reader = easyocr.Reader( | |
| ['ch_sim', 'en'], | |
| gpu=False, | |
| model_storage_directory=MODEL_CACHE_DIR, | |
| download_enabled=True, | |
| verbose=True | |
| ) | |
| print("Model download completed") | |
| else: | |
| print("All models already downloaded") | |
| except Exception as e: | |
| print(f"Error during model download: {str(e)}") | |
| # 下载模型 | |
| download_models() | |
| # 初始化 EasyOCR | |
| def initialize_easyocr(): | |
| try: | |
| print("Initializing EasyOCR...") | |
| reader = easyocr.Reader( | |
| ['ch_sim', 'en'], | |
| gpu=False, # 强制使用 CPU 模式 | |
| model_storage_directory=MODEL_CACHE_DIR, | |
| download_enabled=False, # 禁用自动下载 | |
| verbose=True | |
| ) | |
| print("EasyOCR initialization completed!") | |
| return reader | |
| except Exception as e: | |
| print(f"Error initializing EasyOCR: {str(e)}") | |
| raise | |
| # 初始化 reader | |
| reader = initialize_easyocr() | |
| def process_image(image): | |
| """处理上传的图片并返回识别结果和分析""" | |
| if image is None: | |
| return "请上传图片", "等待图片上传..." | |
| try: | |
| # 提取文字 | |
| text = extract_text_from_image(image) | |
| if not text.strip(): | |
| return "未能识别到文字内容,请尝试上传清晰的图片", "无法分析空白内容" | |
| # 分析内容 | |
| analysis = analyze_slide(text) | |
| return text, analysis | |
| except Exception as e: | |
| return f"处理出错: {str(e)}", "请重试或联系管理员" | |
| def extract_text_from_image(image): | |
| """从图片中提取文字""" | |
| try: | |
| if isinstance(image, str): | |
| image_path = image | |
| else: | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| image_path = "temp_image.png" | |
| image.save(image_path) | |
| print("开始识别文字...") | |
| result = reader.readtext( | |
| image_path, | |
| detail=1, | |
| paragraph=True | |
| ) | |
| print("文字识别完成") | |
| # 删除临时文件 | |
| if image_path == "temp_image.png" and os.path.exists(image_path): | |
| os.remove(image_path) | |
| # 修改文字提取逻辑 | |
| sorted_text = [] | |
| for item in result: | |
| # 检查返回结果的格式 | |
| if isinstance(item, (list, tuple)): | |
| if len(item) >= 2: # 确保至少有 bbox 和 text | |
| text = item[1] if len(item) >= 2 else "" | |
| prob = item[2] if len(item) >= 3 else 1.0 | |
| if prob > 0.5: # 只保留置信度大于 0.5 的结果 | |
| sorted_text.append(text) | |
| elif isinstance(item, dict): # 处理可能的字典格式 | |
| text = item.get('text', '') | |
| prob = item.get('confidence', 1.0) | |
| if prob > 0.5: | |
| sorted_text.append(text) | |
| final_text = ' '.join(sorted_text) | |
| if not final_text.strip(): | |
| return "未能识别到清晰的文字,请尝试上传更清晰的图片" | |
| print(f"识别到的文字: {final_text[:100]}...") # 打印前100个字符用于调试 | |
| return final_text | |
| except Exception as e: | |
| print(f"文字识别出错: {str(e)}") | |
| import traceback | |
| traceback.print_exc() # 打印详细错误信息 | |
| return f"图片处理出错: {str(e)}" | |
| def analyze_slide(text): | |
| """使用 GPT-3.5-turbo 分析幻灯片内容""" | |
| try: | |
| prompt = f"""请分析以下幻灯片内容,并提供清晰的讲解: | |
| 内容:{text} | |
| 请按照以下结构组织回答: | |
| 1. 主要内容:用2-3句话概括核心内容 | |
| 2. 重点解释:详细解释重要概念和关键点 | |
| 3. 知识延伸:与其他知识的联系 | |
| 4. 应用场景:在实际中的应用示例 | |
| 请用中文回答,语言要通俗易懂。""" | |
| response = client.chat.completions.create( | |
| model="gpt-3.5-turbo", # 改用 GPT-3.5-turbo | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.7, | |
| max_tokens=1000 | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| error_msg = str(e) | |
| if "model_not_found" in error_msg: | |
| return "API 配置错误:无法访问指定的模型。请确保您的 OpenAI API Key 有正确的访问权限。" | |
| elif "invalid_request_error" in error_msg: | |
| return "API 请求错误:请检查 API Key 是否正确设置。" | |
| else: | |
| return f"内容分析出错: {error_msg}" | |
| def chat_with_assistant(message, history, slide_text): | |
| """与 AI 助手对话""" | |
| if not message: | |
| return history | |
| try: | |
| context = f"""当前幻灯片内容:{slide_text} | |
| 请基于以上幻灯片内容,回答用户的问题。如果问题与幻灯片内容无关,也可以回答其他问题。""" | |
| messages = [ | |
| {"role": "system", "content": "你是一位专业的课程助教,负责帮助学生理解课程内容。请用清晰易懂的中文回答问题。"}, | |
| {"role": "user", "content": context} | |
| ] | |
| # 添加历史对话 | |
| for human, assistant in history: | |
| messages.append({"role": "user", "content": human}) | |
| messages.append({"role": "assistant", "content": assistant}) | |
| messages.append({"role": "user", "content": message}) | |
| response = client.chat.completions.create( | |
| model="gpt-3.5-turbo", # 改用 GPT-3.5-turbo | |
| messages=messages, | |
| temperature=0.7 | |
| ) | |
| history.append((message, response.choices[0].message.content)) | |
| return history | |
| except Exception as e: | |
| error_msg = str(e) | |
| if "model_not_found" in error_msg: | |
| error_response = "API 配置错误:无法访问指定的模型。请确保您的 OpenAI API Key 有正确的访问权限。" | |
| elif "invalid_request_error" in error_msg: | |
| error_response = "API 请求错误:请检查 API Key 是否正确设置。" | |
| else: | |
| error_response = f"回答出错: {error_msg}" | |
| history.append((message, error_response)) | |
| return history | |
| def check_api_key(): | |
| """检查 API Key 是否已设置并测试连接""" | |
| api_key = os.getenv('OPENAI_API_KEY') | |
| if not api_key: | |
| return """ | |
| <div style="padding: 1rem; background-color: #ffebee; border-radius: 0.5rem; margin: 1rem 0;"> | |
| <h2 style="color: #c62828;">⚠️ OpenAI API Key 未设置</h2> | |
| <p>请在 Hugging Face Space 的 Settings 中设置 Repository Secrets:</p> | |
| <ol> | |
| <li>进入 Space Settings</li> | |
| <li>找到 Repository Secrets 部分</li> | |
| <li>添加名为 OPENAI_API_KEY 的 Secret</li> | |
| <li>将你的 OpenAI API Key 填入值中</li> | |
| <li>保存后重新启动 Space</li> | |
| </ol> | |
| </div> | |
| """ | |
| # 测试 API 连接 | |
| try: | |
| client.chat.completions.create( | |
| model="gpt-3.5-turbo", | |
| messages=[{"role": "user", "content": "测试连接"}], | |
| max_tokens=5 | |
| ) | |
| return None | |
| except Exception as e: | |
| return f""" | |
| <div style="padding: 1rem; background-color: #ffebee; border-radius: 0.5rem; margin: 1rem 0;"> | |
| <h2 style="color: #c62828;">⚠️ API 连接测试失败</h2> | |
| <p>错误信息:{str(e)}</p> | |
| <p>请检查:</p> | |
| <ol> | |
| <li>API Key 是否正确</li> | |
| <li>API Key 是否有效</li> | |
| <li>是否有足够的额度</li> | |
| </ol> | |
| </div> | |
| """ | |
| # 创建 Gradio 界面 | |
| with gr.Blocks(title="课程幻灯片理解助手") as demo: | |
| api_key_error = check_api_key() | |
| if api_key_error: | |
| gr.Markdown(api_key_error) | |
| else: | |
| gr.Markdown("# 📚 课程幻灯片理解助手") | |
| gr.Markdown("上传幻灯片图片,AI 将自动识别内容并提供详细讲解") | |
| # 存储当前识别的文字,用于对话上下文 | |
| current_text = gr.State("") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image( | |
| label="上传幻灯片图片", | |
| type="pil", | |
| sources=["upload", "clipboard"] | |
| ) | |
| status_text = gr.Markdown("等待上传图片...") | |
| with gr.Column(scale=2): | |
| text_output = gr.Textbox( | |
| label="识别的文字内容", | |
| lines=3, | |
| placeholder="上传图片后将显示识别的文字内容..." | |
| ) | |
| analysis_output = gr.Textbox( | |
| label="AI 讲解分析", | |
| lines=10, | |
| placeholder="等待分析结果..." | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("### 💬 与 AI 助手对话") | |
| chatbot = gr.Chatbot( | |
| label="对话历史", | |
| height=400 | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="输入你的问题", | |
| placeholder="请输入你的问题...", | |
| scale=4 | |
| ) | |
| clear = gr.Button("🗑️ 清除对话", scale=1) | |
| # 设置事件处理 | |
| def update_status(image): | |
| return "正在处理图片..." if image is not None else "等待上传图片..." | |
| image_input.change( | |
| fn=update_status, | |
| inputs=[image_input], | |
| outputs=[status_text] | |
| ).then( | |
| fn=process_image, | |
| inputs=[image_input], | |
| outputs=[text_output, analysis_output] | |
| ).then( | |
| fn=lambda x: x, | |
| inputs=[text_output], | |
| outputs=[current_text] | |
| ).then( | |
| fn=lambda: "处理完成", | |
| outputs=[status_text] | |
| ) | |
| def chat_and_clear(message, history, slide_text): | |
| """聊天并清除输入""" | |
| result = chat_with_assistant(message, history, slide_text) | |
| return result, "" # 返回对话历史和空字符串来清除输入 | |
| msg.submit( | |
| fn=chat_and_clear, | |
| inputs=[msg, chatbot, current_text], | |
| outputs=[chatbot, msg] # 添加 msg 作为输出来清除它 | |
| ) | |
| clear.click( | |
| fn=lambda: ([], ""), | |
| outputs=[chatbot, msg] | |
| ) | |
| # 启动应用 | |
| if __name__ == "__main__": | |
| demo.launch( | |
| share=True, | |
| max_threads=4, | |
| show_error=True | |
| ) |