pdfreader / app.py
Shankarm08's picture
Create app.py
7340eaf verified
raw
history blame
1.32 kB
import gradio as gr
import torch
from transformers import BertTokenizer, BertModel
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
app = FastAPI()
class TextClassificationRequest(BaseModel):
text: str
@app.post("/classify")
async def classify_text(request: TextClassificationRequest):
# Load the pre-trained BERT model and tokenizer
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)
# Preprocess the input text
inputs = tokenizer.encode_plus(
request.text,
add_special_tokens=True,
max_length=512,
return_attention_mask=True,
return_tensors='pt'
)
# Create a dictionary to store the output
output = {}
# Use the pre-trained BERT model to extract features from the input text
outputs = model(**inputs)
# Extract the features
features = outputs.last_hidden_state[:, 0, :]
# Store the output
output["features"] = features.tolist()
return output
# Create a Gradio interface
interface = gr.Interface(
fn=classify_text,
inputs="pdf",
outputs="text",
title="PDF Text Classification",
description="Upload a PDF file to classify its text"
)
# Launch the interface
interface.launch()