import json import torch from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import Union, Dict, Any from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer, MarianMTModel, MarianTokenizer # 모델 로드 M2M100_MODEL_NAME = "facebook/m2m100_418M" OPUS_MT_MODEL_NAME = "Helsinki-NLP/opus-mt-tc-big-en-ko" # M2M100 (다국어 번역) m2m100_tokenizer = M2M100Tokenizer.from_pretrained(M2M100_MODEL_NAME) m2m100_model = M2M100ForConditionalGeneration.from_pretrained(M2M100_MODEL_NAME) # Helsinki-NLP Opus-MT (영어 ↔ 한국어 전용) opus_tokenizer = MarianTokenizer.from_pretrained(OPUS_MT_MODEL_NAME) opus_model = MarianMTModel.from_pretrained(OPUS_MT_MODEL_NAME) # CPU에서 실행 device = torch.device("cpu") m2m100_model.to(device) opus_model.to(device) # FastAPI 앱 app = FastAPI() # 요청 데이터 모델 class TranslationRequest(BaseModel): model: str # 사용할 모델 ("m2m100" 또는 "opus-mt") from_lang: str # 입력 언어 (예: "ko", "en", "fr") to: str # 출력 언어 (예: "ko", "fr") data: Dict[str, Any] # 번역할 JSON 객체 # M2M100 번역 함수 (모든 언어 지원) def translate_m2m100(text: str, src_lang: str, tgt_lang: str) -> str: if not text.strip(): return text # 빈 문자열이면 그대로 반환 m2m100_tokenizer.src_lang = src_lang encoded_text = m2m100_tokenizer(text, return_tensors="pt").to(device) generated_tokens = m2m100_model.generate( **encoded_text, forced_bos_token_id=m2m100_tokenizer.get_lang_id(tgt_lang) ) return m2m100_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] # Helsinki-NLP Opus-MT 번역 함수 (영어 ↔ 한국어 전용) def translate_opus_mt(text: str, src_lang: str, tgt_lang: str) -> str: if not text.strip(): return text # 빈 문자열이면 그대로 반환 if src_lang == "en" and tgt_lang == "ko": model_name = "Helsinki-NLP/opus-mt-en-ko" elif src_lang == "ko" and tgt_lang == "en": model_name = "Helsinki-NLP/opus-mt-ko-en" else: raise HTTPException(status_code=400, detail="Opus-MT는 영어 ↔ 한국어만 지원합니다.") tokenizer = MarianTokenizer.from_pretrained(model_name) model = MarianMTModel.from_pretrained(model_name).to(device) encoded_text = tokenizer(text, return_tensors="pt", padding=True).to(device) generated_tokens = model.generate(**encoded_text) return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] # 재귀적으로 JSON 번역 함수 def recursive_translate(json_obj: Union[Dict[str, Any], str], src_lang: str, tgt_lang: str, model_type: str): if isinstance(json_obj, str): # 단일 문자열이면 번역 if model_type == "m2m100": return translate_m2m100(json_obj, src_lang, tgt_lang) elif model_type == "opus-mt": return translate_opus_mt(json_obj, src_lang, tgt_lang) elif isinstance(json_obj, dict): # 딕셔너리면 재귀적으로 번역 return {key: recursive_translate(value, src_lang, tgt_lang, model_type) for key, value in json_obj.items()} else: return json_obj # 숫자, 리스트 등은 번역하지 않고 그대로 반환 @app.post("/translate") async def translate_json(request: TranslationRequest): """JSON 데이터를 번역하는 API""" model_type = request.model # "m2m100" 또는 "opus-mt" src_lang = request.from_lang tgt_lang = request.to input_data = request.data # 지원하는 언어 목록 (M2M100은 거의 모든 언어 지원) supported_langs = ["ko", "en", "fr", "es", "ja", "zh", "de", "it"] # 모델 선택 if model_type == "m2m100": if src_lang not in supported_langs or tgt_lang not in supported_langs: raise HTTPException(status_code=400, detail=f"지원되지 않는 언어 코드: {src_lang} → {tgt_lang}") elif model_type == "opus-mt": if not (src_lang in ["en", "ko"] and tgt_lang in ["en", "ko"]): raise HTTPException(status_code=400, detail="Opus-MT 모델은 영어 ↔ 한국어 번역만 지원합니다.") else: raise HTTPException(status_code=400, detail="지원되지 않는 모델 선택") # 재귀적으로 JSON 번역 실행 translated_data = recursive_translate(input_data, src_lang, tgt_lang, model_type) return translated_data