saeedbenadeeb commited on
Commit
5fc7eb1
·
1 Parent(s): 6ceb71f

Lora Model Uploaded

Browse files
Files changed (5) hide show
  1. app.py +10 -2
  2. encoders/transformer.py +27 -1
  3. lora_only_model.pth +3 -0
  4. models/__init__.py +2 -1
  5. models/lora.py +24 -0
app.py CHANGED
@@ -12,7 +12,7 @@ emotions = ["happy", "sad", "angry", "neutral", "fear", "disgust", "surprise"]
12
  label_mapping = {str(idx): emotion for idx, emotion in enumerate(emotions)}
13
 
14
  # Load the trained model
15
- model_path = "model.pth"
16
  cfg = {
17
  "model": {
18
  "encoder": "Wav2Vec2Classifier",
@@ -25,9 +25,17 @@ cfg = {
25
  }
26
  }
27
  model = Wav2Vec2EmotionClassifier(num_classes=len(emotions), optimizer_cfg=cfg["model"]["optimizer"])
28
- model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
 
 
 
29
  model.eval()
30
 
 
 
 
 
 
31
  # Optional: we define a minimum number of samples to avoid Wav2Vec2 conv errors
32
  MIN_SAMPLES = 10 # or 16000 if you want at least 1 second
33
 
 
12
  label_mapping = {str(idx): emotion for idx, emotion in enumerate(emotions)}
13
 
14
  # Load the trained model
15
+ model_path = "lora_only_model.pth"
16
  cfg = {
17
  "model": {
18
  "encoder": "Wav2Vec2Classifier",
 
25
  }
26
  }
27
  model = Wav2Vec2EmotionClassifier(num_classes=len(emotions), optimizer_cfg=cfg["model"]["optimizer"])
28
+ state_dict = torch.load(model_path, map_location=torch.device("cpu"))
29
+ model.load_state_dict(state_dict, strict=False)
30
+
31
+
32
  model.eval()
33
 
34
+
35
+ for name, param in model.named_parameters():
36
+ if param.requires_grad:
37
+ print(f"{name}: {param.data}")
38
+
39
  # Optional: we define a minimum number of samples to avoid Wav2Vec2 conv errors
40
  MIN_SAMPLES = 10 # or 16000 if you want at least 1 second
41
 
encoders/transformer.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  from torchmetrics import Accuracy, Precision, Recall, F1Score
4
  from transformers import Wav2Vec2Model, Wav2Vec2ForSequenceClassification
5
  import torch.nn.functional as F
6
-
7
 
8
  class Wav2Vec2Classifier(pl.LightningModule):
9
  def __init__(self, num_classes, optimizer_cfg = "Adam", l1_lambda=0.0):
@@ -166,6 +166,32 @@ class Wav2Vec2EmotionClassifier(pl.LightningModule):
166
  else:
167
  self.optimizer = None
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  def forward(self, x, attention_mask=None):
170
  return self.model(x, attention_mask=attention_mask).logits
171
 
 
3
  from torchmetrics import Accuracy, Precision, Recall, F1Score
4
  from transformers import Wav2Vec2Model, Wav2Vec2ForSequenceClassification
5
  import torch.nn.functional as F
6
+ from models.lora import LinearWithLoRA, LoRALayer
7
 
8
  class Wav2Vec2Classifier(pl.LightningModule):
9
  def __init__(self, num_classes, optimizer_cfg = "Adam", l1_lambda=0.0):
 
166
  else:
167
  self.optimizer = None
168
 
169
+ # Apply LoRA
170
+ low_rank = 8
171
+ lora_alpha = 16
172
+ self.apply_lora(low_rank, lora_alpha)
173
+
174
+ def apply_lora(self, rank, alpha):
175
+ # Replace specific linear layers with LinearWithLoRA
176
+ for layer in self.model.wav2vec2.encoder.layers:
177
+ layer.attention.q_proj = LinearWithLoRA(layer.attention.q_proj, rank, alpha)
178
+ layer.attention.k_proj = LinearWithLoRA(layer.attention.k_proj, rank, alpha)
179
+ layer.attention.v_proj = LinearWithLoRA(layer.attention.v_proj, rank, alpha)
180
+ layer.attention.out_proj = LinearWithLoRA(layer.attention.out_proj, rank, alpha)
181
+
182
+ layer.feed_forward.intermediate_dense = LinearWithLoRA(layer.feed_forward.intermediate_dense, rank, alpha)
183
+ layer.feed_forward.output_dense = LinearWithLoRA(layer.feed_forward.output_dense, rank, alpha)
184
+
185
+ def state_dict(self, *args, **kwargs):
186
+ # Save only LoRA and classifier/projector parameters
187
+ state = super().state_dict(*args, **kwargs)
188
+ return {k: v for k, v in state.items() if "lora" in k or "classifier" in k or "projector" in k}
189
+
190
+ def load_state_dict(self, state_dict, strict=True):
191
+ missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False)
192
+ if missing_keys or unexpected_keys:
193
+ print(f"Missing keys: {missing_keys}")
194
+ print(f"Unexpected keys: {unexpected_keys}")
195
  def forward(self, x, attention_mask=None):
196
  return self.model(x, attention_mask=attention_mask).logits
197
 
lora_only_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc2029a0dcf22d2b626533192bda3fa6098653df84be452b88c4db830a7c9216
3
+ size 8185738
models/__init__.py CHANGED
@@ -1 +1,2 @@
1
- from . import CTCencoder
 
 
1
+ from . import CTCencoder
2
+ from . import lora
models/lora.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class LoRALayer(nn.Module):
5
+ def __init__(self, input_dim, output_dim, rank, alpha):
6
+ super().__init__()
7
+ std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
8
+ self.A = nn.Parameter(torch.randn(input_dim, rank) * std_dev) # Low-rank matrix A
9
+ self.B = nn.Parameter(torch.zeros(rank, output_dim)) # Low-rank matrix B
10
+ self.alpha = alpha # Scaling factor
11
+ def forward(self, x):
12
+ # Apply low-rank adaptation: x + alpha * (x @ A @ B)
13
+ return self.alpha * (x @ self.A @ self.B)
14
+
15
+
16
+ class LinearWithLoRA(nn.Module):
17
+ def __init__(self, linear_layer, rank, alpha):
18
+ super().__init__()
19
+ self.linear = linear_layer # Original linear layer
20
+ self.lora = LoRALayer(linear_layer.in_features, linear_layer.out_features, rank, alpha) # LoRA layer
21
+
22
+ def forward(self, x):
23
+ # Combine original linear layer output with LoRA adaptation
24
+ return self.linear(x) + self.lora(x)