gcs / app.py
Hjgugugjhuhjggg's picture
Update app.py
399f6a8 verified
raw
history blame
13.2 kB
import os
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, field_validator
from transformers import (
AutoConfig,
GenerationConfig,
AutoTokenizer,
AutoModelForCausalLM,
)
from google.cloud import storage
from google.auth.exceptions import DefaultCredentialsError
import uvicorn
import asyncio
import json
import logging
from huggingface_hub import login
from dotenv import load_dotenv
import huggingface_hub
import torch
from safetensors.torch import load_file as safe_load
load_dotenv()
GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
HUGGINGFACE_HUB_TOKEN = os.getenv("HF_API_TOKEN")
if HUGGINGFACE_HUB_TOKEN:
login(token=HUGGINGFACE_HUB_TOKEN)
os.system("git config --global credential.helper store")
if HUGGINGFACE_HUB_TOKEN:
huggingface_hub.login(token=HUGGINGFACE_HUB_TOKEN, add_to_git_credential=True)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
try:
credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
client = storage.Client.from_service_account_info(credentials_info)
bucket = client.get_bucket(GCS_BUCKET_NAME)
logger.info(f"Connection to Google Cloud Storage successful. Bucket: {GCS_BUCKET_NAME}")
except (DefaultCredentialsError, json.JSONDecodeError, KeyError, ValueError) as e:
logger.error(f"Error loading credentials or bucket: {e}")
raise RuntimeError(f"Error loading credentials or bucket: {e}")
app = FastAPI()
class GenerateRequest(BaseModel):
model_name: str
input_text: str
task_type: str
temperature: float = 1.0
max_new_tokens: int = 20
stream: bool = False
top_p: float = 1.0
top_k: int = 50
repetition_penalty: float = 1.0
num_return_sequences: int = 1
do_sample: bool = False
chunk_delay: float = 0.1
stop_sequences: list = []
@field_validator("model_name")
def model_name_cannot_be_empty(cls, v):
if not v:
raise ValueError("model_name cannot be empty.")
return v
@field_validator("task_type")
def task_type_must_be_valid(cls, v):
valid_types = ["text-to-text"]
if v not in valid_types:
raise ValueError(f"task_type must be one of: {valid_types}")
return v
class GCSModelLoader:
def __init__(self, bucket):
self.bucket = bucket
def _get_gcs_uri(self, model_name):
return f"{model_name}"
def _blob_exists(self, blob_path):
blob = self.bucket.blob(blob_path)
return blob.exists()
def _download_content(self, blob_path):
blob = self.bucket.blob(blob_path)
try:
return blob.download_as_bytes()
except Exception as e:
logger.error(f"Error downloading {blob_path}: {e}")
return None
def _upload_content(self, content, blob_path):
blob = self.bucket.blob(blob_path)
blob.upload_from_string(content)
def _create_model_folder(self, model_name):
gcs_model_folder = self._get_gcs_uri(model_name)
if not self._blob_exists(f"{gcs_model_folder}/.touch"):
blob = self.bucket.blob(f"{gcs_model_folder}/.touch")
blob.upload_from_string("")
logger.info(f"Created folder '{gcs_model_folder}' in GCS.")
def load_config(self, model_name):
gcs_config_path = f"{self._get_gcs_uri(model_name)}/config.json"
if self._blob_exists(gcs_config_path):
try:
config_content = self._download_content(gcs_config_path)
return AutoConfig.from_pretrained(pretrained_model_name_or_path="", _commit_hash=None, config_dict=json.loads(config_content), trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
except Exception as e:
logger.error(f"Error loading config from GCS: {e}")
try:
logger.info(f"Downloading config from Hugging Face for {model_name}")
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
gcs_model_folder = self._get_gcs_uri(model_name)
self._create_model_folder(model_name)
self._upload_content(json.dumps(config.to_dict()).encode('utf-8'), f"{gcs_model_folder}/config.json")
return config
except Exception as e:
logger.error(f"Error loading config from Hugging Face: {e}")
return None
def load_tokenizer(self, model_name):
gcs_tokenizer_path = self._get_gcs_uri(model_name)
tokenizer_files = ["tokenizer_config.json", "vocab.json", "merges.txt", "tokenizer.json", "special_tokens_map.json"]
gcs_files_exist = all(self._blob_exists(f"{gcs_tokenizer_path}/{f}") for f in tokenizer_files)
if gcs_files_exist:
try:
return AutoTokenizer.from_pretrained(gcs_tokenizer_path, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
except Exception as e:
logger.error(f"Error loading tokenizer from GCS: {e}")
return None
else:
try:
logger.info(f"Downloading tokenizer from Hugging Face for {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
gcs_model_folder = self._get_gcs_uri(model_name)
self._create_model_folder(model_name)
tokenizer.save_pretrained(gcs_model_folder)
return tokenizer
except Exception as e:
logger.error(f"Error loading tokenizer from Hugging Face: {e}")
return None
def load_model(self, model_name, config):
gcs_model_path = self._get_gcs_uri(model_name)
logger.info(f"Attempting to load model '{model_name}' from GCS.")
blobs = self.bucket.list_blobs(prefix=gcs_model_path)
weight_files = [blob.name for blob in blobs if blob.name.endswith((".bin", ".safetensors"))]
if not weight_files:
logger.info(f"No weight files found in GCS for '{model_name}'. Downloading from Hugging Face.")
try:
model = AutoModelForCausalLM.from_pretrained(model_name, config=config, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
gcs_model_folder = self._get_gcs_uri(model_name)
self._create_model_folder(model_name)
for filename in os.listdir(model.config.name_or_path):
if filename.endswith((".bin", ".safetensors")):
blob = self.bucket.blob(f"{gcs_model_folder}/{filename}")
blob.upload_from_filename(os.path.join(model.config.name_or_path, filename))
logger.info(f"Model '{model_name}' downloaded from Hugging Face and saved to GCS.")
return model
except Exception as e:
logger.error(f"Error downloading model from Hugging Face: {e}")
raise HTTPException(status_code=500, detail=f"Error downloading model from Hugging Face: {e}")
logger.info(f"Found weight files in GCS for '{model_name}': {weight_files}")
loaded_state_dict = {}
error_occurred = False
for weight_file in weight_files:
logger.info(f"Streaming weight file from GCS: {weight_file}")
blob = self.bucket.blob(weight_file)
try:
blob_content = blob.download_as_bytes()
if weight_file.endswith(".safetensors"):
loaded_state_dict.update(safe_load(blob_content))
else:
loaded_state_dict.update(torch.load(io.BytesIO(blob_content), map_location="cpu"))
except Exception as e:
logger.error(f"Error streaming and loading weights from GCS {weight_file}: {e}")
error_occurred = True
break
if error_occurred:
logger.info(f"Attempting to reload model '{model_name}' from Hugging Face due to loading error.")
try:
model = AutoModelForCausalLM.from_pretrained(model_name, config=config, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
gcs_model_folder = self._get_gcs_uri(model_name)
self._create_model_folder(model_name)
for filename in os.listdir(model.config.name_or_path):
if filename.endswith((".bin", ".safetensors")):
upload_blob = self.bucket.blob(f"{gcs_model_folder}/{filename}")
upload_blob.upload_from_filename(os.path.join(model.config.name_or_path, filename))
logger.info(f"Model '{model_name}' reloaded from Hugging Face and saved to GCS.")
return model
except Exception as redownload_error:
logger.error(f"Error redownloading model from Hugging Face: {redownload_error}")
raise HTTPException(status_code=500, detail=f"Failed to load or redownload model: {redownload_error}")
try:
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
model.load_state_dict(loaded_state_dict, strict=False)
logger.info(f"Model '{model_name}' successfully loaded from GCS.")
return model
except Exception as e:
logger.error(f"Error loading state dict: {e}")
raise HTTPException(status_code=500, detail=f"Error loading state dict: {e}")
model_loader = GCSModelLoader(bucket)
async def generate_stream(model, tokenizer, input_text, generation_config):
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
async for output in model.generate(**inputs, generation_config=generation_config, stream=True, return_dict_in_generate=True):
token_id = output.sequences[0][-1]
token = tokenizer.decode(token_id, skip_special_tokens=True)
yield {"token": token}
def generate_non_stream(model, tokenizer, input_text, generation_config):
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, generation_config=generation_config)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
@app.post("/generate")
async def generate(request: GenerateRequest):
model_name = request.model_name
input_text = request.input_text
task_type = request.task_type
stream = request.stream
generation_params = request.model_dump(
exclude_none=True,
exclude={'model_name', 'input_text', 'task_type', 'stream', 'chunk_delay'}
)
try:
config = model_loader.load_config(model_name)
if not config:
raise HTTPException(status_code=400, detail="Model configuration could not be loaded.")
tokenizer = model_loader.load_tokenizer(model_name)
if not tokenizer:
raise HTTPException(status_code=400, detail="Tokenizer could not be loaded.")
model = model_loader.load_model(model_name, config)
if not model:
raise HTTPException(status_code=400, detail="Model could not be loaded.")
generation_config_kwargs = {k: v for k, v in generation_params.items() if k in GenerationConfig.__init__.__code__.co_varnames}
generation_config_kwargs.setdefault('pad_token_id', tokenizer.pad_token_id)
generation_config_kwargs.setdefault('eos_token_id', tokenizer.eos_token_id)
if hasattr(tokenizer, 'sep_token_id') and tokenizer.sep_token_id is not None:
generation_config_kwargs.setdefault('sep_token_id', tokenizer.sep_token_id)
if hasattr(tokenizer, 'unk_token_id') and tokenizer.unk_token_id is not None:
generation_config_kwargs.setdefault('unk_token_id', tokenizer.unk_token_id)
generation_config = GenerationConfig.from_pretrained(
model_name,
trust_remote_code=True,
token=HUGGINGFACE_HUB_TOKEN,
**generation_config_kwargs
)
model.eval()
if task_type == "text-to-text":
if stream:
async def token_streamer():
async for item in generate_stream(model, tokenizer, input_text, generation_config):
yield f"data: {json.dumps(item)}\n\n"
await asyncio.sleep(request.chunk_delay)
return StreamingResponse(token_streamer(), media_type="text/event-stream")
else:
text_result = generate_non_stream(model, tokenizer, input_text, generation_config)
return {"text": text_result}
else:
raise HTTPException(status_code=400, detail=f"Task type not supported: {task_type}")
except HTTPException as e:
raise e
except Exception as e:
logger.error(f"Internal server error: {e}")
raise HTTPException(status_code=500, detail=f"Internal server error: {e}")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)