|
import os |
|
from fastapi import FastAPI, HTTPException |
|
from fastapi.responses import StreamingResponse |
|
from pydantic import BaseModel, field_validator |
|
from transformers import ( |
|
AutoConfig, |
|
GenerationConfig, |
|
AutoTokenizer, |
|
AutoModelForCausalLM, |
|
) |
|
from google.cloud import storage |
|
from google.auth.exceptions import DefaultCredentialsError |
|
import uvicorn |
|
import asyncio |
|
import json |
|
import logging |
|
from huggingface_hub import login |
|
from dotenv import load_dotenv |
|
import huggingface_hub |
|
import torch |
|
from safetensors.torch import load_file as safe_load |
|
|
|
load_dotenv() |
|
|
|
GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME") |
|
GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON") |
|
HUGGINGFACE_HUB_TOKEN = os.getenv("HF_API_TOKEN") |
|
|
|
if HUGGINGFACE_HUB_TOKEN: |
|
login(token=HUGGINGFACE_HUB_TOKEN) |
|
|
|
os.system("git config --global credential.helper store") |
|
if HUGGINGFACE_HUB_TOKEN: |
|
huggingface_hub.login(token=HUGGINGFACE_HUB_TOKEN, add_to_git_credential=True) |
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
try: |
|
credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON) |
|
client = storage.Client.from_service_account_info(credentials_info) |
|
bucket = client.get_bucket(GCS_BUCKET_NAME) |
|
logger.info(f"Connection to Google Cloud Storage successful. Bucket: {GCS_BUCKET_NAME}") |
|
|
|
except (DefaultCredentialsError, json.JSONDecodeError, KeyError, ValueError) as e: |
|
logger.error(f"Error loading credentials or bucket: {e}") |
|
raise RuntimeError(f"Error loading credentials or bucket: {e}") |
|
|
|
app = FastAPI() |
|
|
|
class GenerateRequest(BaseModel): |
|
model_name: str |
|
input_text: str |
|
task_type: str |
|
temperature: float = 1.0 |
|
max_new_tokens: int = 20 |
|
stream: bool = False |
|
top_p: float = 1.0 |
|
top_k: int = 50 |
|
repetition_penalty: float = 1.0 |
|
num_return_sequences: int = 1 |
|
do_sample: bool = False |
|
chunk_delay: float = 0.1 |
|
stop_sequences: list = [] |
|
|
|
@field_validator("model_name") |
|
def model_name_cannot_be_empty(cls, v): |
|
if not v: |
|
raise ValueError("model_name cannot be empty.") |
|
return v |
|
|
|
@field_validator("task_type") |
|
def task_type_must_be_valid(cls, v): |
|
valid_types = ["text-to-text"] |
|
if v not in valid_types: |
|
raise ValueError(f"task_type must be one of: {valid_types}") |
|
return v |
|
|
|
class GCSModelLoader: |
|
def __init__(self, bucket): |
|
self.bucket = bucket |
|
|
|
def _get_gcs_uri(self, model_name): |
|
return f"{model_name}" |
|
|
|
def _blob_exists(self, blob_path): |
|
blob = self.bucket.blob(blob_path) |
|
return blob.exists() |
|
|
|
def _download_content(self, blob_path): |
|
blob = self.bucket.blob(blob_path) |
|
try: |
|
return blob.download_as_bytes() |
|
except Exception as e: |
|
logger.error(f"Error downloading {blob_path}: {e}") |
|
return None |
|
|
|
def _upload_content(self, content, blob_path): |
|
blob = self.bucket.blob(blob_path) |
|
blob.upload_from_string(content) |
|
|
|
def _create_model_folder(self, model_name): |
|
gcs_model_folder = self._get_gcs_uri(model_name) |
|
if not self._blob_exists(f"{gcs_model_folder}/.touch"): |
|
blob = self.bucket.blob(f"{gcs_model_folder}/.touch") |
|
blob.upload_from_string("") |
|
logger.info(f"Created folder '{gcs_model_folder}' in GCS.") |
|
|
|
def load_config(self, model_name): |
|
gcs_config_path = f"{self._get_gcs_uri(model_name)}/config.json" |
|
if self._blob_exists(gcs_config_path): |
|
try: |
|
config_content = self._download_content(gcs_config_path) |
|
return AutoConfig.from_pretrained(pretrained_model_name_or_path="", _commit_hash=None, config_dict=json.loads(config_content), trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN) |
|
except Exception as e: |
|
logger.error(f"Error loading config from GCS: {e}") |
|
try: |
|
logger.info(f"Downloading config from Hugging Face for {model_name}") |
|
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN) |
|
gcs_model_folder = self._get_gcs_uri(model_name) |
|
self._create_model_folder(model_name) |
|
self._upload_content(json.dumps(config.to_dict()).encode('utf-8'), f"{gcs_model_folder}/config.json") |
|
return config |
|
except Exception as e: |
|
logger.error(f"Error loading config from Hugging Face: {e}") |
|
return None |
|
|
|
def load_tokenizer(self, model_name): |
|
gcs_tokenizer_path = self._get_gcs_uri(model_name) |
|
tokenizer_files = ["tokenizer_config.json", "vocab.json", "merges.txt", "tokenizer.json", "special_tokens_map.json"] |
|
gcs_files_exist = all(self._blob_exists(f"{gcs_tokenizer_path}/{f}") for f in tokenizer_files) |
|
|
|
if gcs_files_exist: |
|
try: |
|
return AutoTokenizer.from_pretrained(gcs_tokenizer_path, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN) |
|
except Exception as e: |
|
logger.error(f"Error loading tokenizer from GCS: {e}") |
|
return None |
|
else: |
|
try: |
|
logger.info(f"Downloading tokenizer from Hugging Face for {model_name}") |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN) |
|
gcs_model_folder = self._get_gcs_uri(model_name) |
|
self._create_model_folder(model_name) |
|
tokenizer.save_pretrained(gcs_model_folder) |
|
return tokenizer |
|
except Exception as e: |
|
logger.error(f"Error loading tokenizer from Hugging Face: {e}") |
|
return None |
|
|
|
def load_model(self, model_name, config): |
|
gcs_model_path = self._get_gcs_uri(model_name) |
|
logger.info(f"Attempting to load model '{model_name}' from GCS.") |
|
blobs = self.bucket.list_blobs(prefix=gcs_model_path) |
|
weight_files = [blob.name for blob in blobs if blob.name.endswith((".bin", ".safetensors"))] |
|
|
|
if not weight_files: |
|
logger.info(f"No weight files found in GCS for '{model_name}'. Downloading from Hugging Face.") |
|
try: |
|
model = AutoModelForCausalLM.from_pretrained(model_name, config=config, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN) |
|
gcs_model_folder = self._get_gcs_uri(model_name) |
|
self._create_model_folder(model_name) |
|
for filename in os.listdir(model.config.name_or_path): |
|
if filename.endswith((".bin", ".safetensors")): |
|
blob = self.bucket.blob(f"{gcs_model_folder}/{filename}") |
|
blob.upload_from_filename(os.path.join(model.config.name_or_path, filename)) |
|
logger.info(f"Model '{model_name}' downloaded from Hugging Face and saved to GCS.") |
|
return model |
|
except Exception as e: |
|
logger.error(f"Error downloading model from Hugging Face: {e}") |
|
raise HTTPException(status_code=500, detail=f"Error downloading model from Hugging Face: {e}") |
|
|
|
logger.info(f"Found weight files in GCS for '{model_name}': {weight_files}") |
|
|
|
loaded_state_dict = {} |
|
error_occurred = False |
|
for weight_file in weight_files: |
|
logger.info(f"Streaming weight file from GCS: {weight_file}") |
|
blob = self.bucket.blob(weight_file) |
|
try: |
|
blob_content = blob.download_as_bytes() |
|
if weight_file.endswith(".safetensors"): |
|
loaded_state_dict.update(safe_load(blob_content)) |
|
else: |
|
loaded_state_dict.update(torch.load(io.BytesIO(blob_content), map_location="cpu")) |
|
except Exception as e: |
|
logger.error(f"Error streaming and loading weights from GCS {weight_file}: {e}") |
|
error_occurred = True |
|
break |
|
|
|
if error_occurred: |
|
logger.info(f"Attempting to reload model '{model_name}' from Hugging Face due to loading error.") |
|
try: |
|
model = AutoModelForCausalLM.from_pretrained(model_name, config=config, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN) |
|
gcs_model_folder = self._get_gcs_uri(model_name) |
|
self._create_model_folder(model_name) |
|
for filename in os.listdir(model.config.name_or_path): |
|
if filename.endswith((".bin", ".safetensors")): |
|
upload_blob = self.bucket.blob(f"{gcs_model_folder}/{filename}") |
|
upload_blob.upload_from_filename(os.path.join(model.config.name_or_path, filename)) |
|
logger.info(f"Model '{model_name}' reloaded from Hugging Face and saved to GCS.") |
|
return model |
|
except Exception as redownload_error: |
|
logger.error(f"Error redownloading model from Hugging Face: {redownload_error}") |
|
raise HTTPException(status_code=500, detail=f"Failed to load or redownload model: {redownload_error}") |
|
|
|
try: |
|
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) |
|
model.load_state_dict(loaded_state_dict, strict=False) |
|
logger.info(f"Model '{model_name}' successfully loaded from GCS.") |
|
return model |
|
except Exception as e: |
|
logger.error(f"Error loading state dict: {e}") |
|
raise HTTPException(status_code=500, detail=f"Error loading state dict: {e}") |
|
|
|
model_loader = GCSModelLoader(bucket) |
|
|
|
async def generate_stream(model, tokenizer, input_text, generation_config): |
|
inputs = tokenizer(input_text, return_tensors="pt").to(model.device) |
|
async for output in model.generate(**inputs, generation_config=generation_config, stream=True, return_dict_in_generate=True): |
|
token_id = output.sequences[0][-1] |
|
token = tokenizer.decode(token_id, skip_special_tokens=True) |
|
yield {"token": token} |
|
|
|
def generate_non_stream(model, tokenizer, input_text, generation_config): |
|
inputs = tokenizer(input_text, return_tensors="pt").to(model.device) |
|
outputs = model.generate(**inputs, generation_config=generation_config) |
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
@app.post("/generate") |
|
async def generate(request: GenerateRequest): |
|
model_name = request.model_name |
|
input_text = request.input_text |
|
task_type = request.task_type |
|
stream = request.stream |
|
|
|
generation_params = request.model_dump( |
|
exclude_none=True, |
|
exclude={'model_name', 'input_text', 'task_type', 'stream', 'chunk_delay'} |
|
) |
|
|
|
try: |
|
config = model_loader.load_config(model_name) |
|
if not config: |
|
raise HTTPException(status_code=400, detail="Model configuration could not be loaded.") |
|
|
|
tokenizer = model_loader.load_tokenizer(model_name) |
|
if not tokenizer: |
|
raise HTTPException(status_code=400, detail="Tokenizer could not be loaded.") |
|
|
|
model = model_loader.load_model(model_name, config) |
|
if not model: |
|
raise HTTPException(status_code=400, detail="Model could not be loaded.") |
|
|
|
generation_config_kwargs = {k: v for k, v in generation_params.items() if k in GenerationConfig.__init__.__code__.co_varnames} |
|
generation_config_kwargs.setdefault('pad_token_id', tokenizer.pad_token_id) |
|
generation_config_kwargs.setdefault('eos_token_id', tokenizer.eos_token_id) |
|
if hasattr(tokenizer, 'sep_token_id') and tokenizer.sep_token_id is not None: |
|
generation_config_kwargs.setdefault('sep_token_id', tokenizer.sep_token_id) |
|
if hasattr(tokenizer, 'unk_token_id') and tokenizer.unk_token_id is not None: |
|
generation_config_kwargs.setdefault('unk_token_id', tokenizer.unk_token_id) |
|
|
|
generation_config = GenerationConfig.from_pretrained( |
|
model_name, |
|
trust_remote_code=True, |
|
token=HUGGINGFACE_HUB_TOKEN, |
|
**generation_config_kwargs |
|
) |
|
|
|
model.eval() |
|
|
|
if task_type == "text-to-text": |
|
if stream: |
|
async def token_streamer(): |
|
async for item in generate_stream(model, tokenizer, input_text, generation_config): |
|
yield f"data: {json.dumps(item)}\n\n" |
|
await asyncio.sleep(request.chunk_delay) |
|
return StreamingResponse(token_streamer(), media_type="text/event-stream") |
|
else: |
|
text_result = generate_non_stream(model, tokenizer, input_text, generation_config) |
|
return {"text": text_result} |
|
else: |
|
raise HTTPException(status_code=400, detail=f"Task type not supported: {task_type}") |
|
|
|
except HTTPException as e: |
|
raise e |
|
except Exception as e: |
|
logger.error(f"Internal server error: {e}") |
|
raise HTTPException(status_code=500, detail=f"Internal server error: {e}") |
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |