Hjgugugjhuhjggg commited on
Commit
4a8c11d
·
verified ·
1 Parent(s): 14051c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -64
app.py CHANGED
@@ -49,9 +49,6 @@ class GenerateRequest(BaseModel):
49
  do_sample: bool = False
50
  chunk_delay: float = 0.1
51
  stop_sequences: list = []
52
- min_length: int = 0
53
- no_repeat_ngram_size: int = 0
54
- length_penalty: float = 1.0
55
 
56
  @field_validator("model_name")
57
  def model_name_cannot_be_empty(cls, v):
@@ -73,68 +70,56 @@ class GCSModelLoader:
73
  def _get_gcs_uri(self, model_name):
74
  return f"{model_name}"
75
 
76
- async def _create_gcs_folder(self, model_name):
77
- blob_name = f"{model_name}/.keep"
78
- blob = self.bucket.blob(blob_name)
79
- if not await blob.exists(client=self.bucket.client):
80
- try:
81
- await blob.upload_from_string('')
82
- except Exception as e:
83
- logger.error(f"Error creating folder for {model_name}: {e}")
84
-
85
- async def _download_from_gcs(self, gcs_path):
86
- try:
87
- blob = self.bucket.blob(gcs_path)
88
- if await blob.exists(client=self.bucket.client):
89
- return await blob.download_as_string()
90
- return None
91
- except Exception as e:
92
- logger.error(f"Error accessing {gcs_path}: {e}")
93
- return None
94
-
95
- async def _upload_to_gcs(self, content, gcs_path):
96
- try:
97
- blob = self.bucket.blob(gcs_path)
98
- await blob.upload_from_string(content)
99
- return True
100
- except Exception as e:
101
- logger.error(f"Error uploading to {gcs_path}: {e}")
102
- return False
103
 
104
  async def load_config(self, model_name):
105
- gcs_path = f"{self._get_gcs_uri(model_name)}/config.json"
106
- data = await self._download_from_gcs(gcs_path)
107
- if data:
108
  try:
109
- return AutoConfig.from_pretrained(pretrained_model_name_or_path=None, trust_remote_code=True, _commit_hash=None, **json.loads(data))
110
  except Exception as e:
111
  logger.error(f"Error loading config from GCS: {e}")
 
112
  else:
113
  try:
114
  config = AutoConfig.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN, trust_remote_code=True)
115
- await self._create_gcs_folder(model_name)
116
- await self._upload_to_gcs(json.dumps(config.to_dict()), gcs_path)
117
  return config
118
- except Exception as e_hf:
119
- logger.error(f"Error loading config from Hugging Face and saving to GCS: {e_hf}")
120
- return None
121
 
122
  async def load_tokenizer(self, model_name):
123
- gcs_path = f"{self._get_gcs_uri(model_name)}/tokenizer.json"
124
- data = await self._download_from_gcs(gcs_path)
125
- if data:
 
126
  try:
127
- return AutoTokenizer.from_pretrained(pretrained_model_name_or_path=None, trust_remote_code=True, _commit_hash=None, **json.loads(data))
128
  except Exception as e:
129
  logger.error(f"Error loading tokenizer from GCS: {e}")
 
130
  else:
131
  try:
132
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN, trust_remote_code=True)
133
- await self._upload_to_gcs(json.dumps(tokenizer.to_dict()), gcs_path)
134
  return tokenizer
135
- except Exception as e_hf:
136
- logger.error(f"Error loading tokenizer from Hugging Face and saving to GCS: {e_hf}")
137
- return None
138
 
139
  async def load_model(self, model_name, config):
140
  gcs_model_path = self._get_gcs_uri(model_name)
@@ -152,8 +137,8 @@ class GCSModelLoader:
152
  model = AutoModelForCausalLM.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN, trust_remote_code=True)
153
  model.save_pretrained(gcs_model_path)
154
  return model
155
- except Exception as e_hf:
156
- logger.error(f"Error loading model from Hugging Face and saving to GCS: {e_hf}")
157
  raise HTTPException(status_code=500, detail="Failed to load model")
158
 
159
  model_loader = GCSModelLoader(bucket)
