GOT-OCR-API / app.py
Mageia's picture
feat(app): 添加 API 根目录信息
f99597a unverified
import base64
import logging
import os
from datetime import datetime
import torch
from fastapi import FastAPI, File, Form, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoModel, AutoTokenizer
# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
app = FastAPI()
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 初始化模型
model_name = "Mageia/GOT-OCR2_0"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, device_map=device)
model = model.eval().to(device)
# OCR处理函数
async def ocr_process(image_path, got_mode, ocr_color="", ocr_box=""):
try:
if "plain" in got_mode:
if "multi-crop" in got_mode:
res = model.chat_crop(tokenizer, image_path, ocr_type="ocr")
else:
res = model.chat(tokenizer, image_path, ocr_type="ocr", ocr_box=ocr_box, ocr_color=ocr_color)
return res
elif "format" in got_mode:
result_path = f"{os.path.splitext(image_path)[0]}_result.html"
if "multi-crop" in got_mode:
res = model.chat_crop(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path)
else:
res = model.chat(tokenizer, image_path, ocr_type="format", ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)
if os.path.exists(result_path):
with open(result_path, "r", encoding="utf-8") as f:
html_content = f.read()
encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8")
return {"html_content": encoded_html}
return {"error": "未知的OCR模式"}
except Exception as e:
return {"error": str(e)}
@app.post("/ocr")
async def ocr_api(request: Request, image: UploadFile = File(...), got_mode: str = Form(...), ocr_color: str = Form(""), ocr_box: str = Form("")):
# 记录请求信息
client_host = request.client.host
user_agent = request.headers.get("user-agent", "Unknown")
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
log_message = f"""
时间: {current_time}
IP地址: {client_host}
User-Agent: {user_agent}
图片名称: {image.filename}
OCR模式: {got_mode}
OCR颜色: {ocr_color}
OCR边界框: {ocr_box}
"""
logger.info(log_message)
# 保存上传的图片
image_path = f"temp_{image.filename}"
with open(image_path, "wb") as buffer:
buffer.write(await image.read())
# 处理OCR
result = await ocr_process(image_path, got_mode, ocr_color, ocr_box)
# 删除临时文件
os.remove(image_path)
# 记录处理结果
logger.info(f"OCR处理结果: {result}")
return result
@app.get("/")
async def read_root(request: Request):
client_host = request.client.host
user_agent = request.headers.get("user-agent", "Unknown")
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
log_message = f"""
时间: {current_time}
IP地址: {client_host}
User-Agent: {user_agent}
访问: 根路径
"""
logger.info(log_message)
return {
"message": "欢迎使用OCR API",
"user_agent": user_agent,
"model": model_name,
"device": device,
"ocr_mode": [
"plain texts OCR",
"format texts OCR",
"plain multi-crop OCR",
"format multi-crop OCR",
"plain fine-grained OCR",
"format fine-grained OCR",
],
}