Emaad commited on
Commit
b10d9e5
1 Parent(s): eb7610c

Update celle/celle.py

Browse files
Files changed (1) hide show
  1. celle/celle.py +4 -2
celle/celle.py CHANGED
@@ -262,8 +262,10 @@ class ModelExtender(nn.Module):
262
 
263
  # Set the number of output features and initialize the scaling layer
264
  self.out_features = out_features
265
- self.scale_layer = nn.Linear(self.in_features, self.out_features)
266
-
 
 
267
  # Determine whether to freeze the model's parameters
268
  self.fixed_embedding = fixed_embedding
269
  if self.fixed_embedding:
 
262
 
263
  # Set the number of output features and initialize the scaling layer
264
  self.out_features = out_features
265
+ if self.in_features != self.out_features:
266
+ self.scale_layer = nn.Linear(self.in_features, self.out_features)
267
+ else:
268
+ self.scale_layer = nn.Identity()
269
  # Determine whether to freeze the model's parameters
270
  self.fixed_embedding = fixed_embedding
271
  if self.fixed_embedding: