File size: 4,614 Bytes
356330f
8f17baa
aceb54a
 
9719116
aceb54a
 
 
 
 
 
 
 
 
 
 
 
 
 
5f9cb8e
dd6ac97
aceb54a
 
 
9719116
 
 
 
 
 
 
 
 
aceb54a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82adbca
aceb54a
 
82adbca
aceb54a
 
 
 
82adbca
aceb54a
82adbca
aceb54a
82adbca
486fb7a
 
 
 
 
 
 
 
 
 
aceb54a
 
486fb7a
 
 
aceb54a
82adbca
aceb54a
 
 
 
 
 
 
 
 
82adbca
da60b63
aceb54a
 
 
 
 
05f8689
aceb54a
 
 
 
 
 
 
 
 
 
 
 
 
 
8f17baa
 
aceb54a
 
 
 
82adbca
da60b63
aceb54a
82adbca
8f17baa
 
 
356330f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from fastapi import FastAPI, File, UploadFile, HTTPException, Request
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
import fitz  # PyMuPDF
from transformers import AutoModel, AutoTokenizer
from PIL import Image
import numpy as np
import os
import base64
import io
import uuid
import tempfile
import time
import shutil
from pathlib import Path
import json
from starlette.requests import Request
import uvicorn
from bs4 import BeautifulSoup

app = FastAPI()

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allow all origins for simplicity
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, device_map='cuda', use_safetensors=True)
model = model.eval().cuda()

UPLOAD_FOLDER = "./uploads"
RESULTS_FOLDER = "./results"

# Ensure directories exist
for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
    if not os.path.exists(folder):
        os.makedirs(folder)

def image_to_base64(image):
    buffered = io.BytesIO()
    image.save(buffered, format="PNG")
    return base64.b64encode(buffered.getvalue()).decode()

def pdf_to_images(pdf_path):
    images = []
    pdf_document = fitz.open(pdf_path)
    for page_num in range(len(pdf_document)):
        page = pdf_document.load_page(page_num)
        pix = page.get_pixmap()
        img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
        images.append(img)
    return images

def run_GOT(pdf_file):
    unique_id = str(uuid.uuid4())
    pdf_path = os.path.join(UPLOAD_FOLDER, f"{unique_id}.pdf")
    shutil.copy(pdf_file, pdf_path)

    images = pdf_to_images(pdf_path)
    results = []

    try:
        for i, image in enumerate(images):
            image_path = os.path.join(UPLOAD_FOLDER, f"{unique_id}_page_{i+1}.png")
            image.save(image_path)

            result_path = os.path.join(RESULTS_FOLDER, f"{unique_id}_page_{i+1}.html")

            res = model.chat_crop(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)

            # Read the rendered HTML content
            with open(result_path, 'r') as f:
                html_content = f.read()

            # Encode the HTML content to base64
            encoded_html = base64.b64encode(html_content.encode('utf-8')).decode('utf-8')
            iframe_src = f"data:text/html;base64,{encoded_html}"
            iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>'
            download_link = f'<a href="data:text/html;base64,{encoded_html}" download="result_{unique_id}_page_{i+1}.html">Download Full Result</a>'

            results.append({
                "page_number": i + 1,
                "text": res,  # Directly use the output from model.chat_crop
                "html": iframe,
                "download_link": download_link
            })

            if os.path.exists(image_path):
                os.remove(image_path)
            if os.path.exists(result_path):
                os.remove(result_path)
    except Exception as e:
        return f"Error: {str(e)}", None
    finally:
        if os.path.exists(pdf_path):
            os.remove(pdf_path)

    return json.dumps(results, indent=4), results

def cleanup_old_files():
    current_time = time.time()
    for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
        for file_path in Path(folder).glob('*'):
            if current_time - file_path.stat().st_mtime > 3600:  # 1 hour
                file_path.unlink()

cleanup_old_files()

# Mount static files
app.mount("/static", StaticFiles(directory="static"), name="static")

# Set up Jinja2 templates
templates = Jinja2Templates(directory="templates")

@app.get("/", response_class=HTMLResponse)
async def read_root(request: Request):
    return templates.TemplateResponse("index.html", {"request": request})

@app.post("/uploadfile/", response_class=JSONResponse)
async def upload_file(file: UploadFile = File(...)):
    temp_dir = tempfile.TemporaryDirectory()
    temp_pdf_path = os.path.join(temp_dir.name, file.filename)
    with open(temp_pdf_path, "wb") as buffer:
        buffer.write(await file.read())

    json_output, results = run_GOT(temp_pdf_path)
    temp_dir.cleanup()

    return results

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)