Evander1 commited on
Commit
7d92072
·
verified ·
1 Parent(s): 0fbbefc

Upload 2 files

Browse files

Here is the prediction source

Files changed (2) hide show
  1. model.py +111 -0
  2. predict.py +412 -0
model.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import torch.utils.checkpoint
7
+ from transformers import Wav2Vec2Model, Wav2Vec2PreTrainedModel
8
+ from transformers.modeling_outputs import CausalLMOutput
9
+ from torch import nn
10
+
11
+ class Wav2Vec2ForCTCnCLS(Wav2Vec2PreTrainedModel):
12
+
13
+ def __init__(self, config, cls_len=2, alpha=0.01):
14
+ super().__init__(config)
15
+ self.wav2vec2 = Wav2Vec2Model(config)
16
+ self.dropout = nn.Dropout(config.final_dropout)
17
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
18
+ self.cls_head = nn.Linear(config.hidden_size, cls_len)
19
+ self.init_weights()
20
+ self.alpha = alpha
21
+
22
+
23
+ def freeze_feature_extractor(self):
24
+ self.wav2vec2.feature_extractor._freeze_parameters()
25
+
26
+
27
+ def _ctc_loss(self, logits, labels, input_values, attention_mask=None):
28
+ loss = None
29
+ if labels is not None:
30
+
31
+ # retrieve loss input_lengths from attention_mask
32
+ attention_mask = (
33
+ attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
34
+ )
35
+ input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
36
+
37
+ # assuming that padded tokens are filled with -100
38
+ # when not being attended to
39
+ labels_mask = labels >= 0
40
+ target_lengths = labels_mask.sum(-1)
41
+ flattened_targets = labels.masked_select(labels_mask)
42
+
43
+ log_probs = F.log_softmax(logits, dim=-1).transpose(0, 1)
44
+
45
+ with torch.backends.cudnn.flags(enabled=False):
46
+ loss = F.ctc_loss(
47
+ log_probs,
48
+ flattened_targets,
49
+ input_lengths,
50
+ target_lengths,
51
+ blank=self.config.pad_token_id,
52
+ reduction=self.config.ctc_loss_reduction,
53
+ zero_infinity=self.config.ctc_zero_infinity,
54
+ )
55
+
56
+ return loss
57
+
58
+
59
+ def _cls_loss(self, logits, cls_labels): # sum hidden_states over dim 1 (the sequence length), then feed into self.cls
60
+ loss = None
61
+ if cls_labels is not None:
62
+ loss = F.cross_entropy(logits, cls_labels.to(logits.device))
63
+ return loss
64
+
65
+
66
+ def forward(
67
+ self,
68
+ input_values,
69
+ attention_mask=None,
70
+ output_attentions=None,
71
+ output_hidden_states=None,
72
+ return_dict=None,
73
+ labels=None, # tuple: (ctc_labels, cls_labels), shape=(batch_size, target_length)
74
+ if_ctc=True,
75
+ if_cls=True,
76
+ ):
77
+
78
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
79
+
80
+ outputs = self.wav2vec2(
81
+ input_values,
82
+ attention_mask=attention_mask,
83
+ output_attentions=output_attentions,
84
+ output_hidden_states=output_hidden_states,
85
+ return_dict=return_dict,
86
+ )
87
+
88
+ hidden_states = outputs[0] # this is the last layer's hidden states
89
+ hidden_states = self.dropout(hidden_states)
90
+
91
+ logits_ctc = self.lm_head(hidden_states)
92
+ logits_cls = self.cls_head(torch.mean(hidden_states, dim=1))
93
+
94
+ loss = None
95
+ if labels is not None:
96
+ if if_ctc:
97
+ loss_ctc = self._ctc_loss(logits_ctc, labels[0], input_values, attention_mask)
98
+ if if_cls:
99
+ loss_cls = self._cls_loss(logits_cls, labels[1])
100
+
101
+ loss = loss_cls + self.alpha * loss_ctc
102
+
103
+ # if not return_dict:
104
+ # output = (logits,) + outputs[1:]
105
+ # return ((loss,) + output) if loss is not None else output
106
+
107
+ return CausalLMOutput(
108
+ loss=loss, logits=(logits_ctc, logits_cls), hidden_states=outputs.hidden_states, attentions=outputs.attentions
109
+ )
110
+
111
+
predict.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import logging
3
+ import pathlib
4
+ import re
5
+ import sys
6
+ import time
7
+ import csv
8
+ from dataclasses import dataclass, field
9
+ from typing import Any, Callable, Dict, List, Optional, Set, Union
10
+
11
+ import datasets
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from packaging import version
17
+ from torch.cuda.amp import GradScaler, autocast
18
+
19
+ import librosa
20
+ from lang_trans import arabic
21
+ from datasets import Dataset
22
+
23
+ import soundfile as sf
24
+ from model import Wav2Vec2ForCTCnCLS
25
+ from transformers.trainer_utils import get_last_checkpoint
26
+
27
+ from transformers import (
28
+ HfArgumentParser,
29
+ Trainer,
30
+ TrainingArguments,
31
+ Wav2Vec2CTCTokenizer,
32
+ Wav2Vec2FeatureExtractor,
33
+ Wav2Vec2Processor,
34
+ is_apex_available,
35
+ trainer_utils,
36
+ )
37
+
38
+
39
+ local_model_path = "local_model"
40
+
41
+ if is_apex_available():
42
+ from apex import amp
43
+
44
+ if version.parse(torch.__version__) >= version.parse("1.6"):
45
+ _is_native_amp_available = True
46
+ from torch.cuda.amp import autocast
47
+
48
+
49
+ logger = logging.getLogger(__name__)
50
+
51
+
52
+ @dataclass
53
+ class ModelArguments:
54
+ """
55
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
56
+ """
57
+
58
+ model_name_or_path: str = field(
59
+ default="local_model",
60
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
61
+ )
62
+ cache_dir: Optional[str] = field(
63
+ default=None,
64
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
65
+ )
66
+ freeze_feature_extractor: Optional[bool] = field(
67
+ default=False, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
68
+ )
69
+ verbose_logging: Optional[bool] = field(
70
+ default=False,
71
+ metadata={"help": "Whether to log verbose messages or not."},
72
+ )
73
+
74
+ tokenizer: Optional[str] = field(
75
+ default="checkpoint-33000",
76
+ metadata={"help": "Path to pretrained tokenizer"}
77
+
78
+ )
79
+
80
+ def configure_logger(model_args: ModelArguments, training_args: TrainingArguments):
81
+ logging.basicConfig(
82
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
83
+ datefmt="%m/%d/%Y %H:%M:%S",
84
+ handlers=[logging.StreamHandler(sys.stdout)],
85
+ )
86
+ logging_level = logging.WARNING
87
+ if model_args.verbose_logging:
88
+ logging_level = logging.DEBUG
89
+ elif trainer_utils.is_main_process(training_args.local_rank):
90
+ logging_level = logging.INFO
91
+ logger.setLevel(logging_level)
92
+
93
+
94
+ @dataclass
95
+ class DataTrainingArguments:
96
+ """
97
+ Arguments pertaining to what data we are going to input our model for training and eval.
98
+
99
+ Using `HfArgumentParser` we can turn this class
100
+ into argparse arguments to be able to specify them on
101
+ the command line.
102
+ """
103
+
104
+ dataset_name: str = field(
105
+ default='emotion', metadata={"help": "The name of the dataset to use (via the datasets library)."}
106
+ )
107
+ dataset_config_name: Optional[str] = field(
108
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
109
+ )
110
+ train_split_name: Optional[str] = field(
111
+ default="train",
112
+ metadata={
113
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
114
+ },
115
+ )
116
+ validation_split_name: Optional[str] = field(
117
+ default="validation",
118
+ metadata={
119
+ "help": "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
120
+ },
121
+ )
122
+ target_text_column: Optional[str] = field(
123
+ default="text",
124
+ metadata={"help": "Column in the dataset that contains label (target text). Defaults to 'text'"},
125
+ )
126
+ speech_file_column: Optional[str] = field(
127
+ default="file",
128
+ metadata={"help": "Column in the dataset that contains speech file path. Defaults to 'file'"},
129
+ )
130
+ target_feature_extractor_sampling_rate: Optional[bool] = field(
131
+ default=False,
132
+ metadata={"help": "Resample loaded audio to target feature extractor's sampling rate or not."},
133
+ )
134
+ max_duration_in_seconds: Optional[float] = field(
135
+ default=None,
136
+ metadata={"help": "Filters out examples longer than specified. Defaults to no filtering."},
137
+ )
138
+ orthography: Optional[str] = field(
139
+ default="librispeech",
140
+ metadata={
141
+ "help": "Orthography used for normalization and tokenization: 'librispeech' (default), 'timit', or 'buckwalter'."
142
+ },
143
+ )
144
+ overwrite_cache: bool = field(
145
+ default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
146
+ )
147
+ preprocessing_num_workers: Optional[int] = field(
148
+ default=8,
149
+ metadata={"help": "The number of processes to use for the preprocessing."},
150
+ )
151
+
152
+ output_file: Optional[str] = field(
153
+ default=None,
154
+ metadata={"help": "Output file."},
155
+ )
156
+
157
+
158
+ @dataclass
159
+ class Orthography:
160
+ """
161
+ Orthography scheme used for text normalization and tokenization.
162
+
163
+ Args:
164
+ do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`False`):
165
+ Whether or not to accept lowercase input and lowercase the output when decoding.
166
+ vocab_file (:obj:`str`, `optional`, defaults to :obj:`None`):
167
+ File containing the vocabulary.
168
+ word_delimiter_token (:obj:`str`, `optional`, defaults to :obj:`"|"`):
169
+ The token used for delimiting words; it needs to be in the vocabulary.
170
+ translation_table (:obj:`Dict[str, str]`, `optional`, defaults to :obj:`{}`):
171
+ Table to use with `str.translate()` when preprocessing text (e.g., "-" -> " ").
172
+ words_to_remove (:obj:`Set[str]`, `optional`, defaults to :obj:`set()`):
173
+ Words to remove when preprocessing text (e.g., "sil").
174
+ untransliterator (:obj:`Callable[[str], str]`, `optional`, defaults to :obj:`None`):
175
+ Function that untransliterates text back into native writing system.
176
+ tokenizer (:obj:`str`, `optional`, defaults to :obj:`None`):
177
+ Tokenizer type, e.g., 'jieba' for Chinese.
178
+ """
179
+
180
+ do_lower_case: bool = False
181
+ vocab_file: Optional[str] = None
182
+ word_delimiter_token: Optional[str] = "|"
183
+ translation_table: Optional[Dict[str, str]] = field(default_factory=dict)
184
+ words_to_remove: Optional[Set[str]] = field(default_factory=set)
185
+ tokenizer: Optional[str] = None
186
+ untransliterator: Optional[Callable[[str], str]] = None
187
+ @classmethod
188
+ def from_name(cls, name: str):
189
+ if name == "librispeech":
190
+ return cls()
191
+ else:
192
+ raise ValueError(f"Unsupported orthography: '{name}'.")
193
+
194
+
195
+
196
+ def create_processor(self, model_args: ModelArguments) -> Wav2Vec2Processor:
197
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
198
+ local_model_path, cache_dir=model_args.cache_dir
199
+ )
200
+ if self.vocab_file:
201
+ tokenizer = Wav2Vec2CTCTokenizer(
202
+ self.vocab_file,
203
+ cache_dir=model_args.cache_dir,
204
+ do_lower_case=self.do_lower_case,
205
+ word_delimiter_token=self.word_delimiter_token,
206
+ )
207
+ else:
208
+ tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(
209
+ local_model_path,
210
+ # self.tokenizer,
211
+ cache_dir=model_args.cache_dir,
212
+ do_lower_case=self.do_lower_case,
213
+ word_delimiter_token=self.word_delimiter_token,
214
+ device_map="cuda:0",
215
+ )
216
+ return Wav2Vec2Processor(feature_extractor, tokenizer)
217
+
218
+
219
+ @dataclass
220
+ class TrainingArguments(TrainingArguments):
221
+ output_dir: str = field(
222
+ default="output/angry_tmp", metadata={"help": "The store of your output."})
223
+ do_predict: bool = field(
224
+ default=True, metadata={"help": "The store of your output."})
225
+ do_eval: bool = field(
226
+ default=False, metadata={"help": "The store of your output."})
227
+ overwrite_output_dir: str = field(
228
+ default='overwrite_output_dir', metadata={"help": "The store of your output."} )
229
+ per_device_eval_batch_size: int = field(
230
+ default=2, metadata={"help": "The store of your output."})
231
+ warmup_ratio: float = field(
232
+ default=0.1, metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."}
233
+ )
234
+
235
+
236
+
237
+ @dataclass
238
+ class DataCollatorCTCWithPadding:
239
+ """
240
+ Data collator that will dynamically pad the inputs received.
241
+ Args:
242
+ processor (:class:`~transformers.Wav2Vec2Processor`)
243
+ The processor used for proccessing the data.
244
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
245
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
246
+ among:
247
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
248
+ sequence if provided).
249
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
250
+ maximum acceptable input length for the model if that argument is not provided.
251
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
252
+ different lengths).
253
+ max_length (:obj:`int`, `optional`):
254
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
255
+ max_length_labels (:obj:`int`, `optional`):
256
+ Maximum length of the ``labels`` returned list and optionally padding length (see above).
257
+ pad_to_multiple_of (:obj:`int`, `optional`):
258
+ If set will pad the sequence to a multiple of the provided value.
259
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
260
+ 7.5 (Volta).
261
+ """
262
+
263
+ processor: Wav2Vec2Processor
264
+ padding: Union[bool, str] = True
265
+ max_length: Optional[int] = None
266
+ max_length_labels: Optional[int] = None
267
+ pad_to_multiple_of: Optional[int] = None
268
+ pad_to_multiple_of_labels: Optional[int] = None
269
+ audio_only = False
270
+ duration = 6
271
+ sample_rate = 16000
272
+
273
+
274
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
275
+ # split inputs and labels since they have to be of different lenghts and need
276
+ # different padding methods
277
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
278
+
279
+ batch = self.processor.pad(
280
+ input_features,
281
+ padding=self.padding,
282
+ # max_length=self.max_length,
283
+ max_length=self.duration*self.sample_rate,
284
+ pad_to_multiple_of=self.pad_to_multiple_of,
285
+ return_tensors="pt",
286
+ )
287
+
288
+ return batch
289
+
290
+
291
+ class CTCTrainer(Trainer):
292
+ def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
293
+ self.use_amp = False
294
+ self.use_apex = False
295
+ self.deepspeed = False
296
+ self.scaler = GradScaler()
297
+ for k, v in inputs.items():
298
+ if isinstance(v, torch.Tensor):
299
+ kwargs = dict(device=self.args.device)
300
+ if self.deepspeed and inputs[k].dtype != torch.int64:
301
+ kwargs.update(dict(dtype=self.args.hf_deepspeed_config.dtype()))
302
+ inputs[k] = v.to(**kwargs)
303
+
304
+ if self.args.past_index >= 0 and self._past is not None:
305
+ inputs["mems"] = self._past
306
+
307
+ return inputs
308
+
309
+
310
+ def create_dataset(audio_path):
311
+ data = {
312
+ 'file': [audio_path]
313
+ }
314
+ dataset = Dataset.from_dict(data)
315
+ return dataset
316
+
317
+
318
+ def exeute_angry_predict(audio_path):
319
+ # See all possible arguments in src/transformers/training_args.py
320
+ # or by passing the --help flag to this script.
321
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
322
+
323
+ target_sr = 16000
324
+
325
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
326
+
327
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
328
+ configure_logger(model_args, training_args)
329
+
330
+
331
+ orthography = Orthography.from_name(data_args.orthography.lower())
332
+ orthography.tokenizer = model_args.tokenizer
333
+ processor = orthography.create_processor(model_args)
334
+
335
+ if data_args.dataset_name == 'emotion':
336
+ val_dataset = create_dataset(audio_path)
337
+ cls_label_map = {"neutral":0, "angry":1}
338
+
339
+ model = Wav2Vec2ForCTCnCLS.from_pretrained(
340
+ local_model_path,
341
+ gradient_checkpointing=True, # training_args.gradient_checkpointing,
342
+ cls_len=len(cls_label_map),
343
+ )
344
+
345
+ def prepare_example(example, audio_only=False): # TODO(elgeish) make use of multiprocessing?
346
+ example["speech"], example["sampling_rate"] = librosa.load(example[data_args.speech_file_column], sr=target_sr)
347
+ orig_sample_rate = example["sampling_rate"]
348
+ target_sample_rate = target_sr
349
+ if orig_sample_rate != target_sample_rate:
350
+ example["speech"] = librosa.resample(example["speech"], orig_sr=orig_sample_rate, target_sr=target_sample_rate)
351
+ if data_args.max_duration_in_seconds is not None:
352
+ example["duration_in_seconds"] = len(example["speech"]) / example["sampling_rate"]
353
+ return example
354
+
355
+
356
+ if training_args.do_predict:
357
+ val_dataset = val_dataset.map(prepare_example, fn_kwargs={'audio_only':True})
358
+
359
+
360
+ def prepare_dataset(batch, audio_only=False):
361
+ # check that all files have the correct sampling rate
362
+ assert (
363
+ len(set(batch["sampling_rate"])) == 1
364
+ ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."
365
+
366
+ batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
367
+ return batch
368
+
369
+
370
+ if training_args.do_predict:
371
+ val_dataset = val_dataset.map(
372
+ prepare_dataset,
373
+ fn_kwargs={'audio_only':True},
374
+ batch_size=training_args.per_device_eval_batch_size,
375
+ batched=True,
376
+ num_proc=data_args.preprocessing_num_workers,
377
+ )
378
+
379
+ data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
380
+
381
+ if model_args.freeze_feature_extractor:
382
+ model.freeze_feature_extractor()
383
+
384
+ trainer = CTCTrainer(
385
+ model=model,
386
+ args=training_args,
387
+ eval_dataset=val_dataset,
388
+ tokenizer=processor.feature_extractor,
389
+ )
390
+
391
+
392
+ if training_args.do_predict:
393
+ logger.info('******* Predict ********')
394
+ data_collator.audio_only=True
395
+ results= {}
396
+ result= ''
397
+ predictions, labels, metrics = trainer.predict(val_dataset, metric_key_prefix="predict")
398
+ logits_ctc, logits_cls = predictions
399
+ pred_ids = np.argmax(logits_cls, axis=-1)
400
+ if pred_ids==0:
401
+ result = "非愤怒"
402
+ if pred_ids==1:
403
+ result = "愤怒"
404
+ results[audio_path] = result
405
+ print("results", results)
406
+
407
+
408
+ if __name__ == "__main__":
409
+ audio_path = 'audio.mp3'
410
+ exeute_angry_predict(audio_path)
411
+
412
+