A19grey commited on
Commit
84a6150
·
1 Parent(s): 602ae0d

Added trimesh install and allow user file download

Browse files
Files changed (2) hide show
  1. app.py +18 -11
  2. requirements.txt +3 -1
app.py CHANGED
@@ -9,6 +9,7 @@ import torch
9
  import tempfile
10
  import os
11
  import trimesh
 
12
 
13
  # Run the script to download pretrained models
14
  subprocess.run(["bash", "get_pretrained_models.sh"])
@@ -56,7 +57,7 @@ def generate_3d_model(depth, image_path, focallength_px):
56
  focallength_px (float): Focal length in pixels.
57
 
58
  Returns:
59
- str: Path to the exported 3D model file in OBJ format.
60
  """
61
  # Load the RGB image and convert to a NumPy array
62
  image = np.array(Image.open(image_path))
@@ -96,10 +97,13 @@ def generate_3d_model(depth, image_path, focallength_px):
96
  # Create the mesh using Trimesh with vertex colors
97
  mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=colors)
98
 
99
- # Export the mesh to an OBJ file .
100
- model_path = 'output_model.obj'
101
- mesh.export(model_path)
102
- return model_path
 
 
 
103
 
104
  @spaces.GPU(duration=20)
105
  def predict_depth(input_image):
@@ -114,7 +118,8 @@ def predict_depth(input_image):
114
  - str: Path to the depth map image.
115
  - str: Focal length in pixels or an error message.
116
  - str: Path to the raw depth data CSV file.
117
- - str: Path to the generated 3D model file.
 
118
  """
119
  temp_file = None
120
  try:
@@ -176,12 +181,12 @@ def predict_depth(input_image):
176
  np.savetxt(raw_depth_path, depth, delimiter=',')
177
 
178
  # Generate the 3D model from the depth map and resized image
179
- model_path = generate_3d_model(depth, temp_file, focallength_px)
180
 
181
- return output_path, f"Focal length: {focallength_px:.2f} pixels", raw_depth_path, model_path
182
  except Exception as e:
183
  # Return error messages in case of failures
184
- return None, f"An error occurred: {str(e)}", None, None
185
  finally:
186
  # Clean up by removing the temporary resized image file
187
  if temp_file and os.path.exists(temp_file):
@@ -195,7 +200,8 @@ iface = gr.Interface(
195
  gr.Image(type="filepath", label="Depth Map"), # Displays the depth map image
196
  gr.Textbox(label="Focal Length or Error Message"), # Shows focal length or error messages
197
  gr.File(label="Download Raw Depth Map (CSV)"), # Allows downloading the raw depth data
198
- gr.Model3D(label="3D Model") # Displays the generated 3D model
 
199
  ],
200
  title="DepthPro Demo with 3D Visualization",
201
  description=(
@@ -204,7 +210,8 @@ iface = gr.Interface(
204
  "1. Upload an image.\n"
205
  "2. The app will predict the depth map, display it, and provide the focal length.\n"
206
  "3. Download the raw depth data as a CSV file.\n"
207
- "4. View the generated 3D model textured with the original image."
 
208
  ),
209
  )
210
 
 
9
  import tempfile
10
  import os
11
  import trimesh
12
+ import time # Add this import at the top of the file
13
 
14
  # Run the script to download pretrained models
15
  subprocess.run(["bash", "get_pretrained_models.sh"])
 
57
  focallength_px (float): Focal length in pixels.
58
 
59
  Returns:
60
+ tuple: Paths to the exported 3D model files for viewing and downloading.
61
  """
62
  # Load the RGB image and convert to a NumPy array
63
  image = np.array(Image.open(image_path))
 
97
  # Create the mesh using Trimesh with vertex colors
98
  mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=colors)
99
 
100
+ # Export the mesh to OBJ files with unique filenames
101
+ timestamp = int(time.time())
102
+ view_model_path = f'view_model_{timestamp}.obj'
103
+ download_model_path = f'download_model_{timestamp}.obj'
104
+ mesh.export(view_model_path)
105
+ mesh.export(download_model_path)
106
+ return view_model_path, download_model_path
107
 
108
  @spaces.GPU(duration=20)
109
  def predict_depth(input_image):
 
118
  - str: Path to the depth map image.
119
  - str: Focal length in pixels or an error message.
120
  - str: Path to the raw depth data CSV file.
121
+ - str: Path to the generated 3D model file for viewing.
122
+ - str: Path to the downloadable 3D model file.
123
  """
124
  temp_file = None
125
  try:
 
181
  np.savetxt(raw_depth_path, depth, delimiter=',')
182
 
183
  # Generate the 3D model from the depth map and resized image
184
+ view_model_path, download_model_path = generate_3d_model(depth, temp_file, focallength_px)
185
 
186
+ return output_path, f"Focal length: {focallength_px:.2f} pixels", raw_depth_path, view_model_path, download_model_path
187
  except Exception as e:
188
  # Return error messages in case of failures
189
+ return None, f"An error occurred: {str(e)}", None, None, None
190
  finally:
191
  # Clean up by removing the temporary resized image file
192
  if temp_file and os.path.exists(temp_file):
 
200
  gr.Image(type="filepath", label="Depth Map"), # Displays the depth map image
201
  gr.Textbox(label="Focal Length or Error Message"), # Shows focal length or error messages
202
  gr.File(label="Download Raw Depth Map (CSV)"), # Allows downloading the raw depth data
203
+ gr.Model3D(label="View 3D Model"), # For viewing the 3D model
204
+ gr.File(label="Download 3D Model (OBJ)") # For downloading the 3D model
205
  ],
206
  title="DepthPro Demo with 3D Visualization",
207
  description=(
 
210
  "1. Upload an image.\n"
211
  "2. The app will predict the depth map, display it, and provide the focal length.\n"
212
  "3. Download the raw depth data as a CSV file.\n"
213
+ "4. View the generated 3D model textured with the original image.\n"
214
+ "5. Download the 3D model as an OBJ file if desired."
215
  ),
216
  )
217
 
requirements.txt CHANGED
@@ -5,4 +5,6 @@ torch
5
  torchvision
6
  numpy
7
  pillow_heif
8
- timm
 
 
 
5
  torchvision
6
  numpy
7
  pillow_heif
8
+ timm
9
+ trimesh
10
+ time