Tennineee commited on
Commit
fe1ae81
·
verified ·
1 Parent(s): 6c0487e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -13
app.py CHANGED
@@ -92,7 +92,6 @@ def predict(image):
92
  image = transforms(image).unsqueeze(0)
93
  DIS_map = model.inference(image.to(device),depth.to(device))[0][0][0].cpu()
94
  DIS_map = cv2.resize(np.array(DIS_map), (W,H))
95
- # return cv2.resize(np.array(depth[0][0]), (W,H))
96
  return DIS_map
97
 
98
  with gr.Blocks(css=css) as demo:
@@ -102,29 +101,24 @@ with gr.Blocks(css=css) as demo:
102
 
103
  with gr.Row():
104
  input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
105
- dis_image_slider = gr.Image(label="Pedict View", type='numpy', elem_id='img-display-input')
106
- # dis_image_slider = ImageSlider(label="Pedict View", elem_id='img-display-output', position=0.5)
107
  submit = gr.Button(value="Compute")
108
- raw_file = gr.File(label="16-bit raw output", elem_id="download",)
109
 
110
  def on_submit(image):
111
  original_image = image.copy()
112
 
113
- DIS_map = predict(image)
114
  DIS_map = (DIS_map - DIS_map.min()) / (DIS_map.max() - DIS_map.min()) * 255.0
115
- raw_DIS_map = Image.fromarray(DIS_map.astype('uint16'))
116
- tmp_raw_DIS_map = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
117
- raw_DIS_map.save(tmp_raw_DIS_map.name)
118
 
119
- # return [[original_image, DIS_map.astype(np.uint8)], tmp_raw_DIS_map.name]
120
- return [DIS_map.astype(np.uint8), tmp_raw_DIS_map.name]
121
-
122
- submit.click(on_submit, inputs=[input_image], outputs=[dis_image_slider, raw_file])
123
 
124
  example_files = os.listdir('assets/examples')
125
  example_files.sort()
126
  example_files = [os.path.join('assets/examples', filename) for filename in example_files]
127
- examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[dis_image_slider, raw_file], fn=on_submit)
128
 
129
  if __name__ == '__main__':
130
  demo.queue().launch(share=True)
 
92
  image = transforms(image).unsqueeze(0)
93
  DIS_map = model.inference(image.to(device),depth.to(device))[0][0][0].cpu()
94
  DIS_map = cv2.resize(np.array(DIS_map), (W,H))
 
95
  return DIS_map
96
 
97
  with gr.Blocks(css=css) as demo:
 
101
 
102
  with gr.Row():
103
  input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
104
+ dis_image_slider = gr.Image(label="Pedict View",type='numpy', elem_id='img-display-output')
105
+ # dis_image_slider = ImageSlider(label="Pedict View", type="pil", elem_id='img-display-output')
106
  submit = gr.Button(value="Compute")
 
107
 
108
  def on_submit(image):
109
  original_image = image.copy()
110
 
111
+ DIS_map = predict(np.array(image))
112
  DIS_map = (DIS_map - DIS_map.min()) / (DIS_map.max() - DIS_map.min()) * 255.0
113
+ matting = (DIS_map[...,None] / 255.0 * original_image)
114
+ return matting.astype('uint8')
 
115
 
116
+ submit.click(on_submit, inputs=[input_image], outputs=dis_image_slider)
 
 
 
117
 
118
  example_files = os.listdir('assets/examples')
119
  example_files.sort()
120
  example_files = [os.path.join('assets/examples', filename) for filename in example_files]
121
+ examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=dis_image_slider, fn=on_submit)
122
 
123
  if __name__ == '__main__':
124
  demo.queue().launch(share=True)