from open_lm.hf import * from transformers import AutoTokenizer, AutoModelForCausalLM class Chat: def __init__( self, path="mathewhe/DCLM-7B-Chat", device="cuda", ): r""" Construct :class:`Chat`\. Args: path (str): Model name or path. device (str): Model device. """ self.tokenizer = AutoTokenizer.from_pretrained(path) self.tokenizer.add_tokens( ["[ASST]", "[INST]", "[/ASST]", "[/INST]"], special_tokens=True, ) self.model = AutoModelForCausalLM.from_pretrained(path, device_map="cuda") self.messages = list() self.device = device self.gen_kwargs = { "min_new_tokens": 1, "max_new_tokens": 2048, "top_p": 0.8, "temperature": 0.8, "do_sample": True, "repetition_penalty": 1.1, } def reset(self): self.messages = list() def _inference(self, messages): chat = self.tokenizer.apply_chat_template(messages, tokenize=False) inputs = { k: v.to(self.device) for k, v in self.tokenizer(chat, return_tensors="pt").items() } input_length = len(inputs["input_ids"][0]) output = self.model.generate(**inputs, **self.gen_kwargs) response = self.tokenizer.decode( output[0].tolist()[input_length:], skip_special_tokens=True, ) if response.startswith(" "): # fix this so it's handled correctly by the tokenizer response = response[1:] return response def message(self, message): r""" Add a user message to the chat history and save and return a response. Args: message (str): The user message. """ self.messages.append({"role": "user", "content": message}) response = self._inference(self.messages) self.messages.append({"role": "assistant", "content": response}) return response def cli_chat(self): r""" For CLI-based chatting (with history). """ asst_prompt = "Assistant: " user_prompt = "---> User: " print(f"{asst_prompt}Hi! How can I help you?\n") message = input(user_prompt) while not (message is None or message == ""): response = self.message(message) print(f"\n{asst_prompt}{response}\n") message = input(user_prompt) def instruct(self, message): r""" For single instruction-response interactions (without history). Args: message (str): An instruction or one-off user message. """ messages = [{"role": "user", "content": message}] response = self._inference(messages) return response if __name__ == "__main__": chat = Chat() chat.cli_chat()