SonyaX20 commited on
Commit
a284143
·
1 Parent(s): cc4ded4
Files changed (2) hide show
  1. app.py +51 -48
  2. requirements.txt +1 -2
app.py CHANGED
@@ -30,66 +30,72 @@ except Exception as e:
30
  print(f"Error initializing OpenAI client: {str(e)}")
31
  raise
32
 
 
 
 
33
  # 设置模型下载目录
34
  MODEL_CACHE_DIR = os.path.join(os.path.dirname(__file__), 'models')
35
  os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
36
 
37
- # 检查 GPU 环境
38
  def check_gpu():
39
  try:
40
  if torch.cuda.is_available():
41
- # 获取 GPU 信息
42
- gpu_name = torch.cuda.get_device_name(0)
43
- gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3 # 转换为 GB
44
- print(f"Found GPU: {gpu_name} with {gpu_memory:.2f}GB memory")
45
  return 'cuda'
46
- else:
47
- print("No CUDA GPU available")
48
- return 'cpu'
49
- except Exception as e:
50
- print(f"Error checking GPU: {str(e)}")
51
- return 'cpu'
52
 
53
  # 初始化设备
54
  device = check_gpu()
55
  print(f"Running on device: {device}")
56
 
57
- # 初始化 EasyOCR(针对 T4 GPU 优化)
58
- def initialize_easyocr():
59
  try:
