File size: 4,206 Bytes
5fbdd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4678c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fbdd3c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import pickle
import os
from shared import CustomTokens
from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM
from dataclasses import dataclass, field
from typing import Optional


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        default='google/t5-v1_1-small',  # t5-small
        metadata={
            'help': 'Path to pretrained model or model identifier from huggingface.co/models'}
    )
    # config_name: Optional[str] = field( # TODO remove?
    #     default=None, metadata={'help': 'Pretrained config name or path if not the same as model_name'}
    # )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={'help': 'Pretrained tokenizer name or path if not the same as model_name'}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={
            'help': 'Where to store the pretrained models downloaded from huggingface.co'},
    )
    use_fast_tokenizer: bool = field(  # TODO remove?
        default=True,
        metadata={
            'help': 'Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.'},
    )
    model_revision: str = field(  # TODO remove?
        default='main',
        metadata={
            'help': 'The specific model version to use (can be a branch name, tag name or commit id).'},
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
            'help': 'Will use the token generated when running `transformers-cli login` (necessary to use this script '
            'with private models).'
        },
    )
    resize_position_embeddings: Optional[bool] = field(
        default=None,
        metadata={
            'help': "Whether to automatically resize the position embeddings if `max_source_length` exceeds the model's position embeddings."
        },
    )


def get_model(model_args, use_cache=True):
    name = model_args.model_name_or_path
    cached_path = f'models/{name}'

    # Model created after tokenizer:
    if use_cache and os.path.exists(os.path.join(cached_path, 'pytorch_model.bin')):
        name = cached_path

    config = AutoConfig.from_pretrained(
        name,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )

    model = AutoModelForSeq2SeqLM.from_pretrained(
        name,
        from_tf='.ckpt' in name,
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )

    return model


def get_tokenizer(model_args, use_cache=True):
    name = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path

    cached_path = f'models/{name}'

    if use_cache and os.path.exists(os.path.join(cached_path, 'tokenizer.json')):
        name = cached_path

    tokenizer = AutoTokenizer.from_pretrained(
        name,
        cache_dir=model_args.cache_dir,
        use_fast=model_args.use_fast_tokenizer,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )

    CustomTokens.add_custom_tokens(tokenizer)

    return tokenizer


CLASSIFIER_CACHE = {}
def get_classifier_vectorizer(classifier_args, use_cache=True):
    classifier_path = os.path.join(classifier_args.classifier_dir, classifier_args.classifier_file)
    if use_cache and classifier_path in CLASSIFIER_CACHE:
        classifier = CLASSIFIER_CACHE[classifier_path]
    else:
        with open(classifier_path, 'rb') as fp:
            classifier = CLASSIFIER_CACHE[classifier_path] = pickle.load(fp)

    vectorizer_path = os.path.join(classifier_args.classifier_dir, classifier_args.vectorizer_file)
    if use_cache and vectorizer_path in CLASSIFIER_CACHE:
        vectorizer = CLASSIFIER_CACHE[vectorizer_path]
    else:
        with open(vectorizer_path, 'rb') as fp:
            vectorizer = CLASSIFIER_CACHE[vectorizer_path] = pickle.load(fp)

    return classifier, vectorizer