jellychoco
solve version conflict
d466149
import json
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Union, Dict, Any
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer, MarianMTModel, MarianTokenizer
# λͺ¨λΈ λ‘œλ“œ
M2M100_MODEL_NAME = "facebook/m2m100_418M"
OPUS_MT_MODEL_NAME = "Helsinki-NLP/opus-mt-tc-big-en-ko"
# M2M100 (λ‹€κ΅­μ–΄ λ²ˆμ—­)
m2m100_tokenizer = M2M100Tokenizer.from_pretrained(M2M100_MODEL_NAME)
m2m100_model = M2M100ForConditionalGeneration.from_pretrained(M2M100_MODEL_NAME)
# Helsinki-NLP Opus-MT (μ˜μ–΄ ↔ ν•œκ΅­μ–΄ μ „μš©)
opus_tokenizer = MarianTokenizer.from_pretrained(OPUS_MT_MODEL_NAME)
opus_model = MarianMTModel.from_pretrained(OPUS_MT_MODEL_NAME)
# CPUμ—μ„œ μ‹€ν–‰
device = torch.device("cpu")
m2m100_model.to(device)
opus_model.to(device)
# FastAPI μ•±
app = FastAPI()
# μš”μ²­ 데이터 λͺ¨λΈ
class TranslationRequest(BaseModel):
model: str # μ‚¬μš©ν•  λͺ¨λΈ ("m2m100" λ˜λŠ” "opus-mt")
from_lang: str # μž…λ ₯ μ–Έμ–΄ (예: "ko", "en", "fr")
to: str # 좜λ ₯ μ–Έμ–΄ (예: "ko", "fr")
data: Dict[str, Any] # λ²ˆμ—­ν•  JSON 객체
# M2M100 λ²ˆμ—­ ν•¨μˆ˜ (λͺ¨λ“  μ–Έμ–΄ 지원)
def translate_m2m100(text: str, src_lang: str, tgt_lang: str) -> str:
if not text.strip():
return text # 빈 λ¬Έμžμ—΄μ΄λ©΄ κ·ΈλŒ€λ‘œ λ°˜ν™˜
m2m100_tokenizer.src_lang = src_lang
encoded_text = m2m100_tokenizer(text, return_tensors="pt").to(device)
generated_tokens = m2m100_model.generate(
**encoded_text, forced_bos_token_id=m2m100_tokenizer.get_lang_id(tgt_lang)
)
return m2m100_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
# Helsinki-NLP Opus-MT λ²ˆμ—­ ν•¨μˆ˜ (μ˜μ–΄ ↔ ν•œκ΅­μ–΄ μ „μš©)
def translate_opus_mt(text: str, src_lang: str, tgt_lang: str) -> str:
if not text.strip():
return text # 빈 λ¬Έμžμ—΄μ΄λ©΄ κ·ΈλŒ€λ‘œ λ°˜ν™˜
if src_lang == "en" and tgt_lang == "ko":
model_name = "Helsinki-NLP/opus-mt-en-ko"
elif src_lang == "ko" and tgt_lang == "en":
model_name = "Helsinki-NLP/opus-mt-ko-en"
else:
raise HTTPException(status_code=400, detail="Opus-MTλŠ” μ˜μ–΄ ↔ ν•œκ΅­μ–΄λ§Œ μ§€μ›ν•©λ‹ˆλ‹€.")
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name).to(device)
encoded_text = tokenizer(text, return_tensors="pt", padding=True).to(device)
generated_tokens = model.generate(**encoded_text)
return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
# μž¬κ·€μ μœΌλ‘œ JSON λ²ˆμ—­ ν•¨μˆ˜
def recursive_translate(json_obj: Union[Dict[str, Any], str], src_lang: str, tgt_lang: str, model_type: str):
if isinstance(json_obj, str): # 단일 λ¬Έμžμ—΄μ΄λ©΄ λ²ˆμ—­
if model_type == "m2m100":
return translate_m2m100(json_obj, src_lang, tgt_lang)
elif model_type == "opus-mt":
return translate_opus_mt(json_obj, src_lang, tgt_lang)
elif isinstance(json_obj, dict): # λ”•μ…”λ„ˆλ¦¬λ©΄ μž¬κ·€μ μœΌλ‘œ λ²ˆμ—­
return {key: recursive_translate(value, src_lang, tgt_lang, model_type) for key, value in json_obj.items()}
else:
return json_obj # 숫자, 리슀트 등은 λ²ˆμ—­ν•˜μ§€ μ•Šκ³  κ·ΈλŒ€λ‘œ λ°˜ν™˜
@app.post("/translate")
async def translate_json(request: TranslationRequest):
"""JSON 데이터λ₯Ό λ²ˆμ—­ν•˜λŠ” API"""
model_type = request.model # "m2m100" λ˜λŠ” "opus-mt"
src_lang = request.from_lang
tgt_lang = request.to
input_data = request.data
# μ§€μ›ν•˜λŠ” μ–Έμ–΄ λͺ©λ‘ (M2M100은 거의 λͺ¨λ“  μ–Έμ–΄ 지원)
supported_langs = ["ko", "en", "fr", "es", "ja", "zh", "de", "it"]
# λͺ¨λΈ 선택
if model_type == "m2m100":
if src_lang not in supported_langs or tgt_lang not in supported_langs:
raise HTTPException(status_code=400, detail=f"μ§€μ›λ˜μ§€ μ•ŠλŠ” μ–Έμ–΄ μ½”λ“œ: {src_lang} β†’ {tgt_lang}")
elif model_type == "opus-mt":
if not (src_lang in ["en", "ko"] and tgt_lang in ["en", "ko"]):
raise HTTPException(status_code=400, detail="Opus-MT λͺ¨λΈμ€ μ˜μ–΄ ↔ ν•œκ΅­μ–΄ λ²ˆμ—­λ§Œ μ§€μ›ν•©λ‹ˆλ‹€.")
else:
raise HTTPException(status_code=400, detail="μ§€μ›λ˜μ§€ μ•ŠλŠ” λͺ¨λΈ 선택")
# μž¬κ·€μ μœΌλ‘œ JSON λ²ˆμ—­ μ‹€ν–‰
translated_data = recursive_translate(input_data, src_lang, tgt_lang, model_type)
return translated_data