File size: 4,425 Bytes
4244176
 
 
 
 
 
 
 
 
d466149
4244176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import json
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Union, Dict, Any
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer, MarianMTModel, MarianTokenizer

# λͺ¨λΈ λ‘œλ“œ
M2M100_MODEL_NAME = "facebook/m2m100_418M"
OPUS_MT_MODEL_NAME = "Helsinki-NLP/opus-mt-tc-big-en-ko"

# M2M100 (λ‹€κ΅­μ–΄ λ²ˆμ—­)
m2m100_tokenizer = M2M100Tokenizer.from_pretrained(M2M100_MODEL_NAME)
m2m100_model = M2M100ForConditionalGeneration.from_pretrained(M2M100_MODEL_NAME)

# Helsinki-NLP Opus-MT (μ˜μ–΄ ↔ ν•œκ΅­μ–΄ μ „μš©)
opus_tokenizer = MarianTokenizer.from_pretrained(OPUS_MT_MODEL_NAME)
opus_model = MarianMTModel.from_pretrained(OPUS_MT_MODEL_NAME)

# CPUμ—μ„œ μ‹€ν–‰
device = torch.device("cpu")
m2m100_model.to(device)
opus_model.to(device)

# FastAPI μ•±
app = FastAPI()

# μš”μ²­ 데이터 λͺ¨λΈ
class TranslationRequest(BaseModel):
    model: str  # μ‚¬μš©ν•  λͺ¨λΈ ("m2m100" λ˜λŠ” "opus-mt")
    from_lang: str  # μž…λ ₯ μ–Έμ–΄ (예: "ko", "en", "fr")
    to: str  # 좜λ ₯ μ–Έμ–΄ (예: "ko", "fr")
    data: Dict[str, Any]  # λ²ˆμ—­ν•  JSON 객체

# M2M100 λ²ˆμ—­ ν•¨μˆ˜ (λͺ¨λ“  μ–Έμ–΄ 지원)
def translate_m2m100(text: str, src_lang: str, tgt_lang: str) -> str:
    if not text.strip():
        return text  # 빈 λ¬Έμžμ—΄μ΄λ©΄ κ·ΈλŒ€λ‘œ λ°˜ν™˜

    m2m100_tokenizer.src_lang = src_lang
    encoded_text = m2m100_tokenizer(text, return_tensors="pt").to(device)
    generated_tokens = m2m100_model.generate(
        **encoded_text, forced_bos_token_id=m2m100_tokenizer.get_lang_id(tgt_lang)
    )
    return m2m100_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]

# Helsinki-NLP Opus-MT λ²ˆμ—­ ν•¨μˆ˜ (μ˜μ–΄ ↔ ν•œκ΅­μ–΄ μ „μš©)
def translate_opus_mt(text: str, src_lang: str, tgt_lang: str) -> str:
    if not text.strip():
        return text  # 빈 λ¬Έμžμ—΄μ΄λ©΄ κ·ΈλŒ€λ‘œ λ°˜ν™˜

    if src_lang == "en" and tgt_lang == "ko":
        model_name = "Helsinki-NLP/opus-mt-en-ko"
    elif src_lang == "ko" and tgt_lang == "en":
        model_name = "Helsinki-NLP/opus-mt-ko-en"
    else:
        raise HTTPException(status_code=400, detail="Opus-MTλŠ” μ˜μ–΄ ↔ ν•œκ΅­μ–΄λ§Œ μ§€μ›ν•©λ‹ˆλ‹€.")

    tokenizer = MarianTokenizer.from_pretrained(model_name)
    model = MarianMTModel.from_pretrained(model_name).to(device)

    encoded_text = tokenizer(text, return_tensors="pt", padding=True).to(device)
    generated_tokens = model.generate(**encoded_text)
    return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]

# μž¬κ·€μ μœΌλ‘œ JSON λ²ˆμ—­ ν•¨μˆ˜
def recursive_translate(json_obj: Union[Dict[str, Any], str], src_lang: str, tgt_lang: str, model_type: str):
    if isinstance(json_obj, str):  # 단일 λ¬Έμžμ—΄μ΄λ©΄ λ²ˆμ—­
        if model_type == "m2m100":
            return translate_m2m100(json_obj, src_lang, tgt_lang)
        elif model_type == "opus-mt":
            return translate_opus_mt(json_obj, src_lang, tgt_lang)
    elif isinstance(json_obj, dict):  # λ”•μ…”λ„ˆλ¦¬λ©΄ μž¬κ·€μ μœΌλ‘œ λ²ˆμ—­
        return {key: recursive_translate(value, src_lang, tgt_lang, model_type) for key, value in json_obj.items()}
    else:
        return json_obj  # 숫자, 리슀트 등은 λ²ˆμ—­ν•˜μ§€ μ•Šκ³  κ·ΈλŒ€λ‘œ λ°˜ν™˜

@app.post("/translate")
async def translate_json(request: TranslationRequest):
    """JSON 데이터λ₯Ό λ²ˆμ—­ν•˜λŠ” API"""
    model_type = request.model  # "m2m100" λ˜λŠ” "opus-mt"
    src_lang = request.from_lang
    tgt_lang = request.to
    input_data = request.data

    # μ§€μ›ν•˜λŠ” μ–Έμ–΄ λͺ©λ‘ (M2M100은 거의 λͺ¨λ“  μ–Έμ–΄ 지원)
    supported_langs = ["ko", "en", "fr", "es", "ja", "zh", "de", "it"]

    # λͺ¨λΈ 선택
    if model_type == "m2m100":
        if src_lang not in supported_langs or tgt_lang not in supported_langs:
            raise HTTPException(status_code=400, detail=f"μ§€μ›λ˜μ§€ μ•ŠλŠ” μ–Έμ–΄ μ½”λ“œ: {src_lang} β†’ {tgt_lang}")
    elif model_type == "opus-mt":
        if not (src_lang in ["en", "ko"] and tgt_lang in ["en", "ko"]):
            raise HTTPException(status_code=400, detail="Opus-MT λͺ¨λΈμ€ μ˜μ–΄ ↔ ν•œκ΅­μ–΄ λ²ˆμ—­λ§Œ μ§€μ›ν•©λ‹ˆλ‹€.")
    else:
        raise HTTPException(status_code=400, detail="μ§€μ›λ˜μ§€ μ•ŠλŠ” λͺ¨λΈ 선택")

    # μž¬κ·€μ μœΌλ‘œ JSON λ²ˆμ—­ μ‹€ν–‰
    translated_data = recursive_translate(input_data, src_lang, tgt_lang, model_type)

    return translated_data