Update services/prompt_builder.py
Browse files
services/prompt_builder.py
CHANGED
@@ -2,13 +2,13 @@
|
|
2 |
from typing import Protocol, List, Tuple
|
3 |
from transformers import AutoTokenizer
|
4 |
|
5 |
-
|
6 |
class PromptTemplate(Protocol):
|
7 |
"""Protocol for prompt templates."""
|
8 |
def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], **kwargs) -> str:
|
9 |
pass
|
10 |
|
11 |
-
|
12 |
class LlamaPromptTemplate:
|
13 |
def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], max_history_turns: int = 1) -> str:
|
14 |
system_message = f"Please assist based on the following context: {context}"
|
@@ -22,7 +22,7 @@ class LlamaPromptTemplate:
|
|
22 |
prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
23 |
return prompt
|
24 |
|
25 |
-
|
26 |
class TransformersPromptTemplate:
|
27 |
def __init__(self, model_path: str):
|
28 |
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
|
2 |
from typing import Protocol, List, Tuple
|
3 |
from transformers import AutoTokenizer
|
4 |
|
5 |
+
|
6 |
class PromptTemplate(Protocol):
|
7 |
"""Protocol for prompt templates."""
|
8 |
def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], **kwargs) -> str:
|
9 |
pass
|
10 |
|
11 |
+
|
12 |
class LlamaPromptTemplate:
|
13 |
def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], max_history_turns: int = 1) -> str:
|
14 |
system_message = f"Please assist based on the following context: {context}"
|
|
|
22 |
prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
23 |
return prompt
|
24 |
|
25 |
+
|
26 |
class TransformersPromptTemplate:
|
27 |
def __init__(self, model_path: str):
|
28 |
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|