hr16 commited on
Commit
bdba3d8
·
verified ·
1 Parent(s): 1eaeec3

Create modeling_whisper.py

Browse files
Files changed (1) hide show
  1. modeling_whisper.py +161 -0
modeling_whisper.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torch.utils.checkpoint
8
+ from torch import nn
9
+ from torch.nn import CrossEntropyLoss
10
+
11
+ from transformers.modeling_outputs import (
12
+ SequenceClassifierOutput
13
+ )
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import (
16
+ add_start_docstrings,
17
+ add_start_docstrings_to_model_forward,
18
+ logging,
19
+ replace_return_docstrings,
20
+ )
21
+ from transformers.models.whisper.configuration_whisper import WhisperConfig
22
+ from transformers.models.whisper.generation_whisper import WhisperGenerationMixin
23
+ from transformers.models.whisper.modeling_whisper import WhisperPreTrainedModel, WHISPER_ENCODER_INPUTS_DOCSTRING, _CONFIG_FOR_DOC, WhisperEncoder
24
+
25
+ class ViSpeechClassification(WhisperPreTrainedModel):
26
+ def __init__(self, config):
27
+ super().__init__(config)
28
+
29
+ self.encoder = WhisperEncoder(config)
30
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
31
+ self.projector = nn.Sequential(
32
+ nn.Linear(self.encoder.config.hidden_size, 1024),
33
+ nn.ReLU(),
34
+ nn.Dropout(config.dropout),
35
+
36
+ nn.Linear(1024, 512),
37
+ nn.ReLU(),
38
+ nn.Dropout(config.dropout),
39
+
40
+ nn.Linear(512, config.classifier_proj_size),
41
+ nn.ReLU(),
42
+ nn.Dropout(config.dropout)
43
+ )
44
+
45
+ self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
46
+ self.config.use_weighted_layer_sum = False
47
+
48
+ # Initialize weights and apply final processing
49
+ self.post_init()
50
+
51
+ def freeze_encoder(self):
52
+ """
53
+ Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
54
+ not be updated during training. Only the projection layers and classification head will be updated.
55
+ """
56
+ self.encoder._freeze_parameters()
57
+
58
+ def get_input_embeddings(self) -> nn.Module:
59
+ return self.encoder.get_input_embeddings()
60
+
61
+ def set_input_embeddings(self, value: nn.Module):
62
+ self.encoder.set_input_embeddings(value)
63
+
64
+ @add_start_docstrings_to_model_forward(WHISPER_ENCODER_INPUTS_DOCSTRING)
65
+ @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
66
+ def forward(
67
+ self,
68
+ input_features: Optional[torch.LongTensor] = None,
69
+ head_mask: Optional[torch.Tensor] = None,
70
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
71
+ labels: Optional[torch.LongTensor] = None,
72
+ output_attentions: Optional[bool] = None,
73
+ output_hidden_states: Optional[bool] = None,
74
+ return_dict: Optional[bool] = None,
75
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
76
+ r"""
77
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
78
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
79
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
80
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
81
+
82
+ Returns:
83
+
84
+ Example:
85
+
86
+ ```python
87
+ >>> import torch
88
+ >>> from transformers import AutoFeatureExtractor, WhisperForAudioClassification
89
+ >>> from datasets import load_dataset
90
+
91
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")
92
+ >>> model = WhisperForAudioClassification.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")
93
+
94
+ >>> ds = load_dataset("google/fleurs", "all", split="validation", streaming=True)
95
+ >>> sample = next(iter(ds))
96
+
97
+ >>> inputs = feature_extractor(
98
+ ... sample["audio"]["array"], sampling_rate=sample["audio"]["sampling_rate"], return_tensors="pt"
99
+ ... )
100
+ >>> input_features = inputs.input_features
101
+
102
+ >>> with torch.no_grad():
103
+ ... logits = model(input_features).logits
104
+
105
+ >>> predicted_class_ids = torch.argmax(logits).item()
106
+ >>> predicted_label = model.config.id2label[predicted_class_ids]
107
+ >>> predicted_label
108
+ 'Afrikaans'
109
+ ```"""
110
+
111
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
112
+ output_hidden_states = (
113
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
114
+ )
115
+ if self.config.use_weighted_layer_sum:
116
+ output_hidden_states = True
117
+ elif output_hidden_states is None:
118
+ output_hidden_states = self.config.output_hidden_states
119
+
120
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
121
+
122
+ if encoder_outputs is None:
123
+ encoder_outputs = self.encoder(
124
+ input_features,
125
+ head_mask=head_mask,
126
+ output_attentions=output_attentions,
127
+ output_hidden_states=output_hidden_states,
128
+ return_dict=return_dict,
129
+ )
130
+
131
+ if self.config.use_weighted_layer_sum:
132
+ hidden_states = encoder_outputs[_HIDDEN_STATES_START_POSITION]
133
+ hidden_states = torch.stack(hidden_states, dim=1)
134
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
135
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
136
+ else:
137
+ hidden_states = encoder_outputs[0]
138
+
139
+ hidden_states = self.projector(hidden_states)
140
+ pooled_output = hidden_states.mean(dim=1)
141
+
142
+ logits = self.classifier(pooled_output)
143
+
144
+ loss = None
145
+
146
+ if labels is not None:
147
+ loss_fct = CrossEntropyLoss()
148
+ # move labels to correct device to enable PP
149
+ labels = labels.to(logits.device)
150
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
151
+
152
+ if not return_dict:
153
+ output = (logits,) + encoder_outputs[1:]
154
+ return ((loss,) + output) if loss is not None else output
155
+
156
+ return SequenceClassifierOutput(
157
+ loss=loss,
158
+ logits=logits,
159
+ hidden_states=encoder_outputs.hidden_states,
160
+ attentions=encoder_outputs.attentions,
161
+ )