document-vqa-v2 / main.py
MJobe's picture
Update main.py
dea0d8a
raw
history blame
2.54 kB
from io import BytesIO
from PIL import Image
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import JSONResponse
from transformers import pipeline
from pytesseract import pytesseract
app = FastAPI()
# Use a pipeline as a high-level helper
nlp_qa = pipeline("document-question-answering", model="impira/layoutlm-document-qa")
description = """
## Image-based Document QA
This API extracts text from an uploaded image using OCR and performs document question answering using a LayoutLM-based model.
### Endpoints:
- **POST /uploadfile/:** Upload an image file to extract text and answer provided questions.
- **POST /pdfUpload/:** Provide a file to extract text and answer provided questions.
"""
app = FastAPI(docs_url="/", description=description)
@app.post("/uploadfile/", description=description)
async def perform_document_qa(
file: UploadFile = File(...),
questions: str = Form(...),
):
try:
# Read the uploaded file
contents = await file.read()
# Convert binary content to image
image = Image.open(BytesIO(contents))
# Perform OCR to extract text from the image
text_content = pytesseract.image_to_string(image)
# Split the questions string into a list
question_list = [q.strip() for q in questions.split(',')]
# Perform document question answering for each question using LayoutLM-based model
answers_dict = {}
for question in question_list:
result = nlp_qa(
text_content,
question
)
answers_dict[question] = result['answer']
return answers_dict
except Exception as e:
return JSONResponse(content=f"Error processing file: {str(e)}", status_code=500)
@app.post("/pdfUpload/", description=description)
async def load_file(
file: UploadFile = File(...),
questions: str = Form(...),
):
try:
# Read the uploaded file as bytes
contents = await file.read()
# Perform document question answering for each question using LayoutLM-based model
answers_dict = {}
for question in questions.split(','):
result = nlp_qa(
contents.decode('utf-8'), # Assuming the content is text, adjust as needed
question.strip()
)
answers_dict[question] = result['answer']
return answers_dict
except Exception as e:
return JSONResponse(content=f"Error processing file: {str(e)}", status_code=500)