Spaces:
Runtime error
Runtime error
File size: 3,560 Bytes
23aa310 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import os
import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer
# 载入Tokenizer
model_path = "..\\models\\chatglm-6b-int4"
CHECKPOINT_PATH = '.\\output\\adgen-chatglm-6b-pt-128-2e-2\\checkpoint-1000'
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# 如果需要加载的是新 Checkpoint(只包含 PrefixEncoder 参数):
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
# 之后根据需求可以进行量化,也可以直接使用:
kernel_file = "{}\\quantization_kernels.so".format(model_path)
model = model.quantize(bits=4,kernel_file=kernel_file)
model = model.half().cuda()
model.transformer.prefix_encoder.float()
model = model.eval()
# response, history = model.chat(tokenizer, "你好呀", history=[])
# print("response:", response)
def parse_text(text):
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split('`')
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f'<br></code></pre>'
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", "\`")
line = line.replace("<", "<")
line = line.replace(">", ">")
line = line.replace(" ", " ")
line = line.replace("*", "*")
line = line.replace("_", "_")
line = line.replace("-", "-")
line = line.replace(".", ".")
line = line.replace("!", "!")
line = line.replace("(", "(")
line = line.replace(")", ")")
line = line.replace("$", "$")
lines[i] = "<br>"+line
text = "".join(lines)
return text
def predict(input, chatbot, max_length, top_p, temperature, history):
chatbot.append((parse_text(input), ""))
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
temperature=temperature):
chatbot[-1] = (parse_text(input), parse_text(response))
yield chatbot, history
response_new = ''
history = []
for i in range(3000):
length_history = len(history)
if (length_history > 5): # 如果对话长度太长,就把之前的遗忘掉
del history[0]
del history[0]
# print('\nYou:',end='')
print('\033[1;31m{}\033[0m'.format('\nYou:'),end='')
msg = input()
print('\033[1;34m{}\033[0m'.format('ChatGLM:'),end='')
for chatbot, history in predict(input=msg, chatbot=[], max_length=10000, top_p=0.5, temperature=0.5, history=history):
response_old = response_new
response_new = chatbot[0][1]
new_single = response_new.replace(response_old, '')
print(new_single,end='')
|