Hjgugugjhuhjggg commited on
Commit
7d42dcb
·
verified ·
1 Parent(s): 19f95bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -16
app.py CHANGED
@@ -14,6 +14,8 @@ from dotenv import load_dotenv
14
  import huggingface_hub
15
  from threading import Thread
16
  from typing import AsyncIterator, List, Dict
 
 
17
 
18
  load_dotenv()
19
 
@@ -55,7 +57,8 @@ class GenerateRequest(BaseModel):
55
  num_return_sequences: int = 1
56
  do_sample: bool = False
57
  chunk_delay: float = 0.0
58
- max_new_tokens: int = 512
 
59
 
60
  @field_validator("model_name")
61
  def model_name_cannot_be_empty(cls, v):
@@ -70,11 +73,20 @@ class GenerateRequest(BaseModel):
70
  raise ValueError(f"task_type must be one of: {valid_types}")
71
  return v
72
 
73
- @field_validator("max_new_tokens")
74
- def max_new_tokens_must_be_within_limit(cls, v):
75
- if v > 10:
76
- raise ValueError("max_new_tokens cannot exceed 10.")
77
- return v
 
 
 
 
 
 
 
 
 
78
 
79
  class GCSModelLoader:
80
  def __init__(self, bucket):
@@ -125,11 +137,12 @@ async def generate(request: GenerateRequest):
125
  model_name = request.model_name
126
  input_text = request.input_text
127
  task_type = request.task_type
128
- initial_max_new_tokens = request.max_new_tokens # The requested max tokens (will be max 10)
129
  generation_params = request.model_dump(
130
  exclude_none=True,
131
- exclude={'model_name', 'input_text', 'task_type', 'stream', 'chunk_delay', 'max_new_tokens'}
132
  )
 
133
 
134
  try:
135
  if not model_loader.check_model_exists_locally(model_name):
@@ -137,19 +150,40 @@ async def generate(request: GenerateRequest):
137
  raise HTTPException(status_code=500, detail=f"Failed to load model: {model_name}")
138
 
139
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  async def generate_responses() -> AsyncIterator[Dict[str, List[Dict[str, str]]]]:
142
  all_generated_text = ""
143
- remaining_tokens = 512 # Or some reasonable maximum
144
- while remaining_tokens > 0:
145
- current_max_new_tokens = min(initial_max_new_tokens, remaining_tokens)
146
  text_pipeline = pipeline(
147
  task_type,
148
  model=model_name,
149
  tokenizer=tokenizer,
150
  token=HUGGINGFACE_HUB_TOKEN,
 
151
  **generation_params,
152
- max_new_tokens=current_max_new_tokens
153
  )
154
 
155
  def generate_on_thread(pipeline, input_text, output_queue):
@@ -163,13 +197,40 @@ async def generate(request: GenerateRequest):
163
  thread.join()
164
 
165
  newly_generated_text = result[0]['generated_text'][len(all_generated_text):]
166
- if not newly_generated_text: # Break if no new text is generated
167
- break
 
168
 
169
  all_generated_text += newly_generated_text
170
  yield {"response": [{'generated_text': newly_generated_text}]}
171
- remaining_tokens -= current_max_new_tokens
172
- # Update input_text for the next iteration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  input_text = all_generated_text
174
 
175
  async def text_stream():
@@ -186,4 +247,5 @@ async def generate(request: GenerateRequest):
186
  raise HTTPException(status_code=500, detail=f"Internal server error: {e}")
187
 
188
  if __name__ == "__main__":
 
189
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
14
  import huggingface_hub
15
  from threading import Thread
16
  from typing import AsyncIterator, List, Dict
17
+ from transformers.stopping_criteria import StoppingCriteria, StoppingCriteriaList
18
+ import torch
19
 
20
  load_dotenv()
21
 
 
57
  num_return_sequences: int = 1
58
  do_sample: bool = False
59
  chunk_delay: float = 0.0
60
+ max_new_tokens: int = 10
61
+ stopping_strings: List[str] = None
62
 
63
  @field_validator("model_name")
64
  def model_name_cannot_be_empty(cls, v):
 
73
  raise ValueError(f"task_type must be one of: {valid_types}")
74
  return v
75
 
76
+ class StopOnKeywords(StoppingCriteria):
77
+ def __init__(self, stop_words_ids: List[List[int]], encounters: int = 1):
78
+ super().__init__()
79
+ self.stop_words_ids = stop_words_ids
80
+ self.encounters = encounters
81
+ self.current_encounters = 0
82
+
83
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
84
+ for stop_ids in self.stop_words_ids:
85
+ if torch.all(input_ids[0][-len(stop_ids):] == torch.tensor(stop_ids).to(input_ids.device)):
86
+ self.current_encounters += 1
87
+ if self.current_encounters >= self.encounters:
88
+ return True
89
+ return False
90
 
