Update app.py
Browse files
app.py
CHANGED
@@ -48,7 +48,7 @@ class GenerateRequest(BaseModel):
|
|
48 |
input_text: str
|
49 |
task_type: str
|
50 |
temperature: float = 1.0
|
51 |
-
stream: bool = True
|
52 |
top_p: float = 1.0
|
53 |
top_k: int = 50
|
54 |
repetition_penalty: float = 1.0
|
@@ -149,21 +149,20 @@ async def generate(request: GenerateRequest):
|
|
149 |
if not model_loader.download_model_from_huggingface(model_name):
|
150 |
raise HTTPException(status_code=500, detail=f"Failed to load model: {model_name}")
|
151 |
|
152 |
-
|
153 |
token_streamer = TokenIteratorStreamer()
|
154 |
|
155 |
-
def generate_on_thread(
|
156 |
try:
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
pass
|
163 |
finally:
|
164 |
-
|
165 |
|
166 |
-
thread = Thread(target=generate_on_thread, args=(
|
167 |
thread.start()
|
168 |
|
169 |
async def event_stream() -> AsyncIterator[str]:
|
@@ -177,7 +176,7 @@ async def generate(request: GenerateRequest):
|
|
177 |
await asyncio.sleep(request.chunk_delay)
|
178 |
if tokens_buffer:
|
179 |
yield f"data: {json.dumps({'tokens': tokens_buffer})}\n\n"
|
180 |
-
yield "\n\n"
|
181 |
|
182 |
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
183 |
|
|
|
48 |
input_text: str
|
49 |
task_type: str
|
50 |
temperature: float = 1.0
|
51 |
+
stream: bool = True
|
52 |
top_p: float = 1.0
|
53 |
top_k: int = 50
|
54 |
repetition_penalty: float = 1.0
|
|
|
149 |
if not model_loader.download_model_from_huggingface(model_name):
|
150 |
raise HTTPException(status_code=500, detail=f"Failed to load model: {model_name}")
|
151 |
|
152 |
+
text_pipeline = pipeline(task_type, model=model_name, token=HUGGINGFACE_HUB_TOKEN, device_map="auto")
|
153 |
token_streamer = TokenIteratorStreamer()
|
154 |
|
155 |
+
def generate_on_thread(pipeline, input_text, streamer, generation_params):
|
156 |
try:
|
157 |
+
pipeline(input_text,
|
158 |
+
max_new_tokens=int(1e9), # Effectively infinite
|
159 |
+
return_full_text=False,
|
160 |
+
streamer=streamer,
|
161 |
+
**generation_params)
|
|
|
162 |
finally:
|
163 |
+
streamer.end()
|
164 |
|
165 |
+
thread = Thread(target=generate_on_thread, args=(text_pipeline, input_text, token_streamer, generation_params))
|
166 |
thread.start()
|
167 |
|
168 |
async def event_stream() -> AsyncIterator[str]:
|
|
|
176 |
await asyncio.sleep(request.chunk_delay)
|
177 |
if tokens_buffer:
|
178 |
yield f"data: {json.dumps({'tokens': tokens_buffer})}\n\n"
|
179 |
+
yield "\n\n"
|
180 |
|
181 |
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
182 |
|