File size: 2,174 Bytes
ad152ab
bae6852
ad152ab
bae6852
 
d39f3fd
0b8919f
d39f3fd
bae6852
d39f3fd
 
 
 
 
 
 
 
 
bae6852
6a6db87
bae6852
 
6a6db87
bae6852
ad152ab
bae6852
 
 
 
d39f3fd
6a6db87
bae6852
 
 
 
 
 
 
 
 
d39f3fd
bae6852
 
 
 
 
 
 
 
 
d39f3fd
bae6852
 
 
 
 
 
ad152ab
bae6852
ad152ab
 
 
 
 
 
bae6852
 
ad152ab
 
5c4a549
ad152ab
bae6852
 
31984b2
ad152ab
bae6852
5c4a549
bae6852
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from fastapi import FastAPI, HTTPException
from typing import List
from pydantic import BaseModel
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from IndicTransToolkit import IndicProcessor
from fastapi.middleware.cors import CORSMiddleware
import torch

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

model = AutoModelForSeq2SeqLM.from_pretrained(
    "ai4bharat/indictrans2-indic-indic-1B", trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
    "ai4bharat/indictrans2-indic-indic-1B", trust_remote_code=True
)

ip = IndicProcessor(inference=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(DEVICE)


def translate_text(sentences: List[str], src_lang: str, target_lang: str):
    try:
        batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=target_lang)
        inputs = tokenizer(
            batch,
            truncation=True,
            padding="longest",
            return_tensors="pt",
            return_attention_mask=True,
        ).to(DEVICE)

        with torch.no_grad():
            generated_tokens = model.generate(
                **inputs,
                use_cache=True,
                min_length=0,
                max_length=256,
                num_beams=5,
                num_return_sequences=1,
            )

        with tokenizer.as_target_tokenizer():
            generated_tokens = tokenizer.batch_decode(
                generated_tokens.detach().cpu().tolist(),
                skip_special_tokens=True,
            )

        return generated_tokens
    except Exception as e:
        return str(e)


@app.get("/")
def read_root():
    return {"Hello": "World"}


class TranslateRequest(BaseModel):
    sentences: List[str]
    src_lang: str
    target_lang: str


@app.post("/translate/")
def translate(request: TranslateRequest):
    try:
        result = translate_text(request.sentences, request.src_lang, request.target_lang)
        return result
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))