Hjgugugjhuhjggg commited on
Commit
d9a044e
·
verified ·
1 Parent(s): bc2f34d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -20
app.py CHANGED
@@ -73,7 +73,6 @@ class GenerateRequest(BaseModel):
73
 
74
  @model_validator(mode='before')
75
  def set_default_token_ids(cls, values):
76
- # These will be populated after tokenizer is loaded, but need defaults to avoid pydantic errors
77
  values.setdefault("pad_token_id", None)
78
  values.setdefault("eos_token_id", None)
79
  values.setdefault("sep_token_id", None)
@@ -89,7 +88,7 @@ class GCSModelLoader:
89
 
90
  def _blob_exists(self, blob_path):
91
  blob = self.bucket.blob(blob_path)
92
- return blob.exists(client=self.bucket.client)
93
 
94
  def _download_content(self, blob_path):
95
  blob = self.bucket.blob(blob_path)
@@ -106,15 +105,15 @@ class GCSModelLoader:
106
  config_content = self._download_content(gcs_config_path)
107
  if config_content:
108
  try:
109
- return AutoConfig.from_pretrained(pretrained_model_name_or_path=None, trust_remote_code=True, config_dict=json.loads(config_content))
110
  except Exception as e:
111
  logger.error(f"Error loading config from GCS: {e}")
112
  return None
113
  else:
114
  try:
115
- config = AutoConfig.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN, trust_remote_code=True)
116
  gcs_model_folder = self._get_gcs_uri(model_name)
117
- self._upload_content(json.dumps(config.to_dict()).encode('utf-8'), f"{gcs_model_folder}/config.json")
118
  return config
119
  except Exception as e:
120
  logger.error(f"Error loading config from Hugging Face and saving to GCS: {e}")
@@ -127,18 +126,17 @@ class GCSModelLoader:
127
 
128
  if gcs_files_exist:
129
  try:
130
- return AutoTokenizer.from_pretrained(gcs_tokenizer_path, trust_remote_code=True)
131
  except Exception as e:
132
  logger.error(f"Error loading tokenizer from GCS: {e}")
133
  return None
134
  else:
135
  try:
136
- tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN, trust_remote_code=True)
137
  gcs_model_folder = self._get_gcs_uri(model_name)
138
- os.makedirs(gcs_model_folder, exist_ok=True) # Ensure the folder exists in GCS
139
- for file in tokenizer.save_pretrained(gcs_model_folder):
140
  with open(file, 'rb') as f:
141
- self._upload_content(f.read(), f"{gcs_model_folder}/{os.path.basename(file)}")
142
  return tokenizer
143
  except Exception as e:
144
  logger.error(f"Error loading tokenizer from Hugging Face and saving to GCS: {e}")
@@ -157,12 +155,11 @@ class GCSModelLoader:
157
  raise HTTPException(status_code=500, detail=f"Error loading model from GCS: {e}")
158
  else:
159
  try:
160
- model = AutoModelForCausalLM.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN, trust_remote_code=True)
161
  gcs_model_folder = self._get_gcs_uri(model_name)
162
- os.makedirs(gcs_model_folder, exist_ok=True) # Ensure the folder exists in GCS
163
  for filename in os.listdir(model.save_pretrained(None)):
164
- with open(os.path.join(model.save_pretrained(None), filename), 'rb') as f:
165
- self._upload_content(f.read(), f"{gcs_model_folder}/{filename}")
166
  return model
167
  except Exception as e:
168
  logger.error(f"Error loading model from Hugging Face and saving to GCS: {e}")
@@ -200,10 +197,9 @@ async def generate(request: GenerateRequest):
200
 
201
  try:
202
  gcs_model_folder_uri = model_loader._get_gcs_uri(model_name)
203
- if not model_loader._blob_exists(f"{gcs_model_folder_uri}/config.json"):
204
  logger.info(f"Model '{model_name}' not found in GCS, creating placeholder.")
205
  bucket.blob(f"{gcs_model_folder_uri}/.placeholder").upload_from_string("")
206
-
207
  config = model_loader.load_config(model_name)
208
  if not config:
209
  raise HTTPException(status_code=400, detail="Model configuration could not be loaded.")
@@ -212,7 +208,6 @@ async def generate(request: GenerateRequest):
212
  if not tokenizer:
213
  raise HTTPException(status_code=400, detail="Tokenizer could not be loaded.")
214
 
