Spaces:
Runtime error
Runtime error
Configured more architectures to try and debug the fine-tuning issue each with different prompt styles
Browse files- src/architectures.py +14 -3
src/architectures.py
CHANGED
@@ -436,8 +436,10 @@ class HFInferenceEndpoint(ArchitectureComponent):
|
|
436 |
A concrete pipeline component which sends the user text to a given llama chat based
|
437 |
inference endpoint on HuggingFace
|
438 |
"""
|
439 |
-
def __init__(self, endpoint_url: str, model_name: str, system_prompt: str, max_new_tokens: int,
|
|
|
440 |
self.endpoint_url: str = endpoint_url
|
|
|
441 |
self.model_name: str = model_name
|
442 |
self.system_prompt: str = system_prompt
|
443 |
self.max_new_tokens = max_new_tokens
|
@@ -466,8 +468,17 @@ class HFInferenceEndpoint(ArchitectureComponent):
|
|
466 |
"Authorization": f"Bearer {self.api_token}",
|
467 |
"Content-Type": "application/json"
|
468 |
}
|
469 |
-
|
470 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
471 |
payload = {
|
472 |
"inputs": query_input,
|
473 |
"parameters": {
|
|
|
436 |
A concrete pipeline component which sends the user text to a given llama chat based
|
437 |
inference endpoint on HuggingFace
|
438 |
"""
|
439 |
+
def __init__(self, endpoint_url: str, model_name: str, system_prompt: str, max_new_tokens: int,
|
440 |
+
temperature: float = 1.0, prompt_style: str = "full"):
|
441 |
self.endpoint_url: str = endpoint_url
|
442 |
+
self.prompt_style = prompt_style
|
443 |
self.model_name: str = model_name
|
444 |
self.system_prompt: str = system_prompt
|
445 |
self.max_new_tokens = max_new_tokens
|
|
|
468 |
"Authorization": f"Bearer {self.api_token}",
|
469 |
"Content-Type": "application/json"
|
470 |
}
|
471 |
+
|
472 |
+
if self.prompt_style == "multi_line":
|
473 |
+
query_input = f"<s>[INST] <<SYS>>\n{self.system_prompt}\n<</SYS>>\n\n{request.request} [/INST] "
|
474 |
+
elif self.prompt_style == "multi_line_no_sys":
|
475 |
+
query_input = f"<s>[INST]\n{request.request} [/INST] "
|
476 |
+
elif self.prompt_style == "single_line_no_sys":
|
477 |
+
query_input = f"<s>[INST] {request.request} [/INST] "
|
478 |
+
elif self.prompt_style == "single_line":
|
479 |
+
query_input = f"<s>[INST] <<SYS>>\n{self.system_prompt}\n<</SYS>> {request.request} [/INST] "
|
480 |
+
else:
|
481 |
+
raise ValueError(f"Config error - Unknown prompt style: {self.prompt_style}")
|
482 |
payload = {
|
483 |
"inputs": query_input,
|
484 |
"parameters": {
|