Update app.py
Browse files
app.py
CHANGED
@@ -14,6 +14,8 @@ from dotenv import load_dotenv
|
|
14 |
import huggingface_hub
|
15 |
from threading import Thread
|
16 |
from typing import AsyncIterator, List, Dict
|
|
|
|
|
17 |
|
18 |
load_dotenv()
|
19 |
|
@@ -55,7 +57,8 @@ class GenerateRequest(BaseModel):
|
|
55 |
num_return_sequences: int = 1
|
56 |
do_sample: bool = False
|
57 |
chunk_delay: float = 0.0
|
58 |
-
max_new_tokens: int =
|
|
|
59 |
|
60 |
@field_validator("model_name")
|
61 |
def model_name_cannot_be_empty(cls, v):
|
@@ -70,11 +73,20 @@ class GenerateRequest(BaseModel):
|
|
70 |
raise ValueError(f"task_type must be one of: {valid_types}")
|
71 |
return v
|
72 |
|
73 |
-
|
74 |
-
def
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
class GCSModelLoader:
|
80 |
def __init__(self, bucket):
|
@@ -125,11 +137,12 @@ async def generate(request: GenerateRequest):
|
|
125 |
model_name = request.model_name
|
126 |
input_text = request.input_text
|
127 |
task_type = request.task_type
|
128 |
-
|
129 |
generation_params = request.model_dump(
|
130 |
exclude_none=True,
|
131 |
-
exclude={'model_name', 'input_text', 'task_type', 'stream', 'chunk_delay', 'max_new_tokens'}
|
132 |
)
|
|
|
133 |
|
134 |
try:
|
135 |
if not model_loader.check_model_exists_locally(model_name):
|
@@ -137,19 +150,40 @@ async def generate(request: GenerateRequest):
|
|
137 |
raise HTTPException(status_code=500, detail=f"Failed to load model: {model_name}")
|
138 |
|
139 |
tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
async def generate_responses() -> AsyncIterator[Dict[str, List[Dict[str, str]]]]:
|
142 |
all_generated_text = ""
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
text_pipeline = pipeline(
|
147 |
task_type,
|
148 |
model=model_name,
|
149 |
tokenizer=tokenizer,
|
150 |
token=HUGGINGFACE_HUB_TOKEN,
|
|
|
151 |
**generation_params,
|
152 |
-
max_new_tokens=
|
153 |
)
|
154 |
|
155 |
def generate_on_thread(pipeline, input_text, output_queue):
|
@@ -163,13 +197,40 @@ async def generate(request: GenerateRequest):
|
|
163 |
thread.join()
|
164 |
|
165 |
newly_generated_text = result[0]['generated_text'][len(all_generated_text):]
|
166 |
-
|
167 |
-
|
|
|
168 |
|
169 |
all_generated_text += newly_generated_text
|
170 |
yield {"response": [{'generated_text': newly_generated_text}]}
|
171 |
-
|
172 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
input_text = all_generated_text
|
174 |
|
175 |
async def text_stream():
|
@@ -186,4 +247,5 @@ async def generate(request: GenerateRequest):
|
|
186 |
raise HTTPException(status_code=500, detail=f"Internal server error: {e}")
|
187 |
|
188 |
if __name__ == "__main__":
|
|
|
189 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
14 |
import huggingface_hub
|
15 |
from threading import Thread
|
16 |
from typing import AsyncIterator, List, Dict
|
17 |
+
from transformers.stopping_criteria import StoppingCriteria, StoppingCriteriaList
|
18 |
+
import torch
|
19 |
|
20 |
load_dotenv()
|
21 |
|
|
|
57 |
num_return_sequences: int = 1
|
58 |
do_sample: bool = False
|
59 |
chunk_delay: float = 0.0
|
60 |
+
max_new_tokens: int = 10
|
61 |
+
stopping_strings: List[str] = None
|
62 |
|
63 |
@field_validator("model_name")
|
64 |
def model_name_cannot_be_empty(cls, v):
|
|
|
73 |
raise ValueError(f"task_type must be one of: {valid_types}")
|
74 |
return v
|
75 |
|
76 |
+
class StopOnKeywords(StoppingCriteria):
|
77 |
+
def __init__(self, stop_words_ids: List[List[int]], encounters: int = 1):
|
78 |
+
super().__init__()
|
79 |
+
self.stop_words_ids = stop_words_ids
|
80 |
+
self.encounters = encounters
|
81 |
+
self.current_encounters = 0
|
82 |
+
|
83 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
84 |
+
for stop_ids in self.stop_words_ids:
|
85 |
+
if torch.all(input_ids[0][-len(stop_ids):] == torch.tensor(stop_ids).to(input_ids.device)):
|
86 |
+
self.current_encounters += 1
|
87 |
+
if self.current_encounters >= self.encounters:
|
88 |
+
return True
|
89 |
+
return False
|
90 |
|
91 |
class GCSModelLoader:
|
92 |
def __init__(self, bucket):
|
|
|
137 |
model_name = request.model_name
|
138 |
input_text = request.input_text
|
139 |
task_type = request.task_type
|
140 |
+
requested_max_new_tokens = request.max_new_tokens
|
141 |
generation_params = request.model_dump(
|
142 |
exclude_none=True,
|
143 |
+
exclude={'model_name', 'input_text', 'task_type', 'stream', 'chunk_delay', 'max_new_tokens', 'stopping_strings'}
|
144 |
)
|
145 |
+
user_defined_stopping_strings = request.stopping_strings
|
146 |
|
147 |
try:
|
148 |
if not model_loader.check_model_exists_locally(model_name):
|
|
|
150 |
raise HTTPException(status_code=500, detail=f"Failed to load model: {model_name}")
|
151 |
|
152 |
tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
|
153 |
+
config = AutoConfig.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
|
154 |
+
stopping_criteria_list = StoppingCriteriaList()
|
155 |
+
|
156 |
+
# Add user-defined stopping strings if provided
|
157 |
+
if user_defined_stopping_strings:
|
158 |
+
stop_words_ids = [tokenizer.encode(stop_string, add_special_tokens=False) for stop_string in user_defined_stopping_strings]
|
159 |
+
stopping_criteria_list.append(StopOnKeywords(stop_words_ids))
|
160 |
+
|
161 |
+
# Automatically add EOS token as a stopping criterion
|
162 |
+
if config.eos_token_id is not None:
|
163 |
+
eos_token_ids = [config.eos_token_id]
|
164 |
+
if isinstance(config.eos_token_id, int):
|
165 |
+
eos_token_ids = [[config.eos_token_id]]
|
166 |
+
elif isinstance(config.eos_token_id, list):
|
167 |
+
eos_token_ids = [[id] for id in config.eos_token_id]
|
168 |
+
stop_words_ids_eos = [tokenizer.encode(tokenizer.decode(eos_id), add_special_tokens=False) for eos_id in eos_token_ids]
|
169 |
+
stopping_criteria_list.append(StopOnKeywords(stop_words_ids_eos))
|
170 |
+
elif tokenizer.eos_token is not None:
|
171 |
+
stop_words_ids_eos = [tokenizer.encode(tokenizer.eos_token, add_special_tokens=False)]
|
172 |
+
stopping_criteria_list.append(StopOnKeywords(stop_words_ids_eos))
|
173 |
|
174 |
async def generate_responses() -> AsyncIterator[Dict[str, List[Dict[str, str]]]]:
|
175 |
all_generated_text = ""
|
176 |
+
stop_reason = None # To track why the generation stopped
|
177 |
+
|
178 |
+
while True: # Loop indefinitely, relying on stopping criteria
|
179 |
text_pipeline = pipeline(
|
180 |
task_type,
|
181 |
model=model_name,
|
182 |
tokenizer=tokenizer,
|
183 |
token=HUGGINGFACE_HUB_TOKEN,
|
184 |
+
stopping_criteria=stopping_criteria_list,
|
185 |
**generation_params,
|
186 |
+
max_new_tokens=requested_max_new_tokens # Generate in chunks
|
187 |
)
|
188 |
|
189 |
def generate_on_thread(pipeline, input_text, output_queue):
|
|
|
197 |
thread.join()
|
198 |
|
199 |
newly_generated_text = result[0]['generated_text'][len(all_generated_text):]
|
200 |
+
|
201 |
+
if not newly_generated_text:
|
202 |
+
break # Should ideally not happen with proper stopping criteria
|
203 |
|
204 |
all_generated_text += newly_generated_text
|
205 |
yield {"response": [{'generated_text': newly_generated_text}]}
|
206 |
+
|
207 |
+
# Check if any stopping criteria was met
|
208 |
+
if stopping_criteria_list:
|
209 |
+
for criteria in stopping_criteria_list:
|
210 |
+
if isinstance(criteria, StopOnKeywords) and criteria.current_encounters > 0:
|
211 |
+
stop_reason = "stopping_string"
|
212 |
+
break
|
213 |
+
if stop_reason:
|
214 |
+
break
|
215 |
+
|
216 |
+
# If the generated text seems to match the EOS token, stop
|
217 |
+
if config.eos_token_id is not None:
|
218 |
+
eos_tokens = [config.eos_token_id]
|
219 |
+
if isinstance(config.eos_token_id, int):
|
220 |
+
eos_tokens = [config.eos_token_id]
|
221 |
+
elif isinstance(config.eos_token_id, list):
|
222 |
+
eos_tokens = config.eos_token_id
|
223 |
+
for eos_token in eos_tokens:
|
224 |
+
if tokenizer.decode([eos_token]) in newly_generated_text:
|
225 |
+
stop_reason = "eos_token"
|
226 |
+
break
|
227 |
+
if stop_reason:
|
228 |
+
break
|
229 |
+
elif tokenizer.eos_token is not None and tokenizer.eos_token in newly_generated_text:
|
230 |
+
stop_reason = "eos_token"
|
231 |
+
break
|
232 |
+
|
233 |
+
# Update input text for the next iteration
|
234 |
input_text = all_generated_text
|
235 |
|
236 |
async def text_stream():
|
|
|
247 |
raise HTTPException(status_code=500, detail=f"Internal server error: {e}")
|
248 |
|
249 |
if __name__ == "__main__":
|
250 |
+
import torch
|
251 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|