60
- print("Initializing EasyOCR and loading models...")
61
- if device == 'cuda':
62
- # 为 T4 GPU 设置较小的批处理大小和内存限制
63
- torch.cuda.empty_cache() # 清理 GPU 内存
64
- reader = easyocr.Reader(
65
- ['ch_sim', 'en'],
66
- gpu=True,
67
- download_enabled=True,
68
- verbose=True,
69
- model_storage_directory=MODEL_CACHE_DIR,
70
- recog_batch_size=8, # 减小批处理大小
71
- detector_batch_size=2
72
- )
73
- else:
74
- reader = easyocr.Reader(
75
  ['ch_sim', 'en'],
76
  gpu=False,
 
77
  download_enabled=True,
78
- verbose=True,
79
- model_storage_directory=MODEL_CACHE_DIR
80
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  print("EasyOCR initialization completed!")
82
  return reader
83
  except Exception as e:
84
  print(f"Error initializing EasyOCR: {str(e)}")
85
- print("Falling back to CPU mode...")
86
- return easyocr.Reader(
87
- ['ch_sim', 'en'],
88
- gpu=False,
89
- download_enabled=True,
90
- verbose=True,
91
- model_storage_directory=MODEL_CACHE_DIR
92
- )
93
 
94
  # 初始化 reader
95
  reader = initialize_easyocr()
@@ -124,15 +130,10 @@ def extract_text_from_image(image):
124
  image.save(image_path)
125
 
126
  print("开始识别文字...")
127
- if device == 'cuda':
128
- torch.cuda.empty_cache() # 清理 GPU 内存
129
-
130
- # 使用 EasyOCR 识别文字
131
  result = reader.readtext(
132
  image_path,
133
  detail=1,
134
- paragraph=True,
135
- batch_size=8 # 控制批处理大小
136
  )
137
  print("文字识别完成")
138
 
@@ -252,8 +253,7 @@ with gr.Blocks(title="课程幻灯片理解助手") as demo:
252
  if api_key_error:
253
  gr.Markdown(api_key_error)
254
  else:
255
- gpu_info = f"GPU (T4)" if device == 'cuda' else "CPU"
256
- gr.Markdown(f"# 📚 课程幻灯片理解助手 ({gpu_info} 模式)")
257
  gr.Markdown("上传幻灯片图片,AI 将自动识别内容并提供详细讲解")
258
 
259
  # 存储当前识别的文字,用于对话上下文
@@ -333,5 +333,8 @@ with gr.Blocks(title="课程幻灯片理解助手") as demo:
333
 
334
  # 启动应用
335
  if __name__ == "__main__":
336
- # 设置较小的并行处理数
337
- demo.launch(share=True, max_threads=4)
 
 
 
 
30
  print(f"Error initializing OpenAI client: {str(e)}")
31
  raise
32
 
33
+ # 设置环境变量以禁用 CUDA 警告
34
+ os.environ['CUDA_VISIBLE_DEVICES'] = ''
35
+
36
  # 设置模型下载目录
37
  MODEL_CACHE_DIR = os.path.join(os.path.dirname(__file__), 'models')
38
  os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
39
 
40
+ # 简化 GPU 检查
41
  def check_gpu():
42
  try:
43
  if torch.cuda.is_available():
 
 
 
 
44
  return 'cuda'
45
+ except:
46
+ pass
47
+ return 'cpu'
 
 
 
48
 
49
  # 初始化设备
50
  device = check_gpu()
51
  print(f"Running on device: {device}")
52
 
53
+ # 预下载模型
54
+ def download_models():
55
  try:
56
+ print("Checking for pre-downloaded models...")
57
+ model_files = [
58
+ os.path.join(MODEL_CACHE_DIR, 'craft_mlt_25k.pth'),
59
+ os.path.join(MODEL_CACHE_DIR, 'chinese_sim.pth'),
60
+ os.path.join(MODEL_CACHE_DIR, 'english_g2.pth')
61
+ ]
62
+
63
+ all_models_exist = all(os.path.exists(f) for f in model_files)
64
+ if not all_models_exist:
65
+ print("Some models need to be downloaded...")
66
+ # 强制在 CPU 模式下下载模型
67
+ temp_reader = easyocr.Reader(
 
 
 
68
  ['ch_sim', 'en'],
69
  gpu=False,
70
+ model_storage_directory=MODEL_CACHE_DIR,
71
  download_enabled=True,
72
+ verbose=True
 
73
  )
74
+ print("Model download completed")
75
+ else:
76
+ print("All models already downloaded")
77
+ except Exception as e:
78
+ print(f"Error during model download: {str(e)}")
79
+
80
+ # 下载模型
81
+ download_models()
82
+
83
+ # 初始化 EasyOCR
84
+ def initialize_easyocr():
85
+ try:
86
+ print("Initializing EasyOCR...")
87
+ reader = easyocr.Reader(
88
+ ['ch_sim', 'en'],
89
+ gpu=False, # 强制使用 CPU 模式
90
+ model_storage_directory=MODEL_CACHE_DIR,
91
+ download_enabled=False, # 禁用自动下载
92
+ verbose=True
93
+ )
94
  print("EasyOCR initialization completed!")
95
  return reader
96
  except Exception as e:
97
  print(f"Error initializing EasyOCR: {str(e)}")
98
+ raise
 
 
 
 
 
 
 
99
 
100
  # 初始化 reader
101
  reader = initialize_easyocr()
 
130
  image.save(image_path)
131
 
132
  print("开始识别文字...")
 
 
 
 
133
  result = reader.readtext(
134
  image_path,
135
  detail=1,
136
+ paragraph=True
 
137
  )
138
  print("文字识别完成")
139
 
 
253
  if api_key_error:
254
  gr.Markdown(api_key_error)
255
  else:
256
+ gr.Markdown("# 📚 课程幻灯片理解助手")
 
257
  gr.Markdown("上传幻灯片图片,AI 将自动识别内容并提供详细讲解")
258
 
259
  # 存储当前识别的文字,用于对话上下文
 
333
 
334
  # 启动应用
335
  if __name__ == "__main__":
336
+ demo.launch(
337
+ share=True,
338
+ max_threads=4,
339
+ show_error=True
340
+ )
requirements.txt CHANGED
@@ -5,5 +5,4 @@ python-dotenv>=1.0.0
5
  openai>=1.0.0
6
  Pillow>=10.0.0
7
  numpy>=1.24.0
8
- torch>=2.0.0
9
- torchvision>=0.15.0
 
5
  openai>=1.0.0
6
  Pillow>=10.0.0
7
  numpy>=1.24.0
8
+ torch>=2.0.0