Hjgugugjhuhjggg commited on
Commit
19f95bc
·
verified ·
1 Parent(s): 8b558e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -13
app.py CHANGED
@@ -70,6 +70,12 @@ class GenerateRequest(BaseModel):
70
  raise ValueError(f"task_type must be one of: {valid_types}")
71
  return v
72
 
 
 
 
 
 
 
73
  class GCSModelLoader:
74
  def __init__(self, bucket):
75
  self.bucket = bucket
@@ -119,7 +125,7 @@ async def generate(request: GenerateRequest):
119
  model_name = request.model_name
120
  input_text = request.input_text
121
  task_type = request.task_type
122
- initial_max_new_tokens = request.max_new_tokens
123
  generation_params = request.model_dump(
124
  exclude_none=True,
125
  exclude={'model_name', 'input_text', 'task_type', 'stream', 'chunk_delay', 'max_new_tokens'}
@@ -133,18 +139,38 @@ async def generate(request: GenerateRequest):
133
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
134
 
135
  async def generate_responses() -> AsyncIterator[Dict[str, List[Dict[str, str]]]]:
136
- text_pipeline = pipeline(task_type, model=model_name, tokenizer=tokenizer, token=HUGGINGFACE_HUB_TOKEN, **generation_params, max_new_tokens=initial_max_new_tokens)
137
-
138
- def generate_on_thread(pipeline, input_text, output_queue):
139
- result = pipeline(input_text)
140
- output_queue.put_nowait(result)
141
-
142
- output_queue = asyncio.Queue()
143
- thread = Thread(target=generate_on_thread, args=(text_pipeline, input_text, output_queue))
144
- thread.start()
145
- result = await output_queue.get()
146
- thread.join()
147
- yield {"response": result}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  async def text_stream():
150
  async for data in generate_responses():
 
70
  raise ValueError(f"task_type must be one of: {valid_types}")
71
  return v
72
 
73
+ @field_validator("max_new_tokens")
74
+ def max_new_tokens_must_be_within_limit(cls, v):
75
+ if v > 10:
76
+ raise ValueError("max_new_tokens cannot exceed 10.")
77
+ return v
78
+
79
  class GCSModelLoader:
80
  def __init__(self, bucket):
81
  self.bucket = bucket
 
125
  model_name = request.model_name
126
  input_text = request.input_text
127
  task_type = request.task_type
128
+ initial_max_new_tokens = request.max_new_tokens # The requested max tokens (will be max 10)
129
  generation_params = request.model_dump(
130
  exclude_none=True,
131
  exclude={'model_name', 'input_text', 'task_type', 'stream', 'chunk_delay', 'max_new_tokens'}
 
139
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
140
 
141
  async def generate_responses() -> AsyncIterator[Dict[str, List[Dict[str, str]]]]:
142
+ all_generated_text = ""
143
+ remaining_tokens = 512 # Or some reasonable maximum
144
+ while remaining_tokens > 0:
145
+ current_max_new_tokens = min(initial_max_new_tokens, remaining_tokens)
146
+ text_pipeline = pipeline(
147
+ task_type,
148
+ model=model_name,
149
+ tokenizer=tokenizer,
150
+ token=HUGGINGFACE_HUB_TOKEN,
151
+ **generation_params,
152
+ max_new_tokens=current_max_new_tokens
153
+ )
154
+
155
+ def generate_on_thread(pipeline, input_text, output_queue):
156
+ result = pipeline(input_text)
157
+ output_queue.put_nowait(result)
158
+
159
+ output_queue = asyncio.Queue()
160
+ thread = Thread(target=generate_on_thread, args=(text_pipeline, input_text, output_queue))
161
+ thread.start()
162
+ result = await output_queue.get()
163
+ thread.join()
164
+
165
+ newly_generated_text = result[0]['generated_text'][len(all_generated_text):]
166
+ if not newly_generated_text: # Break if no new text is generated
167
+ break
168
+
169
+ all_generated_text += newly_generated_text
170
+ yield {"response": [{'generated_text': newly_generated_text}]}
171
+ remaining_tokens -= current_max_new_tokens
172
+ # Update input_text for the next iteration
173
+ input_text = all_generated_text
174
 
175
  async def text_stream():
176
  async for data in generate_responses():