alfraser commited on
Commit
53169ab
·
1 Parent(s): 3991f6c

Configured more architectures to try and debug the fine-tuning issue each with different prompt styles

Browse files
Files changed (1) hide show
  1. src/architectures.py +14 -3
src/architectures.py CHANGED
@@ -436,8 +436,10 @@ class HFInferenceEndpoint(ArchitectureComponent):
436
  A concrete pipeline component which sends the user text to a given llama chat based
437
  inference endpoint on HuggingFace
438
  """
439
- def __init__(self, endpoint_url: str, model_name: str, system_prompt: str, max_new_tokens: int, temperature: float = 1.0):
 
440
  self.endpoint_url: str = endpoint_url
 
441
  self.model_name: str = model_name
442
  self.system_prompt: str = system_prompt
443
  self.max_new_tokens = max_new_tokens
@@ -466,8 +468,17 @@ class HFInferenceEndpoint(ArchitectureComponent):
466
  "Authorization": f"Bearer {self.api_token}",
467
  "Content-Type": "application/json"
468
  }
469
- #return f"<s>[INST] <<SYS>>\n{sys_prompt}\n<</SYS>>\n\n{q}[/INST]{a}"
470
- query_input = f"<s>[INST] <<SYS>>\n{self.system_prompt}\n<</SYS>>\n\n{request.request}[/INST] "
 
 
 
 
 
 
 
 
 
471
  payload = {
472
  "inputs": query_input,
473
  "parameters": {
 
436
  A concrete pipeline component which sends the user text to a given llama chat based
437
  inference endpoint on HuggingFace
438
  """
439
+ def __init__(self, endpoint_url: str, model_name: str, system_prompt: str, max_new_tokens: int,
440
+ temperature: float = 1.0, prompt_style: str = "full"):
441
  self.endpoint_url: str = endpoint_url
442
+ self.prompt_style = prompt_style
443
  self.model_name: str = model_name
444
  self.system_prompt: str = system_prompt
445
  self.max_new_tokens = max_new_tokens
 
468
  "Authorization": f"Bearer {self.api_token}",
469
  "Content-Type": "application/json"
470
  }
471
+
472
+ if self.prompt_style == "multi_line":
473
+ query_input = f"<s>[INST] <<SYS>>\n{self.system_prompt}\n<</SYS>>\n\n{request.request} [/INST] "
474
+ elif self.prompt_style == "multi_line_no_sys":
475
+ query_input = f"<s>[INST]\n{request.request} [/INST] "
476
+ elif self.prompt_style == "single_line_no_sys":
477
+ query_input = f"<s>[INST] {request.request} [/INST] "
478
+ elif self.prompt_style == "single_line":
479
+ query_input = f"<s>[INST] <<SYS>>\n{self.system_prompt}\n<</SYS>> {request.request} [/INST] "
480
+ else:
481
+ raise ValueError(f"Config error - Unknown prompt style: {self.prompt_style}")
482
  payload = {
483
  "inputs": query_input,
484
  "parameters": {