Spaces:
Running
Running
File size: 2,174 Bytes
ad152ab bae6852 ad152ab bae6852 d39f3fd 0b8919f d39f3fd bae6852 d39f3fd bae6852 6a6db87 bae6852 6a6db87 bae6852 ad152ab bae6852 d39f3fd 6a6db87 bae6852 d39f3fd bae6852 d39f3fd bae6852 ad152ab bae6852 ad152ab bae6852 ad152ab 5c4a549 ad152ab bae6852 31984b2 ad152ab bae6852 5c4a549 bae6852 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
from fastapi import FastAPI, HTTPException
from typing import List
from pydantic import BaseModel
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from IndicTransToolkit import IndicProcessor
from fastapi.middleware.cors import CORSMiddleware
import torch
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
model = AutoModelForSeq2SeqLM.from_pretrained(
"ai4bharat/indictrans2-indic-indic-1B", trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
"ai4bharat/indictrans2-indic-indic-1B", trust_remote_code=True
)
ip = IndicProcessor(inference=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(DEVICE)
def translate_text(sentences: List[str], src_lang: str, target_lang: str):
try:
batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=target_lang)
inputs = tokenizer(
batch,
truncation=True,
padding="longest",
return_tensors="pt",
return_attention_mask=True,
).to(DEVICE)
with torch.no_grad():
generated_tokens = model.generate(
**inputs,
use_cache=True,
min_length=0,
max_length=256,
num_beams=5,
num_return_sequences=1,
)
with tokenizer.as_target_tokenizer():
generated_tokens = tokenizer.batch_decode(
generated_tokens.detach().cpu().tolist(),
skip_special_tokens=True,
)
return generated_tokens
except Exception as e:
return str(e)
@app.get("/")
def read_root():
return {"Hello": "World"}
class TranslateRequest(BaseModel):
sentences: List[str]
src_lang: str
target_lang: str
@app.post("/translate/")
def translate(request: TranslateRequest):
try:
result = translate_text(request.sentences, request.src_lang, request.target_lang)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
|