ombhojane commited on
Commit
8e69506
·
verified ·
1 Parent(s): f18cb7d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -4
app.py CHANGED
@@ -34,15 +34,27 @@ class BehaviorDetector(nn.Module):
34
  return torch.sigmoid(self.base_model(x))
35
 
36
  class DogBehaviorAnalyzer:
37
- def __init__(self, model_path='best.pt'):
38
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
 
40
- # Initialize YOLO model for dog detection
41
- self.yolo_model = YOLO(model_path) if model_path else None
 
 
 
 
42
 
43
  # Initialize behavior classifier
44
  self.num_behaviors = 5
45
- self.behavior_model = BehaviorDetector(self.num_behaviors).to(self.device)
 
 
 
 
 
 
 
 
46
  self.behavior_model.eval()
47
 
48
  # Define sophisticated transforms
 
34
  return torch.sigmoid(self.base_model(x))
35
 
36
  class DogBehaviorAnalyzer:
37
+ def __init__(self, model_path=None):
38
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
 
40
+ # Initialize YOLO model for dog detection (optional)
41
+ try:
42
+ self.yolo_model = YOLO(model_path) if model_path else None
43
+ except Exception as e:
44
+ st.warning("YOLO model not found. Running without dog detection.")
45
+ self.yolo_model = None
46
 
47
  # Initialize behavior classifier
48
  self.num_behaviors = 5
49
+ try:
50
+ self.behavior_model = BehaviorDetector(self.num_behaviors).to(self.device)
51
+ except Exception as e:
52
+ st.warning("Error loading behavior model. Using default classifier.")
53
+ self.behavior_model = models.resnet18(pretrained=True)
54
+ num_features = self.behavior_model.fc.in_features
55
+ self.behavior_model.fc = nn.Linear(num_features, self.num_behaviors)
56
+ self.behavior_model = self.behavior_model.to(self.device)
57
+
58
  self.behavior_model.eval()
59
 
60
  # Define sophisticated transforms