Spaces:
Sleeping
Sleeping
import json | |
import torch | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from typing import Union, Dict, Any | |
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer, MarianMTModel, MarianTokenizer | |
# λͺ¨λΈ λ‘λ | |
M2M100_MODEL_NAME = "facebook/m2m100_418M" | |
OPUS_MT_MODEL_NAME = "Helsinki-NLP/opus-mt-tc-big-en-ko" | |
# M2M100 (λ€κ΅μ΄ λ²μ) | |
m2m100_tokenizer = M2M100Tokenizer.from_pretrained(M2M100_MODEL_NAME) | |
m2m100_model = M2M100ForConditionalGeneration.from_pretrained(M2M100_MODEL_NAME) | |
# Helsinki-NLP Opus-MT (μμ΄ β νκ΅μ΄ μ μ©) | |
opus_tokenizer = MarianTokenizer.from_pretrained(OPUS_MT_MODEL_NAME) | |
opus_model = MarianMTModel.from_pretrained(OPUS_MT_MODEL_NAME) | |
# CPUμμ μ€ν | |
device = torch.device("cpu") | |
m2m100_model.to(device) | |
opus_model.to(device) | |
# FastAPI μ± | |
app = FastAPI() | |
# μμ² λ°μ΄ν° λͺ¨λΈ | |
class TranslationRequest(BaseModel): | |
model: str # μ¬μ©ν λͺ¨λΈ ("m2m100" λλ "opus-mt") | |
from_lang: str # μ λ ₯ μΈμ΄ (μ: "ko", "en", "fr") | |
to: str # μΆλ ₯ μΈμ΄ (μ: "ko", "fr") | |
data: Dict[str, Any] # λ²μν JSON κ°μ²΄ | |
# M2M100 λ²μ ν¨μ (λͺ¨λ μΈμ΄ μ§μ) | |
def translate_m2m100(text: str, src_lang: str, tgt_lang: str) -> str: | |
if not text.strip(): | |
return text # λΉ λ¬Έμμ΄μ΄λ©΄ κ·Έλλ‘ λ°ν | |
m2m100_tokenizer.src_lang = src_lang | |
encoded_text = m2m100_tokenizer(text, return_tensors="pt").to(device) | |
generated_tokens = m2m100_model.generate( | |
**encoded_text, forced_bos_token_id=m2m100_tokenizer.get_lang_id(tgt_lang) | |
) | |
return m2m100_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
# Helsinki-NLP Opus-MT λ²μ ν¨μ (μμ΄ β νκ΅μ΄ μ μ©) | |
def translate_opus_mt(text: str, src_lang: str, tgt_lang: str) -> str: | |
if not text.strip(): | |
return text # λΉ λ¬Έμμ΄μ΄λ©΄ κ·Έλλ‘ λ°ν | |
if src_lang == "en" and tgt_lang == "ko": | |
model_name = "Helsinki-NLP/opus-mt-en-ko" | |
elif src_lang == "ko" and tgt_lang == "en": | |
model_name = "Helsinki-NLP/opus-mt-ko-en" | |
else: | |
raise HTTPException(status_code=400, detail="Opus-MTλ μμ΄ β νκ΅μ΄λ§ μ§μν©λλ€.") | |
tokenizer = MarianTokenizer.from_pretrained(model_name) | |
model = MarianMTModel.from_pretrained(model_name).to(device) | |
encoded_text = tokenizer(text, return_tensors="pt", padding=True).to(device) | |
generated_tokens = model.generate(**encoded_text) | |
return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
# μ¬κ·μ μΌλ‘ JSON λ²μ ν¨μ | |
def recursive_translate(json_obj: Union[Dict[str, Any], str], src_lang: str, tgt_lang: str, model_type: str): | |
if isinstance(json_obj, str): # λ¨μΌ λ¬Έμμ΄μ΄λ©΄ λ²μ | |
if model_type == "m2m100": | |
return translate_m2m100(json_obj, src_lang, tgt_lang) | |
elif model_type == "opus-mt": | |
return translate_opus_mt(json_obj, src_lang, tgt_lang) | |
elif isinstance(json_obj, dict): # λμ λ리면 μ¬κ·μ μΌλ‘ λ²μ | |
return {key: recursive_translate(value, src_lang, tgt_lang, model_type) for key, value in json_obj.items()} | |
else: | |
return json_obj # μ«μ, 리μ€νΈ λ±μ λ²μνμ§ μκ³ κ·Έλλ‘ λ°ν | |
async def translate_json(request: TranslationRequest): | |
"""JSON λ°μ΄ν°λ₯Ό λ²μνλ API""" | |
model_type = request.model # "m2m100" λλ "opus-mt" | |
src_lang = request.from_lang | |
tgt_lang = request.to | |
input_data = request.data | |
# μ§μνλ μΈμ΄ λͺ©λ‘ (M2M100μ κ±°μ λͺ¨λ μΈμ΄ μ§μ) | |
supported_langs = ["ko", "en", "fr", "es", "ja", "zh", "de", "it"] | |
# λͺ¨λΈ μ ν | |
if model_type == "m2m100": | |
if src_lang not in supported_langs or tgt_lang not in supported_langs: | |
raise HTTPException(status_code=400, detail=f"μ§μλμ§ μλ μΈμ΄ μ½λ: {src_lang} β {tgt_lang}") | |
elif model_type == "opus-mt": | |
if not (src_lang in ["en", "ko"] and tgt_lang in ["en", "ko"]): | |
raise HTTPException(status_code=400, detail="Opus-MT λͺ¨λΈμ μμ΄ β νκ΅μ΄ λ²μλ§ μ§μν©λλ€.") | |
else: | |
raise HTTPException(status_code=400, detail="μ§μλμ§ μλ λͺ¨λΈ μ ν") | |
# μ¬κ·μ μΌλ‘ JSON λ²μ μ€ν | |
translated_data = recursive_translate(input_data, src_lang, tgt_lang, model_type) | |
return translated_data | |