File size: 3,758 Bytes
5fbdd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import pickle
import os
from shared import CustomTokens
from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM
from dataclasses import dataclass, field
from typing import Optional


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        default='google/t5-v1_1-small',  # t5-small
        metadata={
            'help': 'Path to pretrained model or model identifier from huggingface.co/models'}
    )
    # config_name: Optional[str] = field( # TODO remove?
    #     default=None, metadata={'help': 'Pretrained config name or path if not the same as model_name'}
    # )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={'help': 'Pretrained tokenizer name or path if not the same as model_name'}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={
            'help': 'Where to store the pretrained models downloaded from huggingface.co'},
    )
    use_fast_tokenizer: bool = field(  # TODO remove?
        default=True,
        metadata={
            'help': 'Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.'},
    )
    model_revision: str = field(  # TODO remove?
        default='main',
        metadata={
            'help': 'The specific model version to use (can be a branch name, tag name or commit id).'},
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
            'help': 'Will use the token generated when running `transformers-cli login` (necessary to use this script '
            'with private models).'
        },
    )
    resize_position_embeddings: Optional[bool] = field(
        default=None,
        metadata={
            'help': "Whether to automatically resize the position embeddings if `max_source_length` exceeds the model's position embeddings."
        },
    )


def get_model(model_args, use_cache=True):
    name = model_args.model_name_or_path
    cached_path = f'models/{name}'

    # Model created after tokenizer:
    if use_cache and os.path.exists(os.path.join(cached_path, 'pytorch_model.bin')):
        name = cached_path

    config = AutoConfig.from_pretrained(
        name,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )

    model = AutoModelForSeq2SeqLM.from_pretrained(
        name,
        from_tf='.ckpt' in name,
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )

    return model


def get_tokenizer(model_args, use_cache=True):
    name = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path

    cached_path = f'models/{name}'

    if use_cache and os.path.exists(os.path.join(cached_path, 'tokenizer.json')):
        name = cached_path

    tokenizer = AutoTokenizer.from_pretrained(
        name,
        cache_dir=model_args.cache_dir,
        use_fast=model_args.use_fast_tokenizer,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )

    CustomTokens.add_custom_tokens(tokenizer)

    return tokenizer


def get_classifier_vectorizer(classifier_args):
    with open(os.path.join(classifier_args.classifier_dir, classifier_args.classifier_file), 'rb') as fp:
        classifier = pickle.load(fp)

    with open(os.path.join(classifier_args.classifier_dir, classifier_args.vectorizer_file), 'rb') as fp:
        vectorizer = pickle.load(fp)

    return classifier, vectorizer