Spaces:
Runtime error
Runtime error
File size: 3,932 Bytes
4d82421 07a6749 4d82421 07a6749 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from transformers import AutoTokenizer
from fastchat.conversation import get_conv_template
import os
from utils import sanitize_jinja2
def test_llama2_template():
jinja_lines = []
with open("../templates/llama-2.jinja2", "r") as f:
jinja_lines = f.readlines()
print("jinja_lines: ", jinja_lines)
print("sanitized: ", sanitize_jinja2(jinja_lines))
chat = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing great. How can I help you today?"},
{"role": "user", "content": "I'd like to show off how chat templating works!"},
]
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="microsoft/Orca-2-7b", trust_remote_code=True)
# f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant"
transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=False)
print("default template")
print(transformer_prompt)
# print(tokenizer.chat_template)
tokenizer.bos_token = "<s>"
tokenizer.eos_token = "</s>"
tokenizer.chat_template = sanitize_jinja2(jinja_lines)
transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=False)
print()
print("add_generation_prompt False:")
print(transformer_prompt)
transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
print()
print("add_generation_prompt True:")
print(transformer_prompt)
print("Fastchat template: ")
conv = get_conv_template("llama-2")
conv.set_system_message(chat[0]["content"])
conv.append_message(conv.roles[0], chat[1]["content"])
conv.append_message(conv.roles[1], chat[2]["content"])
conv.append_message(conv.roles[0], chat[3]["content"])
conv.append_message(conv.roles[1], None)
print(conv.get_prompt())
# assert transformer_prompt == conv.get_prompt()
def test_llama2_no_sys_prompt_template():
jinja_lines = []
with open("../templates/llama-2.jinja2", "r") as f:
jinja_lines = f.readlines()
print("jinja_lines: ", jinja_lines)
print("sanitized: ", sanitize_jinja2(jinja_lines))
chat = [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing great. How can I help you today?"},
{"role": "user", "content": "I'd like to show off how chat templating works!"},
]
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="microsoft/Orca-2-7b", trust_remote_code=True)
# f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant"
transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=False)
print("default template")
print(transformer_prompt)
# print(tokenizer.chat_template)
tokenizer.bos_token = "<s>"
tokenizer.eos_token = "</s>"
tokenizer.chat_template = sanitize_jinja2(jinja_lines)
transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=False)
print()
print("add_generation_prompt False:")
print(transformer_prompt)
transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
print()
print("add_generation_prompt True:")
print(transformer_prompt)
print("Fastchat template: ")
conv = get_conv_template("llama-2")
# conv.set_system_message(chat[0]["content"])
conv.append_message(conv.roles[0], chat[0]["content"])
conv.append_message(conv.roles[1], chat[1]["content"])
conv.append_message(conv.roles[0], chat[2]["content"])
conv.append_message(conv.roles[1], None)
print(conv.get_prompt())
# assert transformer_prompt == conv.get_prompt()
if __name__ == "__main__":
test_llama2_template()
test_llama2_no_sys_prompt_template() |