|
from open_lm.hf import * |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
class Chat: |
|
def __init__( |
|
self, |
|
path="mathewhe/DCLM-7B-Chat", |
|
device="cuda", |
|
): |
|
r""" |
|
Construct :class:`Chat`\. |
|
|
|
Args: |
|
path (str): Model name or path. |
|
device (str): Model device. |
|
""" |
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
self.tokenizer.add_tokens( |
|
["[ASST]", "[INST]", "[/ASST]", "[/INST]"], |
|
special_tokens=True, |
|
) |
|
self.model = AutoModelForCausalLM.from_pretrained(path, device_map="cuda") |
|
|
|
self.messages = list() |
|
self.device = device |
|
self.gen_kwargs = { |
|
"min_new_tokens": 1, |
|
"max_new_tokens": 2048, |
|
"top_p": 0.8, |
|
"temperature": 0.8, |
|
"do_sample": True, |
|
"repetition_penalty": 1.1, |
|
} |
|
|
|
def reset(self): |
|
self.messages = list() |
|
|
|
def _inference(self, messages): |
|
chat = self.tokenizer.apply_chat_template(messages, tokenize=False) |
|
inputs = { |
|
k: v.to(self.device) |
|
for k, v in self.tokenizer(chat, return_tensors="pt").items() |
|
} |
|
input_length = len(inputs["input_ids"][0]) |
|
output = self.model.generate(**inputs, **self.gen_kwargs) |
|
response = self.tokenizer.decode( |
|
output[0].tolist()[input_length:], |
|
skip_special_tokens=True, |
|
) |
|
if response.startswith(" "): |
|
response = response[1:] |
|
return response |
|
|
|
def message(self, message): |
|
r""" |
|
Add a user message to the chat history and save and return a response. |
|
|
|
Args: |
|
message (str): The user message. |
|
""" |
|
self.messages.append({"role": "user", "content": message}) |
|
response = self._inference(self.messages) |
|
self.messages.append({"role": "assistant", "content": response}) |
|
return response |
|
|
|
def cli_chat(self): |
|
r""" |
|
For CLI-based chatting (with history). |
|
""" |
|
asst_prompt = "Assistant: " |
|
user_prompt = "---> User: " |
|
print(f"{asst_prompt}Hi! How can I help you?\n") |
|
message = input(user_prompt) |
|
while not (message is None or message == ""): |
|
response = self.message(message) |
|
print(f"\n{asst_prompt}{response}\n") |
|
message = input(user_prompt) |
|
|
|
def instruct(self, message): |
|
r""" |
|
For single instruction-response interactions (without history). |
|
|
|
Args: |
|
message (str): An instruction or one-off user message. |
|
""" |
|
messages = [{"role": "user", "content": message}] |
|
response = self._inference(messages) |
|
return response |
|
|
|
|
|
if __name__ == "__main__": |
|
chat = Chat() |
|
chat.cli_chat() |
|
|