satyanayak commited on
Commit
de70513
·
1 Parent(s): 94625eb

debugging the new model

Browse files
Files changed (1) hide show
  1. app.py +91 -6
app.py CHANGED
@@ -2,7 +2,9 @@ import torch
2
  import torchvision.transforms as transforms
3
  from PIL import Image
4
  import gradio as gr
 
5
  import torchvision.models as models
 
6
 
7
  # Load ImageNet class labels
8
  with open('imagenet_classes.txt', 'r') as f:
@@ -19,30 +21,113 @@ def preprocess_image(image):
19
  ])
20
  return transform(image).unsqueeze(0)
21
 
22
- # Load model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def load_model():
24
- # Load pretrained ResNet50
25
- model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
26
- model.eval()
27
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- # Prediction function
30
  def predict(input_image):
31
  try:
32
  # Convert from BGR to RGB
33
  input_image = Image.fromarray(input_image)
 
34
 
35
  # Preprocess the image
36
  input_tensor = preprocess_image(input_image)
 
37
 
38
  # Make prediction
39
  with torch.no_grad():
40
  output = model(input_tensor)
 
 
 
41
  probabilities = torch.nn.functional.softmax(output[0], dim=0)
 
42
 
43
  # Get top 5 predictions
44
  top5_prob, top5_catid = torch.topk(probabilities, 5)
45
 
 
 
 
 
 
46
  # Create result dictionary
47
  results = {}
48
  for i in range(5):
 
2
  import torchvision.transforms as transforms
3
  from PIL import Image
4
  import gradio as gr
5
+ from huggingface_hub import hf_hub_download
6
  import torchvision.models as models
7
+ import os
8
 
9
  # Load ImageNet class labels
10
  with open('imagenet_classes.txt', 'r') as f:
 
21
  ])
22
  return transform(image).unsqueeze(0)
23
 
24
+ def convert_state_dict(state_dict):
25
+ """Convert Composer state dict to standard ResNet state dict."""
26
+ print("Original state dict keys:", list(state_dict.keys())[:5], "...")
27
+
28
+ new_state_dict = {}
29
+ for key, value in state_dict.items():
30
+ # Remove 'module.' prefix if it exists
31
+ if key.startswith('module.'):
32
+ key = key[7:] # Remove first 7 characters ('module.')
33
+
34
+ # Handle blur filter layers
35
+ if 'blur_filter' in key or 'filt2d' in key:
36
+ continue
37
+
38
+ # Convert conv layers with blur
39
+ if '.conv.weight' in key:
40
+ key = key.replace('.conv.weight', '.weight')
41
+
42
+ new_state_dict[key] = value
43
+
44
+ # Print shape information for debugging
45
+ print(f"Layer: {key}, Shape: {value.shape}")
46
+
47
+ print("\nConverted state dict keys:", list(new_state_dict.keys())[:5], "...")
48
+ return new_state_dict
49
+
50
+ # Load model from Hugging Face Hub
51
  def load_model():
52
+ try:
53
+ repo_id = "satyanayak/imagenet-resnet50-composer-model"
54
+ filename = "pytorch_model_latest.bin"
55
+
56
+ print(f"Attempting to load model from {repo_id}/{filename}")
57
+
58
+ # Download the model file
59
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename)
60
+ print(f"Model downloaded to: {model_path}")
61
+
62
+ # Initialize standard ResNet50
63
+ print("Initializing ResNet50 model...")
64
+ model = models.resnet50(weights=None)
65
+
66
+ # Print model structure
67
+ print("\nModel structure:")
68
+ for name, module in model.named_children():
69
+ print(f"{name}: {module.__class__.__name__}")
70
+
71
+ # Load and convert the state dict
72
+ print("\nLoading state dict...")
73
+ state_dict = torch.load(
74
+ model_path,
75
+ map_location=torch.device('cpu'),
76
+ weights_only=True
77
+ )
78
+
79
+ print("\nConverting state dict...")
80
+ converted_state_dict = convert_state_dict(state_dict)
81
+
82
+ # Load the converted state dict
83
+ print("\nLoading weights into model...")
84
+ missing_keys, unexpected_keys = model.load_state_dict(converted_state_dict, strict=False)
85
+
86
+ if missing_keys:
87
+ print("\nMissing keys:", missing_keys)
88
+ if unexpected_keys:
89
+ print("\nUnexpected keys:", unexpected_keys)
90
+
91
+ model.eval()
92
+ print("\nModel loaded successfully!")
93
+ return model
94
+
95
+ except Exception as e:
96
+ print(f"\nError loading custom model: {str(e)}")
97
+ print("Stack trace:", e.__traceback__)
98
+ print("Falling back to pretrained ResNet50")
99
+ model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
100
+ model.eval()
101
+ return model
102
 
103
+ # Prediction function with debugging
104
  def predict(input_image):
105
  try:
106
  # Convert from BGR to RGB
107
  input_image = Image.fromarray(input_image)
108
+ print(f"Input image size: {input_image.size}")
109
 
110
  # Preprocess the image
111
  input_tensor = preprocess_image(input_image)
112
+ print(f"Preprocessed tensor shape: {input_tensor.shape}")
113
 
114
  # Make prediction
115
  with torch.no_grad():
116
  output = model(input_tensor)
117
+ print(f"Raw output shape: {output.shape}")
118
+ print(f"Raw output values (first 5): {output[0][:5]}")
119
+
120
  probabilities = torch.nn.functional.softmax(output[0], dim=0)
121
+ print(f"Probability values (first 5): {probabilities[:5]}")
122
 
123
  # Get top 5 predictions
124
  top5_prob, top5_catid = torch.topk(probabilities, 5)
125
 
126
+ # Print debugging info
127
+ print("\nTop 5 predictions:")
128
+ for i in range(5):
129
+ print(f"{categories[top5_catid[i]]}: {float(top5_prob[i]):.4f}")
130
+
131
  # Create result dictionary
132
  results = {}
133
  for i in range(5):