Spaces:
Sleeping
Sleeping
File size: 4,425 Bytes
4244176 d466149 4244176 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import json
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Union, Dict, Any
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer, MarianMTModel, MarianTokenizer
# λͺ¨λΈ λ‘λ
M2M100_MODEL_NAME = "facebook/m2m100_418M"
OPUS_MT_MODEL_NAME = "Helsinki-NLP/opus-mt-tc-big-en-ko"
# M2M100 (λ€κ΅μ΄ λ²μ)
m2m100_tokenizer = M2M100Tokenizer.from_pretrained(M2M100_MODEL_NAME)
m2m100_model = M2M100ForConditionalGeneration.from_pretrained(M2M100_MODEL_NAME)
# Helsinki-NLP Opus-MT (μμ΄ β νκ΅μ΄ μ μ©)
opus_tokenizer = MarianTokenizer.from_pretrained(OPUS_MT_MODEL_NAME)
opus_model = MarianMTModel.from_pretrained(OPUS_MT_MODEL_NAME)
# CPUμμ μ€ν
device = torch.device("cpu")
m2m100_model.to(device)
opus_model.to(device)
# FastAPI μ±
app = FastAPI()
# μμ² λ°μ΄ν° λͺ¨λΈ
class TranslationRequest(BaseModel):
model: str # μ¬μ©ν λͺ¨λΈ ("m2m100" λλ "opus-mt")
from_lang: str # μ
λ ₯ μΈμ΄ (μ: "ko", "en", "fr")
to: str # μΆλ ₯ μΈμ΄ (μ: "ko", "fr")
data: Dict[str, Any] # λ²μν JSON κ°μ²΄
# M2M100 λ²μ ν¨μ (λͺ¨λ μΈμ΄ μ§μ)
def translate_m2m100(text: str, src_lang: str, tgt_lang: str) -> str:
if not text.strip():
return text # λΉ λ¬Έμμ΄μ΄λ©΄ κ·Έλλ‘ λ°ν
m2m100_tokenizer.src_lang = src_lang
encoded_text = m2m100_tokenizer(text, return_tensors="pt").to(device)
generated_tokens = m2m100_model.generate(
**encoded_text, forced_bos_token_id=m2m100_tokenizer.get_lang_id(tgt_lang)
)
return m2m100_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
# Helsinki-NLP Opus-MT λ²μ ν¨μ (μμ΄ β νκ΅μ΄ μ μ©)
def translate_opus_mt(text: str, src_lang: str, tgt_lang: str) -> str:
if not text.strip():
return text # λΉ λ¬Έμμ΄μ΄λ©΄ κ·Έλλ‘ λ°ν
if src_lang == "en" and tgt_lang == "ko":
model_name = "Helsinki-NLP/opus-mt-en-ko"
elif src_lang == "ko" and tgt_lang == "en":
model_name = "Helsinki-NLP/opus-mt-ko-en"
else:
raise HTTPException(status_code=400, detail="Opus-MTλ μμ΄ β νκ΅μ΄λ§ μ§μν©λλ€.")
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name).to(device)
encoded_text = tokenizer(text, return_tensors="pt", padding=True).to(device)
generated_tokens = model.generate(**encoded_text)
return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
# μ¬κ·μ μΌλ‘ JSON λ²μ ν¨μ
def recursive_translate(json_obj: Union[Dict[str, Any], str], src_lang: str, tgt_lang: str, model_type: str):
if isinstance(json_obj, str): # λ¨μΌ λ¬Έμμ΄μ΄λ©΄ λ²μ
if model_type == "m2m100":
return translate_m2m100(json_obj, src_lang, tgt_lang)
elif model_type == "opus-mt":
return translate_opus_mt(json_obj, src_lang, tgt_lang)
elif isinstance(json_obj, dict): # λμ
λ리면 μ¬κ·μ μΌλ‘ λ²μ
return {key: recursive_translate(value, src_lang, tgt_lang, model_type) for key, value in json_obj.items()}
else:
return json_obj # μ«μ, 리μ€νΈ λ±μ λ²μνμ§ μκ³ κ·Έλλ‘ λ°ν
@app.post("/translate")
async def translate_json(request: TranslationRequest):
"""JSON λ°μ΄ν°λ₯Ό λ²μνλ API"""
model_type = request.model # "m2m100" λλ "opus-mt"
src_lang = request.from_lang
tgt_lang = request.to
input_data = request.data
# μ§μνλ μΈμ΄ λͺ©λ‘ (M2M100μ κ±°μ λͺ¨λ μΈμ΄ μ§μ)
supported_langs = ["ko", "en", "fr", "es", "ja", "zh", "de", "it"]
# λͺ¨λΈ μ ν
if model_type == "m2m100":
if src_lang not in supported_langs or tgt_lang not in supported_langs:
raise HTTPException(status_code=400, detail=f"μ§μλμ§ μλ μΈμ΄ μ½λ: {src_lang} β {tgt_lang}")
elif model_type == "opus-mt":
if not (src_lang in ["en", "ko"] and tgt_lang in ["en", "ko"]):
raise HTTPException(status_code=400, detail="Opus-MT λͺ¨λΈμ μμ΄ β νκ΅μ΄ λ²μλ§ μ§μν©λλ€.")
else:
raise HTTPException(status_code=400, detail="μ§μλμ§ μλ λͺ¨λΈ μ ν")
# μ¬κ·μ μΌλ‘ JSON λ²μ μ€ν
translated_data = recursive_translate(input_data, src_lang, tgt_lang, model_type)
return translated_data
|