SonyaX20
new
d5760d5
import os
import gradio as gr
import easyocr
from dotenv import load_dotenv
from openai import OpenAI
from PIL import Image
import numpy as np
import torch
# 加载环境变量
load_dotenv()
# 初始化 OpenAI 客户端
try:
# 首先尝试从环境变量获取
openai_api_key = os.getenv('OPENAI_API_KEY')
if not openai_api_key:
# 如果环境变量中没有,尝试从 .env 文件加载
if os.path.exists('.env'):
load_dotenv('.env')
openai_api_key = os.getenv('OPENAI_API_KEY')
if not openai_api_key:
raise ValueError("No OpenAI API key found in environment variables or .env file")
client = OpenAI(api_key=openai_api_key)
print("Successfully initialized OpenAI client")
except Exception as e:
print(f"Error initializing OpenAI client: {str(e)}")
raise
# 设置环境变量以禁用 CUDA 警告
os.environ['CUDA_VISIBLE_DEVICES'] = ''
# 设置模型下载目录
MODEL_CACHE_DIR = os.path.join(os.path.dirname(__file__), 'models')
os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
# 简化 GPU 检查
def check_gpu():
try:
if torch.cuda.is_available():
return 'cuda'
except:
pass
return 'cpu'
# 初始化设备
device = check_gpu()
print(f"Running on device: {device}")
# 预下载模型
def download_models():
try:
print("Checking for pre-downloaded models...")
model_files = [
os.path.join(MODEL_CACHE_DIR, 'craft_mlt_25k.pth'),
os.path.join(MODEL_CACHE_DIR, 'chinese_sim.pth'),
os.path.join(MODEL_CACHE_DIR, 'english_g2.pth')
]
all_models_exist = all(os.path.exists(f) for f in model_files)
if not all_models_exist:
print("Some models need to be downloaded...")
# 强制在 CPU 模式下下载模型
temp_reader = easyocr.Reader(
['ch_sim', 'en'],
gpu=False,
model_storage_directory=MODEL_CACHE_DIR,
download_enabled=True,
verbose=True
)
print("Model download completed")
else:
print("All models already downloaded")
except Exception as e:
print(f"Error during model download: {str(e)}")
# 下载模型
download_models()
# 初始化 EasyOCR
def initialize_easyocr():
try:
print("Initializing EasyOCR...")
reader = easyocr.Reader(
['ch_sim', 'en'],
gpu=False, # 强制使用 CPU 模式
model_storage_directory=MODEL_CACHE_DIR,
download_enabled=False, # 禁用自动下载
verbose=True
)
print("EasyOCR initialization completed!")
return reader
except Exception as e:
print(f"Error initializing EasyOCR: {str(e)}")
raise
# 初始化 reader
reader = initialize_easyocr()
def process_image(image):
"""处理上传的图片并返回识别结果和分析"""
if image is None:
return "请上传图片", "等待图片上传..."
try:
# 提取文字
text = extract_text_from_image(image)
if not text.strip():
return "未能识别到文字内容,请尝试上传清晰的图片", "无法分析空白内容"
# 分析内容
analysis = analyze_slide(text)
return text, analysis
except Exception as e:
return f"处理出错: {str(e)}", "请重试或联系管理员"
def extract_text_from_image(image):
"""从图片中提取文字"""
try:
if isinstance(image, str):
image_path = image
else:
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image_path = "temp_image.png"
image.save(image_path)
print("开始识别文字...")
result = reader.readtext(
image_path,
detail=1,
paragraph=True
)
print("文字识别完成")
# 删除临时文件
if image_path == "temp_image.png" and os.path.exists(image_path):
os.remove(image_path)
# 修改文字提取逻辑
sorted_text = []
for item in result:
# 检查返回结果的格式
if isinstance(item, (list, tuple)):
if len(item) >= 2: # 确保至少有 bbox 和 text
text = item[1] if len(item) >= 2 else ""
prob = item[2] if len(item) >= 3 else 1.0
if prob > 0.5: # 只保留置信度大于 0.5 的结果
sorted_text.append(text)
elif isinstance(item, dict): # 处理可能的字典格式
text = item.get('text', '')
prob = item.get('confidence', 1.0)
if prob > 0.5:
sorted_text.append(text)
final_text = ' '.join(sorted_text)
if not final_text.strip():
return "未能识别到清晰的文字,请尝试上传更清晰的图片"
print(f"识别到的文字: {final_text[:100]}...") # 打印前100个字符用于调试
return final_text
except Exception as e:
print(f"文字识别出错: {str(e)}")
import traceback
traceback.print_exc() # 打印详细错误信息
return f"图片处理出错: {str(e)}"
def analyze_slide(text):
"""使用 GPT-3.5-turbo 分析幻灯片内容"""
try:
prompt = f"""请分析以下幻灯片内容,并提供清晰的讲解:
内容:{text}
请按照以下结构组织回答:
1. 主要内容:用2-3句话概括核心内容
2. 重点解释:详细解释重要概念和关键点
3. 知识延伸:与其他知识的联系
4. 应用场景:在实际中的应用示例
请用中文回答,语言要通俗易懂。"""
response = client.chat.completions.create(
model="gpt-3.5-turbo", # 改用 GPT-3.5-turbo
messages=[{"role": "user", "content": prompt}],
temperature=0.7,
max_tokens=1000
)
return response.choices[0].message.content
except Exception as e:
error_msg = str(e)
if "model_not_found" in error_msg:
return "API 配置错误:无法访问指定的模型。请确保您的 OpenAI API Key 有正确的访问权限。"
elif "invalid_request_error" in error_msg:
return "API 请求错误:请检查 API Key 是否正确设置。"
else:
return f"内容分析出错: {error_msg}"
def chat_with_assistant(message, history, slide_text):
"""与 AI 助手对话"""
if not message:
return history
try:
context = f"""当前幻灯片内容:{slide_text}
请基于以上幻灯片内容,回答用户的问题。如果问题与幻灯片内容无关,也可以回答其他问题。"""
messages = [
{"role": "system", "content": "你是一位专业的课程助教,负责帮助学生理解课程内容。请用清晰易懂的中文回答问题。"},
{"role": "user", "content": context}
]
# 添加历史对话
for human, assistant in history:
messages.append({"role": "user", "content": human})
messages.append({"role": "assistant", "content": assistant})
messages.append({"role": "user", "content": message})
response = client.chat.completions.create(
model="gpt-3.5-turbo", # 改用 GPT-3.5-turbo
messages=messages,
temperature=0.7
)
history.append((message, response.choices[0].message.content))
return history
except Exception as e:
error_msg = str(e)
if "model_not_found" in error_msg:
error_response = "API 配置错误:无法访问指定的模型。请确保您的 OpenAI API Key 有正确的访问权限。"
elif "invalid_request_error" in error_msg:
error_response = "API 请求错误:请检查 API Key 是否正确设置。"
else:
error_response = f"回答出错: {error_msg}"
history.append((message, error_response))
return history
def check_api_key():
"""检查 API Key 是否已设置并测试连接"""
api_key = os.getenv('OPENAI_API_KEY')
if not api_key:
return """
<div style="padding: 1rem; background-color: #ffebee; border-radius: 0.5rem; margin: 1rem 0;">
<h2 style="color: #c62828;">⚠️ OpenAI API Key 未设置</h2>
<p>请在 Hugging Face Space 的 Settings 中设置 Repository Secrets:</p>
<ol>
<li>进入 Space Settings</li>
<li>找到 Repository Secrets 部分</li>
<li>添加名为 OPENAI_API_KEY 的 Secret</li>
<li>将你的 OpenAI API Key 填入值中</li>
<li>保存后重新启动 Space</li>
</ol>
</div>
"""
# 测试 API 连接
try:
client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "测试连接"}],
max_tokens=5
)
return None
except Exception as e:
return f"""
<div style="padding: 1rem; background-color: #ffebee; border-radius: 0.5rem; margin: 1rem 0;">
<h2 style="color: #c62828;">⚠️ API 连接测试失败</h2>
<p>错误信息:{str(e)}</p>
<p>请检查:</p>
<ol>
<li>API Key 是否正确</li>
<li>API Key 是否有效</li>
<li>是否有足够的额度</li>
</ol>
</div>
"""
# 创建 Gradio 界面
with gr.Blocks(title="课程幻灯片理解助手") as demo:
api_key_error = check_api_key()
if api_key_error:
gr.Markdown(api_key_error)
else:
gr.Markdown("# 📚 课程幻灯片理解助手")
gr.Markdown("上传幻灯片图片,AI 将自动识别内容并提供详细讲解")
# 存储当前识别的文字,用于对话上下文
current_text = gr.State("")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
label="上传幻灯片图片",
type="pil",
sources=["upload", "clipboard"]
)
status_text = gr.Markdown("等待上传图片...")
with gr.Column(scale=2):
text_output = gr.Textbox(
label="识别的文字内容",
lines=3,
placeholder="上传图片后将显示识别的文字内容..."
)
analysis_output = gr.Textbox(
label="AI 讲解分析",
lines=10,
placeholder="等待分析结果..."
)
gr.Markdown("---")
gr.Markdown("### 💬 与 AI 助手对话")
chatbot = gr.Chatbot(
label="对话历史",
height=400
)
with gr.Row():
msg = gr.Textbox(
label="输入你的问题",
placeholder="请输入你的问题...",
scale=4
)
clear = gr.Button("🗑️ 清除对话", scale=1)
# 设置事件处理
def update_status(image):
return "正在处理图片..." if image is not None else "等待上传图片..."
image_input.change(
fn=update_status,
inputs=[image_input],
outputs=[status_text]
).then(
fn=process_image,
inputs=[image_input],
outputs=[text_output, analysis_output]
).then(
fn=lambda x: x,
inputs=[text_output],
outputs=[current_text]
).then(
fn=lambda: "处理完成",
outputs=[status_text]
)
def chat_and_clear(message, history, slide_text):
"""聊天并清除输入"""
result = chat_with_assistant(message, history, slide_text)
return result, "" # 返回对话历史和空字符串来清除输入
msg.submit(
fn=chat_and_clear,
inputs=[msg, chatbot, current_text],
outputs=[chatbot, msg] # 添加 msg 作为输出来清除它
)
clear.click(
fn=lambda: ([], ""),
outputs=[chatbot, msg]
)
# 启动应用
if __name__ == "__main__":
demo.launch(
share=True,
max_threads=4,
show_error=True
)