Xalphinions commited on
Commit
13b45d3
·
verified ·
1 Parent(s): fbfa85f

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +36 -11
  2. app_moe.py +23 -8
app.py CHANGED
@@ -34,10 +34,11 @@ class WatermelonMoEModel(torch.nn.Module):
34
  weights: Optional list of weights for each model (None for equal weighting)
35
  """
36
  super(WatermelonMoEModel, self).__init__()
37
- self.models = []
38
  self.model_configs = model_configs
39
 
40
  # Load each model
 
41
  for config in model_configs:
42
  img_backbone = config["image_backbone"]
43
  audio_backbone = config["audio_backbone"]
@@ -49,22 +50,31 @@ class WatermelonMoEModel(torch.nn.Module):
49
  model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt")
50
  if os.path.exists(model_path):
51
  print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}")
52
- model.load_state_dict(torch.load(model_path, map_location='cpu'))
 
 
 
 
 
 
 
53
  else:
54
  print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}")
55
  continue
56
-
57
- model.eval() # Set to evaluation mode
58
- self.models.append(model)
59
 
 
 
 
 
 
60
  # Set model weights (uniform by default)
61
- if weights:
62
  assert len(weights) == len(self.models), "Number of weights must match number of models"
63
  self.weights = weights
64
  else:
65
- self.weights = [1.0 / len(self.models)] * len(self.models)
66
 
67
- print(f"\033[92mINFO\033[0m: Loaded {len(self.models)} models for MoE ensemble")
68
  print(f"\033[92mINFO\033[0m: Model weights: {self.weights}")
69
 
70
  def to(self, device):
@@ -80,6 +90,11 @@ class WatermelonMoEModel(torch.nn.Module):
80
  Forward pass through the MoE model.
81
  Returns the weighted average of all model outputs.
82
  """
 
 
 
 
 
83
  outputs = []
84
 
85
  # Get outputs from each model
@@ -258,9 +273,19 @@ def predict_sugar_content(audio, image, model_dir="models", weights=None):
258
  print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}, device: {processed_image.device}")
259
 
260
  # Double-check model is on the correct device
261
- print(f"\033[92mDEBUG\033[0m: MoE model device: {next(moe_model.parameters()).device}")
262
- for i, model in enumerate(moe_model.models):
263
- print(f"\033[92mDEBUG\033[0m: Model {i} device: {next(model.parameters()).device}")
 
 
 
 
 
 
 
 
 
 
264
 
265
  # Run inference with MoE model
266
  print(f"\033[92mDEBUG\033[0m: Running inference with MoE model on device: {device}")
 
34
  weights: Optional list of weights for each model (None for equal weighting)
35
  """
36
  super(WatermelonMoEModel, self).__init__()
37
+ self.models = torch.nn.ModuleList() # Use ModuleList instead of regular list
38
  self.model_configs = model_configs
39
 
40
  # Load each model
41
+ loaded_count = 0
42
  for config in model_configs:
43
  img_backbone = config["image_backbone"]
44
  audio_backbone = config["audio_backbone"]
 
50
  model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt")
51
  if os.path.exists(model_path):
52
  print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}")
53
+ try:
54
+ model.load_state_dict(torch.load(model_path, map_location='cpu'))
55
+ model.eval() # Set to evaluation mode
56
+ self.models.append(model)
57
+ loaded_count += 1
58
+ except Exception as e:
59
+ print(f"\033[91mERR!\033[0m: Failed to load model from {model_path}: {e}")
60
+ continue
61
  else:
62
  print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}")
63
  continue
 
 
 
64
 
65
+ # Add a dummy parameter if no models were loaded to prevent StopIteration
66
+ if loaded_count == 0:
67
+ print(f"\033[91mERR!\033[0m: No models were successfully loaded!")
68
+ self.dummy_param = torch.nn.Parameter(torch.zeros(1))
69
+
70
  # Set model weights (uniform by default)
71
+ if weights and loaded_count > 0:
72
  assert len(weights) == len(self.models), "Number of weights must match number of models"
73
  self.weights = weights
74
  else:
75
+ self.weights = [1.0 / max(loaded_count, 1)] * max(loaded_count, 1)
76
 
77
+ print(f"\033[92mINFO\033[0m: Loaded {loaded_count} models for MoE ensemble")
78
  print(f"\033[92mINFO\033[0m: Model weights: {self.weights}")
79
 
80
  def to(self, device):
 
90
  Forward pass through the MoE model.
91
  Returns the weighted average of all model outputs.
92
  """
93
+ # Check if we have models loaded
94
+ if not self.models:
95
+ print(f"\033[91mERR!\033[0m: No models available for inference!")
96
+ return torch.tensor([0.0], device=mfcc.device) # Return a default value
97
+
98
  outputs = []
99
 
100
  # Get outputs from each model
 
