Hjgugugjhuhjggg commited on
Commit
399f6a8
·
verified ·
1 Parent(s): c0d98e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -29
app.py CHANGED
@@ -114,30 +114,31 @@ class GCSModelLoader:
114
  return AutoConfig.from_pretrained(pretrained_model_name_or_path="", _commit_hash=None, config_dict=json.loads(config_content), trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
115
  except Exception as e:
116
  logger.error(f"Error loading config from GCS: {e}")
117
- else:
118
  logger.info(f"Downloading config from Hugging Face for {model_name}")
119
- try:
120
- config = AutoConfig.from_pretrained(model_name, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
121
- gcs_model_folder = self._get_gcs_uri(model_name)
122
- self._create_model_folder(model_name)
123
- self._upload_content(json.dumps(config.to_dict()).encode('utf-8'), f"{gcs_model_folder}/config.json")
124
- return config
125
- except Exception as e:
126
- logger.error(f"Error loading config from Hugging Face: {e}")
127
- return None
128
 
129
  def load_tokenizer(self, model_name):
130
  gcs_tokenizer_path = self._get_gcs_uri(model_name)
131
  tokenizer_files = ["tokenizer_config.json", "vocab.json", "merges.txt", "tokenizer.json", "special_tokens_map.json"]
132
- if all(self._blob_exists(f"{gcs_tokenizer_path}/{f}") for f in tokenizer_files):
 
 
133
  try:
134
  return AutoTokenizer.from_pretrained(gcs_tokenizer_path, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
135
  except Exception as e:
136
  logger.error(f"Error loading tokenizer from GCS: {e}")
137
  return None
138
  else:
139
- logger.info(f"Downloading tokenizer from Hugging Face for {model_name}")
140
  try:
 
141
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
142
  gcs_model_folder = self._get_gcs_uri(model_name)
143
  self._create_model_folder(model_name)
@@ -161,7 +162,8 @@ class GCSModelLoader:
161
  self._create_model_folder(model_name)
162
  for filename in os.listdir(model.config.name_or_path):
163
  if filename.endswith((".bin", ".safetensors")):
164
- self._upload_content(open(os.path.join(model.config.name_or_path, filename), 'rb').read(), f"{gcs_model_folder}/{filename}")
 
165
  logger.info(f"Model '{model_name}' downloaded from Hugging Face and saved to GCS.")
166
  return model
167
  except Exception as e:
@@ -171,30 +173,36 @@ class GCSModelLoader:
171
  logger.info(f"Found weight files in GCS for '{model_name}': {weight_files}")
172
 
173
  loaded_state_dict = {}
 
174
  for weight_file in weight_files:
175
  logger.info(f"Streaming weight file from GCS: {weight_file}")
176
  blob = self.bucket.blob(weight_file)
177
  try:
178
- weight_bytes = blob.download_as_bytes()
179
  if weight_file.endswith(".safetensors"):
180
- loaded_state_dict.update(safe_load(weight_bytes))
181
  else:
182
- loaded_state_dict.update(torch.load(weight_bytes))
183
  except Exception as e:
184
  logger.error(f"Error streaming and loading weights from GCS {weight_file}: {e}")
185
- logger.info(f"Attempting to reload model '{model_name}' from Hugging Face due to loading error.")
186
- try:
187
- model = AutoModelForCausalLM.from_pretrained(model_name, config=config, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
188
- gcs_model_folder = self._get_gcs_uri(model_name)
189
- self._create_model_folder(model_name)
190
- for filename in os.listdir(model.config.name_or_path):
191
- if filename.endswith((".bin", ".safetensors")):
192
- self._upload_content(open(os.path.join(model.config.name_or_path, filename), 'rb').read(), f"{gcs_model_folder}/{filename}")
193
- logger.info(f"Model '{model_name}' reloaded from Hugging Face and saved to GCS.")
194
- return model
195
- except Exception as redownload_error:
196
- logger.error(f"Error redownloading model from Hugging Face: {redownload_error}")
197
- raise HTTPException(status_code=500, detail=f"Failed to load or redownload model: {redownload_error}")
 
 
 
 
 
198
 
199
  try:
200
  model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
 
114
  return AutoConfig.from_pretrained(pretrained_model_name_or_path="", _commit_hash=None, config_dict=json.loads(config_content), trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
115
  except Exception as e:
116
  logger.error(f"Error loading config from GCS: {e}")
117
+ try:
118
  logger.info(f"Downloading config from Hugging Face for {model_name}")
119
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
120
+ gcs_model_folder = self._get_gcs_uri(model_name)
121
+ self._create_model_folder(model_name)
122
+ self._upload_content(json.dumps(config.to_dict()).encode('utf-8'), f"{gcs_model_folder}/config.json")
123
+ return config
124
+ except Exception as e:
125
+ logger.error(f"Error loading config from Hugging Face: {e}")
126
+ return None
 
127
 
128
  def load_tokenizer(self, model_name):
129
  gcs_tokenizer_path = self._get_gcs_uri(model_name)
130
  tokenizer_files = ["tokenizer_config.json", "vocab.json", "merges.txt", "tokenizer.json", "special_tokens_map.json"]
131
+ gcs_files_exist = all(self._blob_exists(f"{gcs_tokenizer_path}/{f}") for f in tokenizer_files)
132
+
133
+ if gcs_files_exist:
134
  try:
135
  return AutoTokenizer.from_pretrained(gcs_tokenizer_path, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
136
  except Exception as e:
137
  logger.error(f"Error loading tokenizer from GCS: {e}")
138
  return None
139
  else:
 
140
  try:
141
+ logger.info(f"Downloading tokenizer from Hugging Face for {model_name}")
142
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
143
  gcs_model_folder = self._get_gcs_uri(model_name)
144
  self._create_model_folder(model_name)
 
162
  self._create_model_folder(model_name)
163
  for filename in os.listdir(model.config.name_or_path):
164
  if filename.endswith((".bin", ".safetensors")):
165
+ blob = self.bucket.blob(f"{gcs_model_folder}/{filename}")
166
+ blob.upload_from_filename(os.path.join(model.config.name_or_path, filename))
167
  logger.info(f"Model '{model_name}' downloaded from Hugging Face and saved to GCS.")
168
  return model
169
  except Exception as e:
 
173
  logger.info(f"Found weight files in GCS for '{model_name}': {weight_files}")
174
 
175
  loaded_state_dict = {}
176
+ error_occurred = False
177
  for weight_file in weight_files:
178
  logger.info(f"Streaming weight file from GCS: {weight_file}")
179
  blob = self.bucket.blob(weight_file)
180
  try:
181
+ blob_content = blob.download_as_bytes()
182
  if weight_file.endswith(".safetensors"):
183
+ loaded_state_dict.update(safe_load(blob_content))
184
  else:
185
+ loaded_state_dict.update(torch.load(io.BytesIO(blob_content), map_location="cpu"))
186
  except Exception as e:
187
  logger.error(f"Error streaming and loading weights from GCS {weight_file}: {e}")
188
+ error_occurred = True
189
+ break
190
+
191
+ if error_occurred:
192
+ logger.info(f"Attempting to reload model '{model_name}' from Hugging Face due to loading error.")
193
+ try:
194
+ model = AutoModelForCausalLM.from_pretrained(model_name, config=config, trust_remote_code=True, token=HUGGINGFACE_HUB_TOKEN)
195
+ gcs_model_folder = self._get_gcs_uri(model_name)
196
+ self._create_model_folder(model_name)
197
+ for filename in os.listdir(model.config.name_or_path):
198
+ if filename.endswith((".bin", ".safetensors")):
199
+ upload_blob = self.bucket.blob(f"{gcs_model_folder}/{filename}")
200
+ upload_blob.upload_from_filename(os.path.join(model.config.name_or_path, filename))
201
+ logger.info(f"Model '{model_name}' reloaded from Hugging Face and saved to GCS.")
202
+ return model
203
+ except Exception as redownload_error:
204
+ logger.error(f"Error redownloading model from Hugging Face: {redownload_error}")
205
+ raise HTTPException(status_code=500, detail=f"Failed to load or redownload model: {redownload_error}")
206
 
207
  try:
208
  model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)