File size: 1,150 Bytes
8518918
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import glob
import os
from transformers import BertTokenizerFast as BertTokenizer, BertForSequenceClassification

LABEL_COLUMNS = ["Assertive Tone", "Conversational Tone", "Emotional Tone", "Informative Tone"]

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4)
id2label = {i:label for i,label in enumerate(LABEL_COLUMNS)}
label2id = {label:i for i,label in enumerate(LABEL_COLUMNS)}

for ckpt in glob.glob('checkpoints/*.ckpt'):
    base_name = os.path.basename(ckpt)
    # 去除文件后缀
    model_name = os.path.splitext(base_name)[0]
    params = torch.load(ckpt, map_location="cpu")['state_dict']
    msg = model.load_state_dict(params, strict=True)
    path = f'models/{model_name}'
    os.makedirs(path, exist_ok=True)
    
    torch.save(model.state_dict(), f'{path}/pytorch_model.bin')
    config = model.config
    config.architectures = ['BertForSequenceClassification']
    config.label2id = label2id
    config.id2label = id2label
    model.config.to_json_file(f'{path}/config.json')
    tokenizer.save_vocabulary(path)