215
- # Update token IDs from tokenizer if not provided in request
216
  if request.pad_token_id is None:
217
  request.pad_token_id = tokenizer.pad_token_id
218
  if request.eos_token_id is None:
@@ -233,9 +228,7 @@ async def generate(request: GenerateRequest):
233
  pad_token_id=request.pad_token_id,
234
  eos_token_id=request.eos_token_id,
235
  sep_token_id=request.sep_token_id,
236
- unk_token_id=request.unk_token_id,
237
- return_dict_in_generate=True,
238
- output_scores=True
239
  )
240
 
241
  if task_type == "text-to-text":
 
73
 
74
  @model_validator(mode='before')
75
  def set_default_token_ids(cls, values):
 
76
  values.setdefault("pad_token_id", None)
77
  values.setdefault("eos_token_id", None)
78
  values.setdefault("sep_token_id", None)
 
88
 
89
  def _blob_exists(self, blob_path):
90
  blob = self.bucket.blob(blob_path)
91
+ return blob.exists()
92
 
93
  def _download_content(self, blob_path):
94
  blob = self.bucket.blob(blob_path)
 
105
  config_content = self._download_content(gcs_config_path)
106
  if config_content:
107
  try:
108
+ return AutoConfig.from_pretrained(pretrained_model_name_or_path=None, trust_remote_code=True, config_dict=json.loads(config_content), token=HUGGINGFACE_HUB_TOKEN)
109
  except Exception as e:
110
  logger.error(f"Error loading config from GCS: {e}")
111
  return None
112
  else:
113
  try:
114
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
115
  gcs_model_folder = self._get_gcs_uri(model_name)
116
+ bucket.blob(f"{gcs_model_folder}/config.json").upload_from_string(json.dumps(config.to_dict()).encode('utf-8'))
117
  return config
118
  except Exception as e:
119
  logger.error(f"Error loading config from Hugging Face and saving to GCS: {e}")
 
126
 
127
  if gcs_files_exist:
128
  try:
129
+ return AutoTokenizer.from_pretrained(gcs_tokenizer_path, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
130
  except Exception as e:
131
  logger.error(f"Error loading tokenizer from GCS: {e}")
132
  return None
133
  else:
134
  try:
135
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
136
  gcs_model_folder = self._get_gcs_uri(model_name)
137
+ for file in tokenizer.save_pretrained(None):
 
138
  with open(file, 'rb') as f:
139
+ bucket.blob(f"{gcs_model_folder}/{os.path.basename(file)}").upload_from_string(f.read())
140
  return tokenizer
141
  except Exception as e:
142
  logger.error(f"Error loading tokenizer from Hugging Face and saving to GCS: {e}")
 
155
  raise HTTPException(status_code=500, detail=f"Error loading model from GCS: {e}")
156
  else:
157
  try:
158
+ model = AutoModelForCausalLM.from_pretrained(model_name, config=config, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
159
  gcs_model_folder = self._get_gcs_uri(model_name)
 
160
  for filename in os.listdir(model.save_pretrained(None)):
161
+ with open(os.path.join(model.save_pretrained(None)), 'rb') as f:
162
+ bucket.blob(f"{gcs_model_folder}/{filename}").upload_from_string(f.read())
163
  return model
164
  except Exception as e:
165
  logger.error(f"Error loading model from Hugging Face and saving to GCS: {e}")
 
197
 
198
  try:
199
  gcs_model_folder_uri = model_loader._get_gcs_uri(model_name)
200
+ if not bucket.blob(f"{gcs_model_folder_uri}/config.json").exists():
201
  logger.info(f"Model '{model_name}' not found in GCS, creating placeholder.")
202
  bucket.blob(f"{gcs_model_folder_uri}/.placeholder").upload_from_string("")
 
203
  config = model_loader.load_config(model_name)
204
  if not config:
205
  raise HTTPException(status_code=400, detail="Model configuration could not be loaded.")
 
208
  if not tokenizer:
209
  raise HTTPException(status_code=400, detail="Tokenizer could not be loaded.")
210
 
 
211
  if request.pad_token_id is None:
212
  request.pad_token_id = tokenizer.pad_token_id
213
  if request.eos_token_id is None:
 
228
  pad_token_id=request.pad_token_id,
229
  eos_token_id=request.eos_token_id,
230
  sep_token_id=request.sep_token_id,
231
+ unk_token_id=request.unk_token_id
 
 
232
  )
233
 
234
  if task_type == "text-to-text":