import os from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel, field_validator, model_validator from transformers import ( AutoConfig, GenerationConfig, AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer ) from google.cloud import storage from google.auth.exceptions import DefaultCredentialsError import uvicorn import asyncio import json import logging from huggingface_hub import login GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME") GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON") HUGGINGFACE_HUB_TOKEN = os.getenv("HF_API_TOKEN") if HUGGINGFACE_HUB_TOKEN: login(token=HUGGINGFACE_HUB_TOKEN) logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) try: credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON) client = storage.Client.from_service_account_info(credentials_info) bucket = client.get_bucket(GCS_BUCKET_NAME) logger.info(f"Connection to Google Cloud Storage successful. Bucket: {GCS_BUCKET_NAME}") except (DefaultCredentialsError, json.JSONDecodeError, KeyError, ValueError) as e: logger.error(f"Error loading credentials or bucket: {e}") raise RuntimeError(f"Error loading credentials or bucket: {e}") app = FastAPI() class GenerateRequest(BaseModel): model_name: str input_text: str task_type: str temperature: float = 1.0 max_new_tokens: int = 20 stream: bool = False top_p: float = 1.0 top_k: int = 50 repetition_penalty: float = 1.0 num_return_sequences: int = 1 do_sample: bool = False chunk_delay: float = 0.1 stop_sequences: list = [] pad_token_id: int | None = None eos_token_id: int | None = None sep_token_id: int | None = None unk_token_id: int | None = None @field_validator("model_name") def model_name_cannot_be_empty(cls, v): if not v: raise ValueError("model_name cannot be empty.") return v @field_validator("task_type") def task_type_must_be_valid(cls, v): valid_types = ["text-to-text"] if v not in valid_types: raise ValueError(f"task_type must be one of: {valid_types}") return v @model_validator(mode='before') def set_default_token_ids(cls, values): values.setdefault("pad_token_id", None) values.setdefault("eos_token_id", None) values.setdefault("sep_token_id", None) values.setdefault("unk_token_id", None) return values class GCSModelLoader: def __init__(self, bucket): self.bucket = bucket def _get_gcs_uri(self, model_name): return f"{model_name}" def _blob_exists(self, blob_path): blob = self.bucket.blob(blob_path) return blob.exists() def _download_content(self, blob_path): blob = self.bucket.blob(blob_path) if self._blob_exists(blob_path): return blob.download_as_bytes() return None def _upload_content(self, content, blob_path): blob = self.bucket.blob(blob_path) blob.upload_from_string(content) def load_config(self, model_name): gcs_config_path = f"{self._get_gcs_uri(model_name)}/config.json" config_content = self._download_content(gcs_config_path) if config_content: try: return AutoConfig.from_pretrained(pretrained_model_name_or_path=None, trust_remote_code=True, config_dict=json.loads(config_content), token=HUGGINGFACE_HUB_TOKEN) except Exception as e: logger.error(f"Error loading config from GCS: {e}") return None else: try: config = AutoConfig.from_pretrained(model_name, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN) gcs_model_folder = self._get_gcs_uri(model_name) bucket.blob(f"{gcs_model_folder}/config.json").upload_from_string(json.dumps(config.to_dict()).encode('utf-8')) return config except Exception as e: logger.error(f"Error loading config from Hugging Face and saving to GCS: {e}") return None def load_tokenizer(self, model_name): gcs_tokenizer_path = self._get_gcs_uri(model_name) tokenizer_files = ["tokenizer_config.json", "vocab.json", "merges.txt", "tokenizer.json", "special_tokens_map.json"] gcs_files_exist = all(self._blob_exists(f"{gcs_tokenizer_path}/{f}") for f in tokenizer_files) if gcs_files_exist: try: return AutoTokenizer.from_pretrained(gcs_tokenizer_path, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN) except Exception as e: logger.error(f"Error loading tokenizer from GCS: {e}") return None else: try: tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN) gcs_model_folder = self._get_gcs_uri(model_name) for file in tokenizer.save_pretrained(None): with open(file, 'rb') as f: bucket.blob(f"{gcs_model_folder}/{os.path.basename(file)}").upload_from_string(f.read()) return tokenizer except Exception as e: logger.error(f"Error loading tokenizer from Hugging Face and saving to GCS: {e}") return None def load_model(self, model_name, config): gcs_model_path = self._get_gcs_uri(model_name) blobs = self.bucket.list_blobs(prefix=gcs_model_path) model_files_present = any(blob.name.endswith((".bin", ".safetensors")) for blob in blobs) if model_files_present: try: return AutoModelForCausalLM.from_pretrained(gcs_model_path, config=config, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN) except Exception as e: logger.error(f"Error loading model from GCS: {e}") raise HTTPException(status_code=500, detail=f"Error loading model from GCS: {e}") else: try: model = AutoModelForCausalLM.from_pretrained(model_name, config=config, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN) gcs_model_folder = self._get_gcs_uri(model_name) for filename in os.listdir(model.save_pretrained(None)): with open(os.path.join(model.save_pretrained(None)), 'rb') as f: bucket.blob(f"{gcs_model_folder}/{filename}").upload_from_string(f.read()) return model except Exception as e: logger.error(f"Error loading model from Hugging Face and saving to GCS: {e}") raise HTTPException(status_code=500, detail="Failed to load model") model_loader = GCSModelLoader(bucket) async def generate_stream(model, tokenizer, input_text, generation_config, chunk_delay): inputs = tokenizer(input_text, return_tensors="pt").to(model.device) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict(inputs, generation_config=generation_config, streamer=streamer) asyncio.create_task(model.generate(**generation_kwargs)) async def event_stream(): for token in streamer: yield {"token": token} await asyncio.sleep(chunk_delay) return event_stream() def generate_non_stream(model, tokenizer, input_text, generation_config): inputs = tokenizer(input_text, return_tensors="pt").to(model.device) outputs = model.generate(**inputs, generation_config=generation_config) return tokenizer.decode(outputs[0], skip_special_tokens=True) @app.post("/generate") async def generate(request: GenerateRequest): model_name = request.model_name input_text = request.input_text task_type = request.task_type stream = request.stream generation_params = {} for key, value in request.model_dump(exclude_none=True, exclude={'model_name', 'input_text', 'task_type', 'stream', 'chunk_delay'}).items(): generation_params[key] = value try: gcs_model_folder_uri = model_loader._get_gcs_uri(model_name) if not bucket.blob(f"{gcs_model_folder_uri}/config.json").exists(): logger.info(f"Model '{model_name}' not found in GCS, creating placeholder.") bucket.blob(f"{gcs_model_folder_uri}/.placeholder").upload_from_string("") config = model_loader.load_config(model_name) if not config: raise HTTPException(status_code=400, detail="Model configuration could not be loaded.") tokenizer = model_loader.load_tokenizer(model_name) if not tokenizer: raise HTTPException(status_code=400, detail="Tokenizer could not be loaded.") if request.pad_token_id is None: request.pad_token_id = tokenizer.pad_token_id if request.eos_token_id is None: request.eos_token_id = tokenizer.eos_token_id if request.sep_token_id is None: request.sep_token_id = tokenizer.sep_token_id if request.unk_token_id is None: request.unk_token_id = tokenizer.unk_token_id model = model_loader.load_model(model_name, config) if not model: raise HTTPException(status_code=400, detail="Model could not be loaded.") generation_config = GenerationConfig.from_pretrained( model_name, trust_remote_code=True, **generation_params, pad_token_id=request.pad_token_id, eos_token_id=request.eos_token_id, sep_token_id=request.sep_token_id, unk_token_id=request.unk_token_id ) if task_type == "text-to-text": if stream: event_generator = await generate_stream(model, tokenizer, input_text, generation_config, request.chunk_delay) return StreamingResponse(event_generator(), media_type="text/event-stream") else: text_result = generate_non_stream(model, tokenizer, input_text, generation_config) return {"text": text_result} else: raise HTTPException(status_code=400, detail=f"Task type not supported: {task_type}") except HTTPException as e: raise e except Exception as e: logger.error(f"Internal server error: {e}") raise HTTPException(status_code=500, detail=f"Internal server error: {e}") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)