Update app.py
Browse files
app.py
CHANGED
@@ -70,6 +70,12 @@ class GenerateRequest(BaseModel):
|
|
70 |
raise ValueError(f"task_type must be one of: {valid_types}")
|
71 |
return v
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
class GCSModelLoader:
|
74 |
def __init__(self, bucket):
|
75 |
self.bucket = bucket
|
@@ -119,7 +125,7 @@ async def generate(request: GenerateRequest):
|
|
119 |
model_name = request.model_name
|
120 |
input_text = request.input_text
|
121 |
task_type = request.task_type
|
122 |
-
initial_max_new_tokens = request.max_new_tokens
|
123 |
generation_params = request.model_dump(
|
124 |
exclude_none=True,
|
125 |
exclude={'model_name', 'input_text', 'task_type', 'stream', 'chunk_delay', 'max_new_tokens'}
|
@@ -133,18 +139,38 @@ async def generate(request: GenerateRequest):
|
|
133 |
tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
|
134 |
|
135 |
async def generate_responses() -> AsyncIterator[Dict[str, List[Dict[str, str]]]]:
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
async def text_stream():
|
150 |
async for data in generate_responses():
|
|
|
70 |
raise ValueError(f"task_type must be one of: {valid_types}")
|
71 |
return v
|
72 |
|
73 |
+
@field_validator("max_new_tokens")
|
74 |
+
def max_new_tokens_must_be_within_limit(cls, v):
|
75 |
+
if v > 10:
|
76 |
+
raise ValueError("max_new_tokens cannot exceed 10.")
|
77 |
+
return v
|
78 |
+
|
79 |
class GCSModelLoader:
|
80 |
def __init__(self, bucket):
|
81 |
self.bucket = bucket
|
|
|
125 |
model_name = request.model_name
|
126 |
input_text = request.input_text
|
127 |
task_type = request.task_type
|
128 |
+
initial_max_new_tokens = request.max_new_tokens # The requested max tokens (will be max 10)
|
129 |
generation_params = request.model_dump(
|
130 |
exclude_none=True,
|
131 |
exclude={'model_name', 'input_text', 'task_type', 'stream', 'chunk_delay', 'max_new_tokens'}
|
|
|
139 |
tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
|
140 |
|
141 |
async def generate_responses() -> AsyncIterator[Dict[str, List[Dict[str, str]]]]:
|
142 |
+
all_generated_text = ""
|
143 |
+
remaining_tokens = 512 # Or some reasonable maximum
|
144 |
+
while remaining_tokens > 0:
|
145 |
+
current_max_new_tokens = min(initial_max_new_tokens, remaining_tokens)
|
146 |
+
text_pipeline = pipeline(
|
147 |
+
task_type,
|
148 |
+
model=model_name,
|
149 |
+
tokenizer=tokenizer,
|
150 |
+
token=HUGGINGFACE_HUB_TOKEN,
|
151 |
+
**generation_params,
|
152 |
+
max_new_tokens=current_max_new_tokens
|
153 |
+
)
|
154 |
+
|
155 |
+
def generate_on_thread(pipeline, input_text, output_queue):
|
156 |
+
result = pipeline(input_text)
|
157 |
+
output_queue.put_nowait(result)
|
158 |
+
|
159 |
+
output_queue = asyncio.Queue()
|
160 |
+
thread = Thread(target=generate_on_thread, args=(text_pipeline, input_text, output_queue))
|
161 |
+
thread.start()
|
162 |
+
result = await output_queue.get()
|
163 |
+
thread.join()
|
164 |
+
|
165 |
+
newly_generated_text = result[0]['generated_text'][len(all_generated_text):]
|
166 |
+
if not newly_generated_text: # Break if no new text is generated
|
167 |
+
break
|
168 |
+
|
169 |
+
all_generated_text += newly_generated_text
|
170 |
+
yield {"response": [{'generated_text': newly_generated_text}]}
|
171 |
+
remaining_tokens -= current_max_new_tokens
|
172 |
+
# Update input_text for the next iteration
|
173 |
+
input_text = all_generated_text
|
174 |
|
175 |
async def text_stream():
|
176 |
async for data in generate_responses():
|