aleespace / app.py
aliMohammad16's picture
Update app.py
bbe006e verified
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import uvicorn
# Create FastAPI app
app = FastAPI()
# Load the tokenizer and model
MODEL_NAME = "facebook/bart-large-cnn" # A lightweight summarization model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to("cpu") # Use "cuda" if you have a GPU
# Define input format
class InputText(BaseModel):
text: str
@app.post("/summarize")
async def summarize_text(input_text: InputText):
inputs = tokenizer(input_text.text, return_tensors="pt", max_length=1024, truncation=True)
summary_ids = model.generate(inputs.input_ids, max_length=150, min_length=50, length_penalty=2.0)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return {"summary": summary}
# Ensure the application starts when running locally
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)