273
  print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}, device: {processed_image.device}")
274
 
275
  # Double-check model is on the correct device
276
+ try:
277
+ param = next(moe_model.parameters())
278
+ print(f"\033[92mDEBUG\033[0m: MoE model device: {param.device}")
279
+
280
+ # Check individual models
281
+ for i, model in enumerate(moe_model.models):
282
+ try:
283
+ model_param = next(model.parameters())
284
+ print(f"\033[92mDEBUG\033[0m: Model {i} device: {model_param.device}")
285
+ except StopIteration:
286
+ print(f"\033[91mERR!\033[0m: Model {i} has no parameters!")
287
+ except StopIteration:
288
+ print(f"\033[91mERR!\033[0m: MoE model has no parameters!")
289
 
290
  # Run inference with MoE model
291
  print(f"\033[92mDEBUG\033[0m: Running inference with MoE model on device: {device}")
app_moe.py CHANGED
@@ -34,10 +34,11 @@ class WatermelonMoEModel(torch.nn.Module):
34
  weights: Optional list of weights for each model (None for equal weighting)
35
  """
36
  super(WatermelonMoEModel, self).__init__()
37
- self.models = []
38
  self.model_configs = model_configs
39
 
40
  # Load each model
 
41
  for config in model_configs:
42
  img_backbone = config["image_backbone"]
43
  audio_backbone = config["audio_backbone"]
@@ -49,22 +50,31 @@ class WatermelonMoEModel(torch.nn.Module):
49
  model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt")
50
  if os.path.exists(model_path):
51
  print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}")
52
- model.load_state_dict(torch.load(model_path, map_location='cpu'))
 
 
 
 
 
 
 
53
  else:
54
  print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}")
55
  continue
56
-
57
- model.eval() # Set to evaluation mode
58
- self.models.append(model)
59
 
 
 
 
 
 
60
  # Set model weights (uniform by default)
61
- if weights:
62
  assert len(weights) == len(self.models), "Number of weights must match number of models"
63
  self.weights = weights
64
  else:
65
- self.weights = [1.0 / len(self.models)] * len(self.models)
66
 
67
- print(f"\033[92mINFO\033[0m: Loaded {len(self.models)} models for MoE ensemble")
68
  print(f"\033[92mINFO\033[0m: Model weights: {self.weights}")
69
 
70
  def to(self, device):
@@ -80,6 +90,11 @@ class WatermelonMoEModel(torch.nn.Module):
80
  Forward pass through the MoE model.
81
  Returns the weighted average of all model outputs.
82
  """
 
 
 
 
 
83
  outputs = []
84
 
85
  # Get outputs from each model
 
34
  weights: Optional list of weights for each model (None for equal weighting)
35
  """
36
  super(WatermelonMoEModel, self).__init__()
37
+ self.models = torch.nn.ModuleList() # Use ModuleList instead of regular list
38
  self.model_configs = model_configs
39
 
40
  # Load each model
41
+ loaded_count = 0
42
  for config in model_configs:
43
  img_backbone = config["image_backbone"]
44
  audio_backbone = config["audio_backbone"]
 
50
  model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt")
51
  if os.path.exists(model_path):
52
  print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}")
53
+ try:
54
+ model.load_state_dict(torch.load(model_path, map_location='cpu'))
55
+ model.eval() # Set to evaluation mode
56
+ self.models.append(model)
57
+ loaded_count += 1
58
+ except Exception as e:
59
+ print(f"\033[91mERR!\033[0m: Failed to load model from {model_path}: {e}")
60
+ continue
61
  else:
62
  print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}")
63
  continue
 
 
 
64
 
65
+ # Add a dummy parameter if no models were loaded to prevent StopIteration
66
+ if loaded_count == 0:
67
+ print(f"\033[91mERR!\033[0m: No models were successfully loaded!")
68
+ self.dummy_param = torch.nn.Parameter(torch.zeros(1))
69
+
70
  # Set model weights (uniform by default)
71
+ if weights and loaded_count > 0:
72
  assert len(weights) == len(self.models), "Number of weights must match number of models"
73
  self.weights = weights
74
  else:
75
+ self.weights = [1.0 / max(loaded_count, 1)] * max(loaded_count, 1)
76
 
77
+ print(f"\033[92mINFO\033[0m: Loaded {loaded_count} models for MoE ensemble")
78
  print(f"\033[92mINFO\033[0m: Model weights: {self.weights}")
79
 
80
  def to(self, device):
 
90
  Forward pass through the MoE model.
91
  Returns the weighted average of all model outputs.
92
  """
93
+ # Check if we have models loaded
94
+ if not self.models:
95
+ print(f"\033[91mERR!\033[0m: No models available for inference!")
96
+ return torch.tensor([0.0], device=mfcc.device) # Return a default value
97
+
98
  outputs = []
99
 
100
  # Get outputs from each model