Spaces:
Sleeping
Sleeping
File size: 1,204 Bytes
8518918 cf5b520 8518918 cf5b520 8518918 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import torch
import glob
import os
from transformers import BertTokenizerFast as BertTokenizer, BertForSequenceClassification
os.environ['https_proxy'] = "127.0.0.1:1081"
LABEL_COLUMNS = ["Assertive Tone", "Conversational Tone", "Emotional Tone", "Informative Tone", "None"]
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=5)
id2label = {i:label for i,label in enumerate(LABEL_COLUMNS)}
label2id = {label:i for i,label in enumerate(LABEL_COLUMNS)}
for ckpt in glob.glob('checkpoints/*.ckpt'):
base_name = os.path.basename(ckpt)
# 去除文件后缀
model_name = os.path.splitext(base_name)[0]
params = torch.load(ckpt, map_location="cpu")['state_dict']
msg = model.load_state_dict(params, strict=True)
path = f'models/{model_name}'
os.makedirs(path, exist_ok=True)
torch.save(model.state_dict(), f'{path}/pytorch_model.bin')
config = model.config
config.architectures = ['BertForSequenceClassification']
config.label2id = label2id
config.id2label = id2label
model.config.to_json_file(f'{path}/config.json')
tokenizer.save_vocabulary(path) |