import os from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel, field_validator from transformers import ( AutoConfig, GenerationConfig, AutoTokenizer, AutoModelForCausalLM, ) from google.cloud import storage from google.auth.exceptions import DefaultCredentialsError import uvicorn import asyncio import json import logging from huggingface_hub import login from dotenv import load_dotenv import huggingface_hub load_dotenv() GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME") GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON") HUGGINGFACE_HUB_TOKEN = os.getenv("HF_API_TOKEN") if HUGGINGFACE_HUB_TOKEN: login(token=HUGGINGFACE_HUB_TOKEN) os.system("git config --global credential.helper store") huggingface_hub.login(token=HUGGINGFACE_HUB_TOKEN, add_to_git_credential=True) logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) try: credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON) client = storage.Client.from_service_account_info(credentials_info) bucket = client.get_bucket(GCS_BUCKET_NAME) logger.info(f"Connection to Google Cloud Storage successful. Bucket: {GCS_BUCKET_NAME}") except (DefaultCredentialsError, json.JSONDecodeError, KeyError, ValueError) as e: logger.error(f"Error loading credentials or bucket: {e}") raise RuntimeError(f"Error loading credentials or bucket: {e}") app = FastAPI() class GenerateRequest(BaseModel): model_name: str input_text: str task_type: str temperature: float = 1.0 max_new_tokens: int = 20 stream: bool = False top_p: float = 1.0 top_k: int = 50 repetition_penalty: float = 1.0 num_return_sequences: int = 1 do_sample: bool = False chunk_delay: float = 0.1 stop_sequences: list = [] @field_validator("model_name") def model_name_cannot_be_empty(cls, v): if not v: raise ValueError("model_name cannot be empty.") return v @field_validator("task_type") def task_type_must_be_valid(cls, v): valid_types = ["text-to-text"] if v not in valid_types: raise ValueError(f"task_type must be one of: {valid_types}") return v class GCSModelLoader: def __init__(self, bucket): self.bucket = bucket def _get_gcs_uri(self, model_name): return f"{model_name}" def _blob_exists(self, blob_path): blob = self.bucket.blob(blob_path) return blob.exists() def _download_content(self, blob_path): blob = self.bucket.blob(blob_path) if blob.exists(): return blob.download_as_bytes() return None def _upload_content(self, content, blob_path): blob = self.bucket.blob(blob_path) blob.upload_from_string(content) def _create_model_folder(self, model_name): gcs_model_folder = self._get_gcs_uri(model_name) if not self._blob_exists(f"{gcs_model_folder}/.touch"): blob = self.bucket.blob(f"{gcs_model_folder}/.touch") blob.upload_from_string("") logger.info(f"Created folder '{gcs_model_folder}' in GCS.") def load_config(self, model_name): gcs_config_path = f"{self._get_gcs_uri(model_name)}/config.json" config_content = self._download_content(gcs_config_path) if config_content: try: return AutoConfig.from_pretrained(pretrained_model_name_or_path=None, trust_remote_code=True, config_dict=json.loads(config_content), token=HUGGINGFACE_HUB_TOKEN) except Exception as e: logger.error(f"Error loading config from GCS: {e}") return None else: try: config = AutoConfig.from_pretrained(model_name, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN) gcs_model_folder = self._get_gcs_uri(model_name) self._create_model_folder(model_name) self._upload_content(json.dumps(config.to_dict()).encode('utf-8'), f"{gcs_model_folder}/config.json") return config except Exception as e: logger.error(f"Error loading config from Hugging Face and saving to GCS: {e}") return None def load_tokenizer(self, model_name): gcs_tokenizer_path = self._get_gcs_uri(model_name) tokenizer_files = ["tokenizer_config.json", "vocab.json", "merges.txt", "tokenizer.json", "special_tokens_map.json"] gcs_files_exist = all(self._blob_exists(f"{gcs_tokenizer_path}/{f}") for f in tokenizer_files) if gcs_files_exist: try: return AutoTokenizer.from_pretrained(gcs_tokenizer_path, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN) except Exception as e: logger.error(f"Error loading tokenizer from GCS: {e}") return None else: try: tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN) gcs_model_folder = self._get_gcs_uri(model_name) self._create_model_folder(model_name) tokenizer.save_pretrained(gcs_model_folder) return tokenizer except Exception as e: logger.error(f"Error loading tokenizer from Hugging Face and saving to GCS: {e}") return None def load_model(self, model_name, config): gcs_model_path = self._get_gcs_uri(model_name) blobs = self.bucket.list_blobs(prefix=gcs_model_path) model_files_present = any(blob.name.endswith((".bin", ".safetensors")) for blob in blobs) if model_files_present: try: return AutoModelForCausalLM.from_pretrained(gcs_model_path, config=config, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN) except Exception as e: logger.error(f"Error loading model from GCS: {e}") raise HTTPException(status_code=500, detail=f"Error loading model from GCS: {e}") else: try: model = AutoModelForCausalLM.from_pretrained(model_name, config=config, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN) gcs_model_folder = self._get_gcs_uri(model_name) self._create_model_folder(model_name) model.save_pretrained(gcs_model_folder) return model except Exception as e: logger.error(f"Error loading model from Hugging Face and saving to GCS: {e}") raise HTTPException(status_code=500, detail="Failed to load model") model_loader = GCSModelLoader(bucket) async def generate_stream(model, tokenizer, input_text, generation_config): inputs = tokenizer(input_text, return_tensors="pt").to(model.device) async def token_stream(): generation_stream = model.generate( **inputs, generation_config=generation_config, stream=True, ) async for output in generation_stream: token_id = output[-1] token = tokenizer.decode(token_id, skip_special_tokens=True) yield {"token": token} return token_stream() def generate_non_stream(model, tokenizer, input_text, generation_config): inputs = tokenizer(input_text, return_tensors="pt").to(model.device) outputs = model.generate(**inputs, generation_config=generation_config) return tokenizer.decode(outputs[0], skip_special_tokens=True) @app.post("/generate") async def generate(request: GenerateRequest): model_name = request.model_name input_text = request.input_text task_type = request.task_type stream = request.stream generation_params = request.model_dump( exclude_none=True, exclude={'model_name', 'input_text', 'task_type', 'stream', 'chunk_delay'} ) try: gcs_model_folder_uri = model_loader._get_gcs_uri(model_name) if not model_loader._blob_exists(f"{gcs_model_folder_uri}/config.json"): logger.info(f"Model '{model_name}' not found in GCS, checking Hugging Face.") config = model_loader.load_config(model_name) if not config: raise HTTPException(status_code=400, detail="Model configuration could not be loaded.") tokenizer = model_loader.load_tokenizer(model_name) if not tokenizer: raise HTTPException(status_code=400, detail="Tokenizer could not be loaded.") generation_config_kwargs = generation_params.copy() generation_config_kwargs['pad_token_id'] = tokenizer.pad_token_id generation_config_kwargs['eos_token_id'] = tokenizer.eos_token_id generation_config_kwargs['sep_token_id'] = tokenizer.sep_token_id generation_config_kwargs['unk_token_id'] = tokenizer.unk_token_id model = model_loader.load_model(model_name, config) if not model: raise HTTPException(status_code=400, detail="Model could not be loaded.") generation_config = GenerationConfig.from_pretrained( model_name, trust_remote_code=True, **generation_config_kwargs ) if task_type == "text-to-text": if stream: async def event_stream(): async for output in generate_stream(model, tokenizer, input_text, generation_config): yield f"data: {json.dumps(output)}\n\n" await asyncio.sleep(request.chunk_delay) return StreamingResponse(event_stream(), media_type="text/event-stream") else: text_result = generate_non_stream(model, tokenizer, input_text, generation_config) return {"text": text_result} else: raise HTTPException(status_code=400, detail=f"Task type not supported: {task_type}") except HTTPException as e: raise e except Exception as e: logger.error(f"Internal server error: {e}") raise HTTPException(status_code=500, detail=f"Internal server error: {e}") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)