import os import platform import signal import time from transformers import AutoTokenizer, AutoModel from multi_input import MultiInputInCmd if __name__ == "__main__": tokenizer = AutoTokenizer.from_pretrained("E:\ProjectEX\LLM\ChatGLM-6B\chatglm-6b-int4", trust_remote_code=True) model = AutoModel.from_pretrained("E:\ProjectEX\LLM\ChatGLM-6B\chatglm-6b-int4", trust_remote_code=True).float() model = model.quantize(bits=4, kernel_file="E:\ProjectEX\LLM\ChatGLM-6B\chatglm-6b-int4\quantization_kernels_parallel.so") model = model.eval() os_name = platform.system() clear_command = 'cls' if os_name == 'Windows' else 'clear' stop_stream = False def build_prompt(history): prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序" for query, response in history: prompt += f"\n\n用户:{query}" prompt += f"\n\nChatGLM-6B:{response}" return prompt def signal_handler(signal, frame): global stop_stream stop_stream = True def main(): history = [] global stop_stream print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") while True: # query = input("\n用户:") input_fun = MultiInputInCmd("\n用户:") all_input_lines = input_fun.run() # handle as normal message query = '' for index in range(len(all_input_lines)): if index == len(all_input_lines) - 1: query = query + all_input_lines[index] else: query = query + all_input_lines[index] + '\n' if query.strip() == "stop": break if query.strip() == "clear": history = [] os.system(clear_command) print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") continue last_index = 0 start = time.time() for response, history in model.stream_chat(tokenizer, query, history=history): if stop_stream: stop_stream = False break else: print(response[last_index:], end='', flush=True) last_index = len(response) signal.signal(signal.SIGINT, signal_handler) print((time.time() - start) / last_index) print('') main()