Safetensors
openlm
text
DCLM-7B-Chat / chat_class.py
mathewhe's picture
Fix typo in argument list
3bf2eac
raw
history blame
2.91 kB
from open_lm.hf import *
from transformers import AutoTokenizer, AutoModelForCausalLM
class Chat:
def __init__(
self,
path="mathewhe/DCLM-7B-Chat",
device="cuda",
):
r"""
Construct :class:`Chat`\.
Args:
path (str): Model name or path.
device (str): Model device.
"""
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.tokenizer.add_tokens(
["[ASST]", "[INST]", "[/ASST]", "[/INST]"],
special_tokens=True,
)
self.model = AutoModelForCausalLM.from_pretrained(path, device_map="cuda")
self.messages = list()
self.device = device
self.gen_kwargs = {
"min_new_tokens": 1,
"max_new_tokens": 2048,
"top_p": 0.8,
"temperature": 0.8,
"do_sample": True,
"repetition_penalty": 1.1,
}
def reset(self):
self.messages = list()
def _inference(self, messages):
chat = self.tokenizer.apply_chat_template(messages, tokenize=False)
inputs = {
k: v.to(self.device)
for k, v in self.tokenizer(chat, return_tensors="pt").items()
}
input_length = len(inputs["input_ids"][0])
output = self.model.generate(**inputs, **self.gen_kwargs)
response = self.tokenizer.decode(
output[0].tolist()[input_length:],
skip_special_tokens=True,
)
if response.startswith(" "): # fix this so it's handled correctly by the tokenizer
response = response[1:]
return response
def message(self, message):
r"""
Add a user message to the chat history and save and return a response.
Args:
message (str): The user message.
"""
self.messages.append({"role": "user", "content": message})
response = self._inference(self.messages)
self.messages.append({"role": "assistant", "content": response})
return response
def cli_chat(self):
r"""
For CLI-based chatting (with history).
"""
asst_prompt = "Assistant: "
user_prompt = "---> User: "
print(f"{asst_prompt}Hi! How can I help you?\n")
message = input(user_prompt)
while not (message is None or message == ""):
response = self.message(message)
print(f"\n{asst_prompt}{response}\n")
message = input(user_prompt)
def instruct(self, message):
r"""
For single instruction-response interactions (without history).
Args:
message (str): An instruction or one-off user message.
"""
messages = [{"role": "user", "content": message}]
response = self._inference(messages)
return response
if __name__ == "__main__":
chat = Chat()
chat.cli_chat()