@@ -170,13 +155,8 @@ async def generate_stream(model, tokenizer, input_text, generation_config, stop_
170
  await asyncio.sleep(chunk_delay)
171
  if any(stop in token for stop in stop_sequences):
172
  break
173
- yield {"finish": True}
174
-
175
- async def generate_events():
176
- async for event_data in event_stream():
177
- yield json.dumps(event_data) + "\n"
178
 
179
- return generate_events()
180
 
181
  async def generate_non_stream(model, tokenizer, input_text, generation_config):
182
  inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
@@ -217,16 +197,12 @@ async def generate(request: GenerateRequest):
217
 
218
  if task_type == "text-to-text":
219
  if stream:
220
- return StreamingResponse(
221
- generate_stream(
222
- model, tokenizer, input_text, generation_config, request.stop_sequences, request.chunk_delay
223
- ),
224
- media_type="text/event-stream"
225
- )
226
  else:
227
- text_result = await generate_non_stream(
228
- model, tokenizer, input_text, generation_config
229
- )
230
  return {"text": text_result}
231
  else:
232
  raise HTTPException(status_code=400, detail=f"Task type not supported: {task_type}")
 
49
  do_sample: bool = False
50
  chunk_delay: float = 0.1
51
  stop_sequences: list = []
 
 
 
52
 
53
  @field_validator("model_name")
54
  def model_name_cannot_be_empty(cls, v):
 
70
  def _get_gcs_uri(self, model_name):
71
  return f"{model_name}"
72
 
73
+ async def _blob_exists(self, blob_path):
74
+ blob = self.bucket.blob(blob_path)
75
+ return await blob.exists(client=self.bucket.client)
76
+
77
+ async def _download_string(self, blob_path):
78
+ blob = self.bucket.blob(blob_path)
79
+ if await self._blob_exists(blob_path):
80
+ return await blob.download_as_string()
81
+ return None
82
+
83
+ async def _upload_string(self, content, blob_path):
84
+ blob = self.bucket.blob(blob_path)
85
+ await blob.upload_from_string(content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  async def load_config(self, model_name):
88
+ gcs_config_path = f"{self._get_gcs_uri(model_name)}/config.json"
89
+ config_str = await self._download_string(gcs_config_path)
90
+ if config_str:
91
  try:
92
+ return AutoConfig.from_pretrained(pretrained_model_name_or_path=None, trust_remote_code=True, **json.loads(config_str))
93
  except Exception as e:
94
  logger.error(f"Error loading config from GCS: {e}")
95
+ return None
96
  else:
97
  try:
98
  config = AutoConfig.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN, trust_remote_code=True)
99
+ await self._upload_string(json.dumps(config.to_dict()), gcs_config_path)
 
100
  return config
101
+ except Exception as e:
102
+ logger.error(f"Error loading config from Hugging Face and saving to GCS: {e}")
103
+ return None
104
 
105
  async def load_tokenizer(self, model_name):
106
+ gcs_tokenizer_path = self._get_gcs_uri(model_name)
107
+ if await self._blob_exists(f"{gcs_tokenizer_path}/tokenizer_config.json") and \
108
+ await self._blob_exists(f"{gcs_tokenizer_path}/vocab.json") and \
109
+ await self._blob_exists(f"{gcs_tokenizer_path}/merges.txt"):
110
  try:
111
+ return AutoTokenizer.from_pretrained(gcs_tokenizer_path, trust_remote_code=True)
112
  except Exception as e:
113
  logger.error(f"Error loading tokenizer from GCS: {e}")
114
+ return None
115
  else:
116
  try:
117
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN, trust_remote_code=True)
118
+ tokenizer.save_pretrained(gcs_tokenizer_path)
119
  return tokenizer
120
+ except Exception as e:
121
+ logger.error(f"Error loading tokenizer from Hugging Face and saving to GCS: {e}")
122
+ return None
123
 
124
  async def load_model(self, model_name, config):
125
  gcs_model_path = self._get_gcs_uri(model_name)
 
137
  model = AutoModelForCausalLM.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN, trust_remote_code=True)
138
  model.save_pretrained(gcs_model_path)
139
  return model
140
+ except Exception as e:
141
+ logger.error(f"Error loading model from Hugging Face and saving to GCS: {e}")
142
  raise HTTPException(status_code=500, detail="Failed to load model")
143
 
144
  model_loader = GCSModelLoader(bucket)
 
155
  await asyncio.sleep(chunk_delay)
156
  if any(stop in token for stop in stop_sequences):
157
  break
 
 
 
 
 
158
 
159
+ return event_stream()
160
 
161
  async def generate_non_stream(model, tokenizer, input_text, generation_config):
162
  inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
 
197
 
198
  if task_type == "text-to-text":
199
  if stream:
200
+ async def generate_events():
201
+ async for event in generate_stream(model, tokenizer, input_text, generation_config, request.stop_sequences, request.chunk_delay):
202
+ yield json.dumps(event).encode('utf-8') + b"\n"
203
+ return StreamingResponse(generate_events(), media_type="text/event-stream")
 
 
204
  else:
205
+ text_result = await generate_non_stream(model, tokenizer, input_text, generation_config)
 
 
206
  return {"text": text_result}
207
  else:
208
  raise HTTPException(status_code=400, detail=f"Task type not supported: {task_type}")