huntrezz commited on
Commit
7f6b914
·
verified ·
1 Parent(s): e41de5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -35
app.py CHANGED
@@ -1,102 +1,126 @@
1
- import cv2
2
- import torch
3
- import numpy as np
4
- from transformers import DPTImageProcessor
5
- import gradio as gr
6
- import matplotlib.pyplot as plt
7
- from mpl_toolkits.mplot3d import Axes3D
8
- import torch.nn as nn
 
9
 
10
 
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
- # Load your custom trained model
14
  class CompressedStudentModel(nn.Module):
15
  def __init__(self):
 
16
  super(CompressedStudentModel, self).__init__()
 
17
  self.encoder = nn.Sequential(
18
- nn.Conv2d(3, 64, kernel_size=3, padding=1),
 
 
19
  nn.ReLU(),
20
- nn.Conv2d(64, 64, kernel_size=3, padding=1),
 
21
  nn.ReLU(),
22
- nn.MaxPool2d(2),
23
- nn.Conv2d(64, 128, kernel_size=3, padding=1),
24
  nn.ReLU(),
25
- nn.Conv2d(128, 128, kernel_size=3, padding=1),
 
26
  nn.ReLU(),
27
- nn.MaxPool2d(2),
28
- nn.Conv2d(128, 256, kernel_size=3, padding=1),
29
- nn.ReLU(),
30
- nn.Conv2d(256, 256, kernel_size=3, padding=1),
31
  nn.ReLU(),
32
  )
 
33
  self.decoder = nn.Sequential(
34
- nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
35
  nn.ReLU(),
36
- nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
37
  nn.ReLU(),
38
- nn.Conv2d(64, 1, kernel_size=3, padding=1),
39
  )
40
 
41
  def forward(self, x):
 
42
  features = self.encoder(x)
 
43
  depth = self.decoder(features)
44
  return depth
45
 
46
- # Initialize and load weights into the student model
47
- model = CompressedStudentModel().to(device)
48
- model.load_state_dict(torch.load("huntrezz_depth_v2.pt", map_location=device))
49
- model.eval()
50
 
 
51
  processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256")
52
 
53
  def preprocess_image(image):
 
54
  image = cv2.resize(image, (200, 200))
 
55
  image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float().to(device)
 
56
  return image / 255.0
57
 
58
  def plot_depth_map(depth_map, original_image):
 
59
  fig = plt.figure(figsize=(16, 9))
 
60
  ax = fig.add_subplot(111, projection='3d')
 
61
  x, y = np.meshgrid(range(depth_map.shape[1]), range(depth_map.shape[0]))
62
 
63
- # Normalize depth map to [0, 1] range
64
  norm = plt.Normalize(depth_map.min(), depth_map.max())
65
  colors = plt.cm.viridis(norm(depth_map))
66
 
 
67
  ax.plot_surface(x, y, depth_map, facecolors=colors, shade=False)
68
- ax.set_zlim(0, 1)
69
 
70
- # Adjust the view to look down at an angle from a higher position
71
  ax.view_init(elev=70, azim=90)
72
- plt.axis('off')
73
- plt.close(fig)
74
 
 
75
  fig.canvas.draw()
76
  img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
77
  img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
78
 
79
  return img
80
 
81
- @torch.inference_mode()
82
  def process_frame(image):
 
83
  if image is None:
84
  return None
 
85
  preprocessed = preprocess_image(image)
 
86
  predicted_depth = model(preprocessed).squeeze().cpu().numpy()
87
 
 
88
  depth_map = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min())
89
 
 
90
  if image.shape[2] == 3:
91
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
92
 
 
93
  return plot_depth_map(depth_map, image)
94
 
 
95
  interface = gr.Interface(
96
- fn=process_frame,
97
- inputs=gr.Image(sources="webcam", streaming=True),
98
- outputs="image",
99
- live=True
100
  )
101
 
 
102
  interface.launch()
 
1
+ # Import required libraries for image processing, deep learning, and visualization
2
+ import cv2 # OpenCV for image processing
3
+ import torch # PyTorch deep learning framework
4
+ import numpy as np # NumPy for numerical operations
5
+ from transformers import DPTImageProcessor # Hugging Face image processor for depth estimation
6
+ import gradio as gr # Gradio for creating web interfaces
7
+ import matplotlib.pyplot as plt # Matplotlib for plotting
8
+ from mpl_toolkits.mplot3d import Axes3D # 3D plotting tools
9
+ import torch.nn as nn # Neural network modules from PyTorch
10
 
11
 
12
+ # Set up device - will use GPU if available, otherwise CPU
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
15
+ # Define my compressed student model architecture for depth estimation
16
  class CompressedStudentModel(nn.Module):
17
  def __init__(self):
18
+ # Initialize parent class
19
  super(CompressedStudentModel, self).__init__()
