ascarlettvfx commited on
Commit
0bbefc3
·
verified ·
1 Parent(s): a49153f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -14
app.py CHANGED
@@ -1,22 +1,26 @@
1
- from transformers import pipeline
2
  import gradio as gr
 
 
3
 
4
- # Initialize the pipeline with the Marigold model hosted on Hugging Face
5
- model = pipeline("image-to-image", model="prs-eth/marigold-depth-v1-0")
 
 
 
6
 
7
- def predict_depth(image):
8
- # Generate a depth map from the input image using the model
9
- output = model(image)
10
- return output['output_image'] # Ensure this key matches the output of your model
11
 
12
- # Set up the Gradio interface
13
- interface = gr.Interface(
14
  fn=predict_depth,
15
- inputs=gr.inputs.Image(shape=(512, 512), label="Upload Image"),
16
- outputs=gr.outputs.Image(label="Depth Map"),
17
- title="Marigold Depth Map Estimation",
18
- description="Upload an image and the model will estimate and display its depth map."
 
19
  )
20
 
21
  if __name__ == "__main__":
22
- interface.launch()
 
 
1
  import gradio as gr
2
+ from PIL import Image
3
+ from marigold_depth_estimation import MarigoldPipeline, UNet2DConditionModel, AutoencoderKL, DDIMScheduler
4
 
5
+ # Instantiate the model components and the pipeline
6
+ unet_model = UNet2DConditionModel()
7
+ vae_model = AutoencoderKL()
8
+ scheduler = DDIMScheduler()
9
+ pipeline = MarigoldPipeline(unet=unet_model, vae=vae_model, scheduler=scheduler)
10
 
11
+ def predict_depth(input_image):
12
+ # Process the image and predict the depth map
13
+ output = pipeline(input_image)
14
+ return output.depth_image
15
 
16
+ iface = gr.Interface(
 
17
  fn=predict_depth,
18
+ inputs=gr.inputs.Image(type="pil", label="Upload an Image"),
19
+ outputs=gr.outputs.Image(type="pil", label="Depth Map"),
20
+ title="Depth Map Generation",
21
+ description="Upload an image to generate its depth map using the Marigold Depth Estimation Model.",
22
+ examples=["sample1.jpg", "sample2.jpg"] # Optional: include example images in your repository
23
  )
24
 
25
  if __name__ == "__main__":
26
+ iface.launch()