Spaces:
Sleeping
Sleeping
SonyaX20
commited on
Commit
·
cc4ded4
1
Parent(s):
581381b
new
Browse files- .gitignore +4 -1
- README.md +11 -8
- app.py +101 -17
- requirements.txt +3 -2
.gitignore
CHANGED
@@ -19,4 +19,7 @@ temp_image.png
|
|
19 |
# Distribution / packaging
|
20 |
dist/
|
21 |
build/
|
22 |
-
*.egg-info/
|
|
|
|
|
|
|
|
19 |
# Distribution / packaging
|
20 |
dist/
|
21 |
build/
|
22 |
+
*.egg-info/
|
23 |
+
|
24 |
+
# Model cache
|
25 |
+
models/
|
README.md
CHANGED
@@ -43,11 +43,14 @@ MIT License
|
|
43 |
## Hugging Face Spaces 部署说明
|
44 |
|
45 |
1. Fork 这个项目到你的 Hugging Face Space
|
46 |
-
2. 在 Space
|
47 |
-
-
|
48 |
-
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
43 |
## Hugging Face Spaces 部署说明
|
44 |
|
45 |
1. Fork 这个项目到你的 Hugging Face Space
|
46 |
+
2. 在 Space 设置中:
|
47 |
+
- Hardware: 选择 CPU (免费) 或 GPU (付费)
|
48 |
+
- Python packages: 确保所有依赖都已列在 requirements.txt 中
|
49 |
+
3. 添加 Repository Secrets:
|
50 |
+
- 名称:`OPENAI_API_KEY`
|
51 |
+
- 值:你的 OpenAI API Key
|
52 |
+
|
53 |
+
注意:
|
54 |
+
- 首次运行时会下载必要的模型文件,可能需要几分钟
|
55 |
+
- CPU 模式下识别速度较慢,但功能完整
|
56 |
+
- 如果需要更快的识别速度,建议使用 GPU 环境
|
app.py
CHANGED
@@ -10,24 +10,89 @@ import torch
|
|
10 |
# 加载环境变量
|
11 |
load_dotenv()
|
12 |
|
13 |
-
#
|
14 |
try:
|
|
|
15 |
openai_api_key = os.getenv('OPENAI_API_KEY')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
if not openai_api_key:
|
17 |
-
raise ValueError("No OpenAI API key found")
|
|
|
18 |
client = OpenAI(api_key=openai_api_key)
|
|
|
19 |
except Exception as e:
|
20 |
print(f"Error initializing OpenAI client: {str(e)}")
|
21 |
raise
|
22 |
|
23 |
-
#
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
print(f"Running on device: {device}")
|
26 |
|
27 |
-
# 初始化 EasyOCR
|
28 |
-
|
29 |
-
|
30 |
-
print("EasyOCR
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
def process_image(image):
|
33 |
"""处理上传的图片并返回识别结果和分析"""
|
@@ -59,11 +124,15 @@ def extract_text_from_image(image):
|
|
59 |
image.save(image_path)
|
60 |
|
61 |
print("开始识别文字...")
|
|
|
|
|
|
|
62 |
# 使用 EasyOCR 识别文字
|
63 |
result = reader.readtext(
|
64 |
image_path,
|
65 |
detail=1,
|
66 |
-
paragraph=True
|
|
|
67 |
)
|
68 |
print("文字识别完成")
|
69 |
|
@@ -71,18 +140,32 @@ def extract_text_from_image(image):
|
|
71 |
if image_path == "temp_image.png" and os.path.exists(image_path):
|
72 |
os.remove(image_path)
|
73 |
|
74 |
-
#
|
75 |
sorted_text = []
|
76 |
-
for
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
final_text = ' '.join(sorted_text)
|
81 |
if not final_text.strip():
|
82 |
return "未能识别到清晰的文字,请尝试上传更清晰的图片"
|
|
|
|
|
83 |
return final_text
|
84 |
except Exception as e:
|
85 |
print(f"文字识别出错: {str(e)}")
|
|
|
|
|
86 |
return f"图片处理出错: {str(e)}"
|
87 |
|
88 |
def analyze_slide(text):
|
@@ -169,7 +252,8 @@ with gr.Blocks(title="课程幻灯片理解助手") as demo:
|
|
169 |
if api_key_error:
|
170 |
gr.Markdown(api_key_error)
|
171 |
else:
|
172 |
-
|
|
|
173 |
gr.Markdown("上传幻灯片图片,AI 将自动识别内容并提供详细讲解")
|
174 |
|
175 |
# 存储当前识别的文字,用于对话上下文
|
@@ -200,8 +284,7 @@ with gr.Blocks(title="课程幻灯片理解助手") as demo:
|
|
200 |
gr.Markdown("### 💬 与 AI 助手对话")
|
201 |
chatbot = gr.Chatbot(
|
202 |
label="对话历史",
|
203 |
-
height=400
|
204 |
-
placeholder="在这里可以看到对话历史..."
|
205 |
)
|
206 |
with gr.Row():
|
207 |
msg = gr.Textbox(
|
@@ -250,4 +333,5 @@ with gr.Blocks(title="课程幻灯片理解助手") as demo:
|
|
250 |
|
251 |
# 启动应用
|
252 |
if __name__ == "__main__":
|
253 |
-
|
|
|
|
10 |
# 加载环境变量
|
11 |
load_dotenv()
|
12 |
|
13 |
+
# 初始化 OpenAI 客户端
|
14 |
try:
|
15 |
+
# 首先尝试从环境变量获取
|
16 |
openai_api_key = os.getenv('OPENAI_API_KEY')
|
17 |
+
|
18 |
+
if not openai_api_key:
|
19 |
+
# 如果环境变量中没有,尝试从 .env 文件加载
|
20 |
+
if os.path.exists('.env'):
|
21 |
+
load_dotenv('.env')
|
22 |
+
openai_api_key = os.getenv('OPENAI_API_KEY')
|
23 |
+
|
24 |
if not openai_api_key:
|
25 |
+
raise ValueError("No OpenAI API key found in environment variables or .env file")
|
26 |
+
|
27 |
client = OpenAI(api_key=openai_api_key)
|
28 |
+
print("Successfully initialized OpenAI client")
|
29 |
except Exception as e:
|
30 |
print(f"Error initializing OpenAI client: {str(e)}")
|
31 |
raise
|
32 |
|
33 |
+
# 设置模型下载目录
|
34 |
+
MODEL_CACHE_DIR = os.path.join(os.path.dirname(__file__), 'models')
|
35 |
+
os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
|
36 |
+
|
37 |
+
# 检查 GPU 环境
|
38 |
+
def check_gpu():
|
39 |
+
try:
|
40 |
+
if torch.cuda.is_available():
|
41 |
+
# 获取 GPU 信息
|
42 |
+
gpu_name = torch.cuda.get_device_name(0)
|
43 |
+
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3 # 转换为 GB
|
44 |
+
print(f"Found GPU: {gpu_name} with {gpu_memory:.2f}GB memory")
|
45 |
+
return 'cuda'
|
46 |
+
else:
|
47 |
+
print("No CUDA GPU available")
|
48 |
+
return 'cpu'
|
49 |
+
except Exception as e:
|
50 |
+
print(f"Error checking GPU: {str(e)}")
|
51 |
+
return 'cpu'
|
52 |
+
|
53 |
+
# 初始化设备
|
54 |
+
device = check_gpu()
|
55 |
print(f"Running on device: {device}")
|
56 |
|
57 |
+
# 初始化 EasyOCR(针对 T4 GPU 优化)
|
58 |
+
def initialize_easyocr():
|
59 |
+
try:
|
60 |
+
print("Initializing EasyOCR and loading models...")
|
61 |
+
if device == 'cuda':
|
62 |
+
# 为 T4 GPU 设置较小的批处理大小和内存限制
|
63 |
+
torch.cuda.empty_cache() # 清理 GPU 内存
|
64 |
+
reader = easyocr.Reader(
|
65 |
+
['ch_sim', 'en'],
|
66 |
+
gpu=True,
|
67 |
+
download_enabled=True,
|
68 |
+
verbose=True,
|
69 |
+
model_storage_directory=MODEL_CACHE_DIR,
|
70 |
+
recog_batch_size=8, # 减小批处理大小
|
71 |
+
detector_batch_size=2
|
72 |
+
)
|
73 |
+
else:
|
74 |
+
reader = easyocr.Reader(
|
75 |
+
['ch_sim', 'en'],
|
76 |
+
gpu=False,
|
77 |
+
download_enabled=True,
|
78 |
+
verbose=True,
|
79 |
+
model_storage_directory=MODEL_CACHE_DIR
|
80 |
+
)
|
81 |
+
print("EasyOCR initialization completed!")
|
82 |
+
return reader
|
83 |
+
except Exception as e:
|
84 |
+
print(f"Error initializing EasyOCR: {str(e)}")
|
85 |
+
print("Falling back to CPU mode...")
|
86 |
+
return easyocr.Reader(
|
87 |
+
['ch_sim', 'en'],
|
88 |
+
gpu=False,
|
89 |
+
download_enabled=True,
|
90 |
+
verbose=True,
|
91 |
+
model_storage_directory=MODEL_CACHE_DIR
|
92 |
+
)
|
93 |
+
|
94 |
+
# 初始化 reader
|
95 |
+
reader = initialize_easyocr()
|
96 |
|
97 |
def process_image(image):
|
98 |
"""处理上传的图片并返回识别结果和分析"""
|
|
|
124 |
image.save(image_path)
|
125 |
|
126 |
print("开始识别文字...")
|
127 |
+
if device == 'cuda':
|
128 |
+
torch.cuda.empty_cache() # 清理 GPU 内存
|
129 |
+
|
130 |
# 使用 EasyOCR 识别文字
|
131 |
result = reader.readtext(
|
132 |
image_path,
|
133 |
detail=1,
|
134 |
+
paragraph=True,
|
135 |
+
batch_size=8 # 控制批处理大小
|
136 |
)
|
137 |
print("文字识别完成")
|
138 |
|
|
|
140 |
if image_path == "temp_image.png" and os.path.exists(image_path):
|
141 |
os.remove(image_path)
|
142 |
|
143 |
+
# 修改文字提取逻辑
|
144 |
sorted_text = []
|
145 |
+
for item in result:
|
146 |
+
# 检查返回结果的格式
|
147 |
+
if isinstance(item, (list, tuple)):
|
148 |
+
if len(item) >= 2: # 确保至少有 bbox 和 text
|
149 |
+
text = item[1] if len(item) >= 2 else ""
|
150 |
+
prob = item[2] if len(item) >= 3 else 1.0
|
151 |
+
if prob > 0.5: # 只保留置信度大于 0.5 的结果
|
152 |
+
sorted_text.append(text)
|
153 |
+
elif isinstance(item, dict): # 处理可能的字典格式
|
154 |
+
text = item.get('text', '')
|
155 |
+
prob = item.get('confidence', 1.0)
|
156 |
+
if prob > 0.5:
|
157 |
+
sorted_text.append(text)
|
158 |
|
159 |
final_text = ' '.join(sorted_text)
|
160 |
if not final_text.strip():
|
161 |
return "未能识别到清晰的文字,请尝试上传更清晰的图片"
|
162 |
+
|
163 |
+
print(f"识别到的文字: {final_text[:100]}...") # 打印前100个字符用于调试
|
164 |
return final_text
|
165 |
except Exception as e:
|
166 |
print(f"文字识别出错: {str(e)}")
|
167 |
+
import traceback
|
168 |
+
traceback.print_exc() # 打印详细错误信息
|
169 |
return f"图片处理出错: {str(e)}"
|
170 |
|
171 |
def analyze_slide(text):
|
|
|
252 |
if api_key_error:
|
253 |
gr.Markdown(api_key_error)
|
254 |
else:
|
255 |
+
gpu_info = f"GPU (T4)" if device == 'cuda' else "CPU"
|
256 |
+
gr.Markdown(f"# 📚 课程幻灯片理解助手 ({gpu_info} 模式)")
|
257 |
gr.Markdown("上传幻灯片图片,AI 将自动识别内容并提供详细讲解")
|
258 |
|
259 |
# 存储当前识别的文字,用于对话上下文
|
|
|
284 |
gr.Markdown("### 💬 与 AI 助手对话")
|
285 |
chatbot = gr.Chatbot(
|
286 |
label="对话历史",
|
287 |
+
height=400
|
|
|
288 |
)
|
289 |
with gr.Row():
|
290 |
msg = gr.Textbox(
|
|
|
333 |
|
334 |
# 启动应用
|
335 |
if __name__ == "__main__":
|
336 |
+
# 设置较小的并行处理数
|
337 |
+
demo.launch(share=True, max_threads=4)
|
requirements.txt
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
huggingface_hub==0.25.2
|
2 |
-
gradio==4.
|
3 |
easyocr>=1.7.1
|
4 |
python-dotenv>=1.0.0
|
5 |
openai>=1.0.0
|
6 |
Pillow>=10.0.0
|
7 |
numpy>=1.24.0
|
8 |
-
torch>=2.0.0
|
|
|
|
1 |
huggingface_hub==0.25.2
|
2 |
+
gradio==4.44.1
|
3 |
easyocr>=1.7.1
|
4 |
python-dotenv>=1.0.0
|
5 |
openai>=1.0.0
|
6 |
Pillow>=10.0.0
|
7 |
numpy>=1.24.0
|
8 |
+
torch>=2.0.0
|
9 |
+
torchvision>=0.15.0
|