20
+ # Define encoder network that extracts features from input image
21
  self.encoder = nn.Sequential(
22
+ nn.Conv2d(3, 64, kernel_size=3, padding=1), # First conv layer: RGB -> 64 channels
23
+ nn.ReLU(), # Activation function
24
+ nn.Conv2d(64, 64, kernel_size=3, padding=1), # Second conv: 64 -> 64 channels
25
  nn.ReLU(),
26
+ nn.MaxPool2d(2), # Reduce spatial dimensions by 2
27
+ nn.Conv2d(64, 128, kernel_size=3, padding=1), # Third conv: 64 -> 128 channels
28
  nn.ReLU(),
29
+ nn.Conv2d(128, 128, kernel_size=3, padding=1), # Fourth conv: 128 -> 128 channels
 
30
  nn.ReLU(),
31
+ nn.MaxPool2d(2), # Further reduce spatial dimensions
32
+ nn.Conv2d(128, 256, kernel_size=3, padding=1), # Fifth conv: 128 -> 256 channels
33
  nn.ReLU(),
34
+ nn.Conv2d(256, 256, kernel_size=3, padding=1), # Sixth conv: 256 -> 256 channels
 
 
 
35
  nn.ReLU(),
36
  )
37
+ # Define decoder network that upsamples features back to original resolution
38
  self.decoder = nn.Sequential(
39
+ nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1), # First upsample: 256 -> 128
40
  nn.ReLU(),
41
+ nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), # Second upsample: 128 -> 64
42
  nn.ReLU(),
43
+ nn.Conv2d(64, 1, kernel_size=3, padding=1), # Final conv: 64 -> 1 channel depth map
44
  )
45
 
46
  def forward(self, x):
47
+ # Pass input through encoder to get features
48
  features = self.encoder(x)
49
+ # Pass features through decoder to get depth map
50
  depth = self.decoder(features)
51
  return depth
52
 
53
+ # Load my trained model and prepare it for inference
54
+ model = CompressedStudentModel().to(device) # Create model instance and move to device
55
+ model.load_state_dict(torch.load("huntrezz_depth_v2.pt", map_location=device)) # Load trained weights
56
+ model.eval() # Set model to evaluation mode
57
 
58
+ # Initialize the image processor from Hugging Face
59
  processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256")
60
 
61
  def preprocess_image(image):
62
+ # Resize image to 200x200 for consistent processing
63
  image = cv2.resize(image, (200, 200))
64
+ # Convert image to PyTorch tensor and move to device
65
  image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float().to(device)
66
+ # Normalize pixel values to [0,1] range
67
  return image / 255.0
68
 
69
  def plot_depth_map(depth_map, original_image):
70
+ # Create new figure with specific size
71
  fig = plt.figure(figsize=(16, 9))
72
+ # Add 3D subplot
73
  ax = fig.add_subplot(111, projection='3d')
74
+ # Create coordinate grids for 3D plot
75
  x, y = np.meshgrid(range(depth_map.shape[1]), range(depth_map.shape[0]))
76
 
77
+ # Normalize depth values for coloring
78
  norm = plt.Normalize(depth_map.min(), depth_map.max())
79
  colors = plt.cm.viridis(norm(depth_map))
80
 
81
+ # Create 3D surface plot
82
  ax.plot_surface(x, y, depth_map, facecolors=colors, shade=False)
83
+ ax.set_zlim(0, 1) # Set z-axis limits
84
 
85
+ # Set viewing angle for better visualization
86
  ax.view_init(elev=70, azim=90)
87
+ plt.axis('off') # Hide axes
88
+ plt.close(fig) # Close the figure to free memory
89
 
90
+ # Convert matplotlib figure to numpy array
91
  fig.canvas.draw()
92
  img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
93
  img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
94
 
95
  return img
96
 
97
+ @torch.inference_mode() # Disable gradient computation for inference
98
  def process_frame(image):
99
+ # Check if image is valid
100
  if image is None:
101
  return None
102
+ # Preprocess input image
103
  preprocessed = preprocess_image(image)
104
+ # Get depth prediction from model
105
  predicted_depth = model(preprocessed).squeeze().cpu().numpy()
106
 
107
+ # Normalize depth values to [0,1] range
108
  depth_map = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min())
109
 
110
+ # Convert BGR to RGB if needed
111
  if image.shape[2] == 3:
112
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
113
 
114
+ # Create and return 3D visualization
115
  return plot_depth_map(depth_map, image)
116
 
117
+ # Create Gradio interface for webcam input
118
  interface = gr.Interface(
119
+ fn=process_frame, # Processing function
120
+ inputs=gr.Image(sources="webcam", streaming=True), # Webcam input
121
+ outputs="image", # Image output
122
+ live=True # Enable live updates
123
  )
124
 
125
+ # Launch the interface
126
  interface.launch()