File size: 2,595 Bytes
5fbdd3c
 
 
 
 
 
 
 
 
 
90d1f68
 
5fbdd3c
 
 
 
 
90d1f68
 
 
5fbdd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gc
from time import time_ns
import random
import numpy as np
import torch
from typing import Optional
from dataclasses import dataclass, field
from enum import Enum

class CustomTokens(Enum):
    EXTRACT_SEGMENTS_PREFIX = 'EXTRACT_SEGMENTS: '

    URL = 'URL_TOKEN'
    HYPHENATED_URL = 'HYPHENATED_URL_TOKEN'
    NUMBER_PERCENTAGE = 'NUMBER_PERCENTAGE_TOKEN'
    NUMBER = 'NUMBER_TOKEN'

    START_SEGMENT = 'START_SEGMENT_TOKEN'
    END_SEGMENT = 'END_SEGMENT_TOKEN'
    NO_SEGMENT = 'NO_SEGMENT_FOUND'

    SHORT_HYPHENATED = 'SHORT_HYPHENATED_TOKEN'
    LONG_WORD = 'LONG_WORD_TOKEN'

    # Custom YouTube tokens
    MUSIC = '[Music]'
    APPLAUSE = '[Applause]'
    LAUGHTER = '[Laughter]'

    PROFANITY = 'PROFANITY_TOKEN'

    @classmethod
    def custom_tokens(cls):
        return [e.value for e in cls]

    @classmethod
    def add_custom_tokens(cls, tokenizer):
        tokenizer.add_tokens(cls.custom_tokens())


@dataclass
class OutputArguments:

    output_dir: str = field(
        default='out',
        metadata={
            'help': 'The output directory where the model predictions and checkpoints will be written to and read from.'
        },
    )
    checkpoint: Optional[str] = field(
        default=None,
        metadata={
            'help': 'Choose the checkpoint/model to train from or test with. Defaults to the latest checkpoint found in `output_dir`.'
        },
    )
    models_dir: str = field(
        default='models',
        metadata={
            'help': 'The output directory where the model predictions and checkpoints will be written to and read from.'
        },
    )
    # classifier_dir: str = field(
    #     default='out',
    #     metadata={
    #         'help': 'The output directory where the model predictions and checkpoints will be written to and read from.'
    #     },
    # )


def seed_factory():
    return time_ns() % (2**32 - 1)


@dataclass
class GeneralArguments:
    seed: Optional[int] = field(default_factory=seed_factory, metadata={
        'help': 'Set seed for deterministic training and testing. By default, it uses the current time (results in essentially random results).'
    })

    def __post_init__(self):
        random.seed(self.seed)
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        torch.cuda.manual_seed_all(self.seed)


def device():
    return torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def reset():
    torch.clear_autocast_cache()
    torch.cuda.empty_cache()
    gc.collect()
    print(torch.cuda.memory_summary(device=None, abbreviated=False))