indic-trans-api / app.py
darshankr's picture
Update app.py
5c4a549 verified
from fastapi import FastAPI, HTTPException
from typing import List
from pydantic import BaseModel
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from IndicTransToolkit import IndicProcessor
from fastapi.middleware.cors import CORSMiddleware
import torch
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
model = AutoModelForSeq2SeqLM.from_pretrained(
"ai4bharat/indictrans2-indic-indic-1B", trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
"ai4bharat/indictrans2-indic-indic-1B", trust_remote_code=True
)
ip = IndicProcessor(inference=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(DEVICE)
def translate_text(sentences: List[str], src_lang: str, target_lang: str):
try:
batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=target_lang)
inputs = tokenizer(
batch,
truncation=True,
padding="longest",
return_tensors="pt",
return_attention_mask=True,
).to(DEVICE)
with torch.no_grad():
generated_tokens = model.generate(
**inputs,
use_cache=True,
min_length=0,
max_length=256,
num_beams=5,
num_return_sequences=1,
)
with tokenizer.as_target_tokenizer():
generated_tokens = tokenizer.batch_decode(
generated_tokens.detach().cpu().tolist(),
skip_special_tokens=True,
)
return generated_tokens
except Exception as e:
return str(e)
@app.get("/")
def read_root():
return {"Hello": "World"}
class TranslateRequest(BaseModel):
sentences: List[str]
src_lang: str
target_lang: str
@app.post("/translate/")
def translate(request: TranslateRequest):
try:
result = translate_text(request.sentences, request.src_lang, request.target_lang)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))