Hjgugugjhuhjggg commited on
Commit
63f92cf
·
verified ·
1 Parent(s): fcc4055

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -12
app.py CHANGED
@@ -48,7 +48,7 @@ class GenerateRequest(BaseModel):
48
  input_text: str
49
  task_type: str
50
  temperature: float = 1.0
51
- stream: bool = True # Enforce stream for this functionality
52
  top_p: float = 1.0
53
  top_k: int = 50
54
  repetition_penalty: float = 1.0
@@ -149,21 +149,20 @@ async def generate(request: GenerateRequest):
149
  if not model_loader.download_model_from_huggingface(model_name):
150
  raise HTTPException(status_code=500, detail=f"Failed to load model: {model_name}")
151
 
152
- pipe = pipeline(task_type, model=model_name, token=HUGGINGFACE_HUB_TOKEN, device_map="auto")
153
  token_streamer = TokenIteratorStreamer()
154
 
155
- def generate_on_thread(pipe, input_text, token_streamer, generation_params):
156
  try:
157
- for output in pipe(input_text,
158
- max_new_tokens=int(1e9), # Effectively infinite
159
- return_full_text=False,
160
- streamer=token_streamer,
161
- **generation_params):
162
- pass
163
  finally:
164
- token_streamer.end()
165
 
166
- thread = Thread(target=generate_on_thread, args=(pipe, input_text, token_streamer, generation_params))
167
  thread.start()
168
 
169
  async def event_stream() -> AsyncIterator[str]:
@@ -177,7 +176,7 @@ async def generate(request: GenerateRequest):
177
  await asyncio.sleep(request.chunk_delay)
178
  if tokens_buffer:
179
  yield f"data: {json.dumps({'tokens': tokens_buffer})}\n\n"
180
- yield "\n\n" # Ensure final newline
181
 
182
  return StreamingResponse(event_stream(), media_type="text/event-stream")
183
 
 
48
  input_text: str
49
  task_type: str
50
  temperature: float = 1.0
51
+ stream: bool = True
52
  top_p: float = 1.0
53
  top_k: int = 50
54
  repetition_penalty: float = 1.0
 
149
  if not model_loader.download_model_from_huggingface(model_name):
150
  raise HTTPException(status_code=500, detail=f"Failed to load model: {model_name}")
151
 
152
+ text_pipeline = pipeline(task_type, model=model_name, token=HUGGINGFACE_HUB_TOKEN, device_map="auto")
153
  token_streamer = TokenIteratorStreamer()
154
 
155
+ def generate_on_thread(pipeline, input_text, streamer, generation_params):
156
  try:
157
+ pipeline(input_text,
158
+ max_new_tokens=int(1e9), # Effectively infinite
159
+ return_full_text=False,
160
+ streamer=streamer,
161
+ **generation_params)
 
162
  finally:
163
+ streamer.end()
164
 
165
+ thread = Thread(target=generate_on_thread, args=(text_pipeline, input_text, token_streamer, generation_params))
166
  thread.start()
167
 
168
  async def event_stream() -> AsyncIterator[str]:
 
176
  await asyncio.sleep(request.chunk_delay)
177
  if tokens_buffer:
178
  yield f"data: {json.dumps({'tokens': tokens_buffer})}\n\n"
179
+ yield "\n\n"
180
 
181
  return StreamingResponse(event_stream(), media_type="text/event-stream")
182