document-vqa-v2 / main.py
MJobe's picture
Update main.py
2181fee
raw
history blame
2.66 kB
from io import BytesIO
from PIL import Image
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import JSONResponse
import fitz
from transformers import pipeline
import requests
from typing import List
from pytesseract import pytesseract
app = FastAPI()
# Load a BERT-based question answering pipeline
nlp_qa = pipeline('question-answering', model='bert-large-uncased-whole-word-masking-finetuned-squad')
description = """
## Image-based Document QA
This API extracts text from an uploaded image using OCR and performs document question answering using a BERT-based model.
### Endpoints:
- **POST /uploadfile/:** Upload an image file to extract text and answer provided questions.
- **POST /pdfUpload/:** Provide a file to extract text and answer provided questions.
"""
app = FastAPI(docs_url="/", description=description)
@app.post("/uploadfile/", description=description)
async def perform_document_qa(
file: UploadFile = File(...),
questions: str = Form(...),
):
try:
# Read the uploaded file
contents = await file.read()
# Convert binary content to image
image = Image.open(BytesIO(contents))
# Perform OCR to extract text from the image
text_content = pytesseract.image_to_string(image)
# Split the questions string into a list
question_list = [q.strip() for q in questions.split(',')]
# Perform document question answering for each question using BERT-based model
answers_dict = {}
for question in question_list:
result = nlp_qa({
'question': question,
'context': text_content
})
answers_dict[question] = result['answer']
return answers_dict
except Exception as e:
return JSONResponse(content=f"Error processing file: {str(e)}", status_code=500)
@app.post("/pdfUpload/", description=description)
async def load_file(
file: UploadFile = File(...),
questions: str = Form(...),
):
try:
# Read the uploaded file as bytes
contents = await file.read()
# Perform document question answering for each question using BERT-based model
answers_dict = {}
for question in questions.split(','):
result = nlp_qa({
'question': question.strip(),
'context': contents.decode('utf-8') # Assuming the content is text, adjust as needed
})
answers_dict[question] = result['answer']
return answers_dict
except Exception as e:
return JSONResponse(content=f"Error processing file: {str(e)}", status_code=500)