Spaces:
Sleeping
Sleeping
import base64 | |
import logging | |
import os | |
from datetime import datetime | |
import torch | |
from fastapi import FastAPI, File, Form, Request, UploadFile | |
from fastapi.middleware.cors import CORSMiddleware | |
from transformers import AutoModel, AutoTokenizer | |
# 配置日志 | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
# 添加CORS中间件 | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# 初始化模型 | |
model_name = "Mageia/GOT-OCR2_0" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, device_map=device) | |
model = model.eval().to(device) | |
# OCR处理函数 | |
async def ocr_process(image_path, got_mode, ocr_color="", ocr_box=""): | |
try: | |
if "plain" in got_mode: | |
if "multi-crop" in got_mode: | |
res = model.chat_crop(tokenizer, image_path, ocr_type="ocr") | |
else: | |
res = model.chat(tokenizer, image_path, ocr_type="ocr", ocr_box=ocr_box, ocr_color=ocr_color) | |
return res | |
elif "format" in got_mode: | |
result_path = f"{os.path.splitext(image_path)[0]}_result.html" | |
if "multi-crop" in got_mode: | |
res = model.chat_crop(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path) | |
else: | |
res = model.chat(tokenizer, image_path, ocr_type="format", ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path) | |
if os.path.exists(result_path): | |
with open(result_path, "r", encoding="utf-8") as f: | |
html_content = f.read() | |
encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8") | |
return {"html_content": encoded_html} | |
return {"error": "未知的OCR模式"} | |
except Exception as e: | |
return {"error": str(e)} | |
async def ocr_api(request: Request, image: UploadFile = File(...), got_mode: str = Form(...), ocr_color: str = Form(""), ocr_box: str = Form("")): | |
# 记录请求信息 | |
client_host = request.client.host | |
user_agent = request.headers.get("user-agent", "Unknown") | |
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
log_message = f""" | |
时间: {current_time} | |
IP地址: {client_host} | |
User-Agent: {user_agent} | |
图片名称: {image.filename} | |
OCR模式: {got_mode} | |
OCR颜色: {ocr_color} | |
OCR边界框: {ocr_box} | |
""" | |
logger.info(log_message) | |
# 保存上传的图片 | |
image_path = f"temp_{image.filename}" | |
with open(image_path, "wb") as buffer: | |
buffer.write(await image.read()) | |
# 处理OCR | |
result = await ocr_process(image_path, got_mode, ocr_color, ocr_box) | |
# 删除临时文件 | |
os.remove(image_path) | |
# 记录处理结果 | |
logger.info(f"OCR处理结果: {result}") | |
return result | |
async def read_root(request: Request): | |
client_host = request.client.host | |
user_agent = request.headers.get("user-agent", "Unknown") | |
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
log_message = f""" | |
时间: {current_time} | |
IP地址: {client_host} | |
User-Agent: {user_agent} | |
访问: 根路径 | |
""" | |
logger.info(log_message) | |
return { | |
"message": "欢迎使用OCR API", | |
"user_agent": user_agent, | |
"model": model_name, | |
"device": device, | |
"ocr_mode": [ | |
"plain texts OCR", | |
"format texts OCR", | |
"plain multi-crop OCR", | |
"format multi-crop OCR", | |
"plain fine-grained OCR", | |
"format fine-grained OCR", | |
], | |
} | |