91
  class GCSModelLoader:
92
  def __init__(self, bucket):
 
137
  model_name = request.model_name
138
  input_text = request.input_text
139
  task_type = request.task_type
140
+ requested_max_new_tokens = request.max_new_tokens
141
  generation_params = request.model_dump(
142
  exclude_none=True,
143
+ exclude={'model_name', 'input_text', 'task_type', 'stream', 'chunk_delay', 'max_new_tokens', 'stopping_strings'}
144
  )
145
+ user_defined_stopping_strings = request.stopping_strings
146
 
147
  try:
148
  if not model_loader.check_model_exists_locally(model_name):
 
150
  raise HTTPException(status_code=500, detail=f"Failed to load model: {model_name}")
151
 
152
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
153
+ config = AutoConfig.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
154
+ stopping_criteria_list = StoppingCriteriaList()
155
+
156
+ # Add user-defined stopping strings if provided
157
+ if user_defined_stopping_strings:
158
+ stop_words_ids = [tokenizer.encode(stop_string, add_special_tokens=False) for stop_string in user_defined_stopping_strings]
159
+ stopping_criteria_list.append(StopOnKeywords(stop_words_ids))
160
+
161
+ # Automatically add EOS token as a stopping criterion
162
+ if config.eos_token_id is not None:
163
+ eos_token_ids = [config.eos_token_id]
164
+ if isinstance(config.eos_token_id, int):
165
+ eos_token_ids = [[config.eos_token_id]]
166
+ elif isinstance(config.eos_token_id, list):
167
+ eos_token_ids = [[id] for id in config.eos_token_id]
168
+ stop_words_ids_eos = [tokenizer.encode(tokenizer.decode(eos_id), add_special_tokens=False) for eos_id in eos_token_ids]
169
+ stopping_criteria_list.append(StopOnKeywords(stop_words_ids_eos))
170
+ elif tokenizer.eos_token is not None:
171
+ stop_words_ids_eos = [tokenizer.encode(tokenizer.eos_token, add_special_tokens=False)]
172
+ stopping_criteria_list.append(StopOnKeywords(stop_words_ids_eos))
173
 
174
  async def generate_responses() -> AsyncIterator[Dict[str, List[Dict[str, str]]]]:
175
  all_generated_text = ""
176
+ stop_reason = None # To track why the generation stopped
177
+
178
+ while True: # Loop indefinitely, relying on stopping criteria
179
  text_pipeline = pipeline(
180
  task_type,
181
  model=model_name,
182
  tokenizer=tokenizer,
183
  token=HUGGINGFACE_HUB_TOKEN,
184
+ stopping_criteria=stopping_criteria_list,
185
  **generation_params,
186
+ max_new_tokens=requested_max_new_tokens # Generate in chunks
187
  )
188
 
189
  def generate_on_thread(pipeline, input_text, output_queue):
 
197
  thread.join()
198
 
199
  newly_generated_text = result[0]['generated_text'][len(all_generated_text):]
200
+
201
+ if not newly_generated_text:
202
+ break # Should ideally not happen with proper stopping criteria
203
 
204
  all_generated_text += newly_generated_text
205
  yield {"response": [{'generated_text': newly_generated_text}]}
206
+
207
+ # Check if any stopping criteria was met
208
+ if stopping_criteria_list:
209
+ for criteria in stopping_criteria_list:
210
+ if isinstance(criteria, StopOnKeywords) and criteria.current_encounters > 0:
211
+ stop_reason = "stopping_string"
212
+ break
213
+ if stop_reason:
214
+ break
215
+
216
+ # If the generated text seems to match the EOS token, stop
217
+ if config.eos_token_id is not None:
218
+ eos_tokens = [config.eos_token_id]
219
+ if isinstance(config.eos_token_id, int):
220
+ eos_tokens = [config.eos_token_id]
221
+ elif isinstance(config.eos_token_id, list):
222
+ eos_tokens = config.eos_token_id
223
+ for eos_token in eos_tokens:
224
+ if tokenizer.decode([eos_token]) in newly_generated_text:
225
+ stop_reason = "eos_token"
226
+ break
227
+ if stop_reason:
228
+ break
229
+ elif tokenizer.eos_token is not None and tokenizer.eos_token in newly_generated_text:
230
+ stop_reason = "eos_token"
231
+ break
232
+
233
+ # Update input text for the next iteration
234
  input_text = all_generated_text
235
 
236
  async def text_stream():
 
247
  raise HTTPException(status_code=500, detail=f"Internal server error: {e}")
248
 
249
  if __name__ == "__main__":
250
+ import torch
251
  uvicorn.run(app, host="0.0.0.0", port=7860)