ADKU's picture
Create app.py
b2c1b30 verified
raw
history blame
607 Bytes
from fastapi import FastAPI
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
import torch
app = FastAPI()
model = DistilBertForSequenceClassification.from_pretrained("your-username/distilbert-recommendation")
tokenizer = DistilBertTokenizerFast.from_pretrained("your-username/distilbert-recommendation")
@app.post("/predict/")
async def predict(text: str):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
outputs = model(**inputs)
prediction = torch.argmax(outputs.logits, dim=-1).item()
return {"prediction": prediction}