Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- app.py +36 -11
- app_moe.py +23 -8
app.py
CHANGED
@@ -34,10 +34,11 @@ class WatermelonMoEModel(torch.nn.Module):
|
|
34 |
weights: Optional list of weights for each model (None for equal weighting)
|
35 |
"""
|
36 |
super(WatermelonMoEModel, self).__init__()
|
37 |
-
self.models =
|
38 |
self.model_configs = model_configs
|
39 |
|
40 |
# Load each model
|
|
|
41 |
for config in model_configs:
|
42 |
img_backbone = config["image_backbone"]
|
43 |
audio_backbone = config["audio_backbone"]
|
@@ -49,22 +50,31 @@ class WatermelonMoEModel(torch.nn.Module):
|
|
49 |
model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt")
|
50 |
if os.path.exists(model_path):
|
51 |
print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}")
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
else:
|
54 |
print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}")
|
55 |
continue
|
56 |
-
|
57 |
-
model.eval() # Set to evaluation mode
|
58 |
-
self.models.append(model)
|
59 |
|
|
|
|
|
|
|
|
|
|
|
60 |
# Set model weights (uniform by default)
|
61 |
-
if weights:
|
62 |
assert len(weights) == len(self.models), "Number of weights must match number of models"
|
63 |
self.weights = weights
|
64 |
else:
|
65 |
-
self.weights = [1.0 /
|
66 |
|
67 |
-
print(f"\033[92mINFO\033[0m: Loaded {
|
68 |
print(f"\033[92mINFO\033[0m: Model weights: {self.weights}")
|
69 |
|
70 |
def to(self, device):
|
@@ -80,6 +90,11 @@ class WatermelonMoEModel(torch.nn.Module):
|
|
80 |
Forward pass through the MoE model.
|
81 |
Returns the weighted average of all model outputs.
|
82 |
"""
|
|
|
|
|
|
|
|
|
|
|
83 |
outputs = []
|
84 |
|
85 |
# Get outputs from each model
|
@@ -258,9 +273,19 @@ def predict_sugar_content(audio, image, model_dir="models", weights=None):
|
|
258 |
print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}, device: {processed_image.device}")
|
259 |
|
260 |
# Double-check model is on the correct device
|
261 |
-
|
262 |
-
|
263 |
-
print(f"\033[92mDEBUG\033[0m:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
|
265 |
# Run inference with MoE model
|
266 |
print(f"\033[92mDEBUG\033[0m: Running inference with MoE model on device: {device}")
|
|
|
34 |
weights: Optional list of weights for each model (None for equal weighting)
|
35 |
"""
|
36 |
super(WatermelonMoEModel, self).__init__()
|
37 |
+
self.models = torch.nn.ModuleList() # Use ModuleList instead of regular list
|
38 |
self.model_configs = model_configs
|
39 |
|
40 |
# Load each model
|
41 |
+
loaded_count = 0
|
42 |
for config in model_configs:
|
43 |
img_backbone = config["image_backbone"]
|
44 |
audio_backbone = config["audio_backbone"]
|
|
|
50 |
model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt")
|
51 |
if os.path.exists(model_path):
|
52 |
print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}")
|
53 |
+
try:
|
54 |
+
model.load_state_dict(torch.load(model_path, map_location='cpu'))
|
55 |
+
model.eval() # Set to evaluation mode
|
56 |
+
self.models.append(model)
|
57 |
+
loaded_count += 1
|
58 |
+
except Exception as e:
|
59 |
+
print(f"\033[91mERR!\033[0m: Failed to load model from {model_path}: {e}")
|
60 |
+
continue
|
61 |
else:
|
62 |
print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}")
|
63 |
continue
|
|
|
|
|
|
|
64 |
|
65 |
+
# Add a dummy parameter if no models were loaded to prevent StopIteration
|
66 |
+
if loaded_count == 0:
|
67 |
+
print(f"\033[91mERR!\033[0m: No models were successfully loaded!")
|
68 |
+
self.dummy_param = torch.nn.Parameter(torch.zeros(1))
|
69 |
+
|
70 |
# Set model weights (uniform by default)
|
71 |
+
if weights and loaded_count > 0:
|
72 |
assert len(weights) == len(self.models), "Number of weights must match number of models"
|
73 |
self.weights = weights
|
74 |
else:
|
75 |
+
self.weights = [1.0 / max(loaded_count, 1)] * max(loaded_count, 1)
|
76 |
|
77 |
+
print(f"\033[92mINFO\033[0m: Loaded {loaded_count} models for MoE ensemble")
|
78 |
print(f"\033[92mINFO\033[0m: Model weights: {self.weights}")
|
79 |
|
80 |
def to(self, device):
|
|
|
90 |
Forward pass through the MoE model.
|
91 |
Returns the weighted average of all model outputs.
|
92 |
"""
|
93 |
+
# Check if we have models loaded
|
94 |
+
if not self.models:
|
95 |
+
print(f"\033[91mERR!\033[0m: No models available for inference!")
|
96 |
+
return torch.tensor([0.0], device=mfcc.device) # Return a default value
|
97 |
+
|
98 |
outputs = []
|
99 |
|
100 |
# Get outputs from each model
|
|
|
273 |
print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}, device: {processed_image.device}")
|
274 |
|
275 |
# Double-check model is on the correct device
|
276 |
+
try:
|
277 |
+
param = next(moe_model.parameters())
|
278 |
+
print(f"\033[92mDEBUG\033[0m: MoE model device: {param.device}")
|
279 |
+
|
280 |
+
# Check individual models
|
281 |
+
for i, model in enumerate(moe_model.models):
|
282 |
+
try:
|
283 |
+
model_param = next(model.parameters())
|
284 |
+
print(f"\033[92mDEBUG\033[0m: Model {i} device: {model_param.device}")
|
285 |
+
except StopIteration:
|
286 |
+
print(f"\033[91mERR!\033[0m: Model {i} has no parameters!")
|
287 |
+
except StopIteration:
|
288 |
+
print(f"\033[91mERR!\033[0m: MoE model has no parameters!")
|
289 |
|
290 |
# Run inference with MoE model
|
291 |
print(f"\033[92mDEBUG\033[0m: Running inference with MoE model on device: {device}")
|
app_moe.py
CHANGED
@@ -34,10 +34,11 @@ class WatermelonMoEModel(torch.nn.Module):
|
|
34 |
weights: Optional list of weights for each model (None for equal weighting)
|
35 |
"""
|
36 |
super(WatermelonMoEModel, self).__init__()
|
37 |
-
self.models =
|
38 |
self.model_configs = model_configs
|
39 |
|
40 |
# Load each model
|
|
|
41 |
for config in model_configs:
|
42 |
img_backbone = config["image_backbone"]
|
43 |
audio_backbone = config["audio_backbone"]
|
@@ -49,22 +50,31 @@ class WatermelonMoEModel(torch.nn.Module):
|
|
49 |
model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt")
|
50 |
if os.path.exists(model_path):
|
51 |
print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}")
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
else:
|
54 |
print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}")
|
55 |
continue
|
56 |
-
|
57 |
-
model.eval() # Set to evaluation mode
|
58 |
-
self.models.append(model)
|
59 |
|
|
|
|
|
|
|
|
|
|
|
60 |
# Set model weights (uniform by default)
|
61 |
-
if weights:
|
62 |
assert len(weights) == len(self.models), "Number of weights must match number of models"
|
63 |
self.weights = weights
|
64 |
else:
|
65 |
-
self.weights = [1.0 /
|
66 |
|
67 |
-
print(f"\033[92mINFO\033[0m: Loaded {
|
68 |
print(f"\033[92mINFO\033[0m: Model weights: {self.weights}")
|
69 |
|
70 |
def to(self, device):
|
@@ -80,6 +90,11 @@ class WatermelonMoEModel(torch.nn.Module):
|
|
80 |
Forward pass through the MoE model.
|
81 |
Returns the weighted average of all model outputs.
|
82 |
"""
|
|
|
|
|
|
|
|
|
|
|
83 |
outputs = []
|
84 |
|
85 |
# Get outputs from each model
|
|
|
34 |
weights: Optional list of weights for each model (None for equal weighting)
|
35 |
"""
|
36 |
super(WatermelonMoEModel, self).__init__()
|
37 |
+
self.models = torch.nn.ModuleList() # Use ModuleList instead of regular list
|
38 |
self.model_configs = model_configs
|
39 |
|
40 |
# Load each model
|
41 |
+
loaded_count = 0
|
42 |
for config in model_configs:
|
43 |
img_backbone = config["image_backbone"]
|
44 |
audio_backbone = config["audio_backbone"]
|
|
|
50 |
model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt")
|
51 |
if os.path.exists(model_path):
|
52 |
print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}")
|
53 |
+
try:
|
54 |
+
model.load_state_dict(torch.load(model_path, map_location='cpu'))
|
55 |
+
model.eval() # Set to evaluation mode
|
56 |
+
self.models.append(model)
|
57 |
+
loaded_count += 1
|
58 |
+
except Exception as e:
|
59 |
+
print(f"\033[91mERR!\033[0m: Failed to load model from {model_path}: {e}")
|
60 |
+
continue
|
61 |
else:
|
62 |
print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}")
|
63 |
continue
|
|
|
|
|
|
|
64 |
|
65 |
+
# Add a dummy parameter if no models were loaded to prevent StopIteration
|
66 |
+
if loaded_count == 0:
|
67 |
+
print(f"\033[91mERR!\033[0m: No models were successfully loaded!")
|
68 |
+
self.dummy_param = torch.nn.Parameter(torch.zeros(1))
|
69 |
+
|
70 |
# Set model weights (uniform by default)
|
71 |
+
if weights and loaded_count > 0:
|
72 |
assert len(weights) == len(self.models), "Number of weights must match number of models"
|
73 |
self.weights = weights
|
74 |
else:
|
75 |
+
self.weights = [1.0 / max(loaded_count, 1)] * max(loaded_count, 1)
|
76 |
|
77 |
+
print(f"\033[92mINFO\033[0m: Loaded {loaded_count} models for MoE ensemble")
|
78 |
print(f"\033[92mINFO\033[0m: Model weights: {self.weights}")
|
79 |
|
80 |
def to(self, device):
|
|
|
90 |
Forward pass through the MoE model.
|
91 |
Returns the weighted average of all model outputs.
|
92 |
"""
|
93 |
+
# Check if we have models loaded
|
94 |
+
if not self.models:
|
95 |
+
print(f"\033[91mERR!\033[0m: No models available for inference!")
|
96 |
+
return torch.tensor([0.0], device=mfcc.device) # Return a default value
|
97 |
+
|
98 |
outputs = []
|
99 |
|
100 |
# Get outputs from each model
|