from fastapi import FastAPI, HTTPException from typing import List from pydantic import BaseModel from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from IndicTransToolkit import IndicProcessor from fastapi.middleware.cors import CORSMiddleware import torch app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) model = AutoModelForSeq2SeqLM.from_pretrained( "ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained( "ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True ) ip = IndicProcessor(inference=True) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(DEVICE) def translate_text(sentences: List[str], target_lang: str): try: src_lang = "eng_Latn" batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=target_lang) inputs = tokenizer( batch, truncation=True, padding="longest", return_tensors="pt", return_attention_mask=True, ).to(DEVICE) with torch.no_grad(): generated_tokens = model.generate( **inputs, use_cache=True, min_length=0, max_length=256, num_beams=5, num_return_sequences=1, ) with tokenizer.as_target_tokenizer(): generated_tokens = tokenizer.batch_decode( generated_tokens.detach().cpu().tolist(), skip_special_tokens=True, ) return generated_tokens except Exception as e: return str(e) @app.get("/") def read_root(): return {"Hello": "World"} class TranslateRequest(BaseModel): sentences: List[str] target_lang: str @app.post("/translate/") def translate(request: TranslateRequest): try: result = translate_text(request.sentences, request.target_lang) return result except Exception as e: raise HTTPException(status_code=500, detail=str(e))