juxuan27 commited on
Commit
5ea98b8
·
verified ·
1 Parent(s): 39429bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -17,7 +17,7 @@ import random
17
  import spaces
18
  import gradio as gr
19
 
20
- mobile_sam = sam_model_registry['vit_h'](checkpoint='data/ckpt/sam_vit_h_4b8939.pth').to("cuda")
21
  mobile_sam.eval()
22
  mobile_predictor = SamPredictor(mobile_sam)
23
  colors = [(255, 0, 0), (0, 255, 0)]
@@ -269,6 +269,7 @@ with block:
269
  for p, l in sel_pix:
270
  points.append(p)
271
  labels.append(l)
 
272
  mobile_predictor.set_image(img if isinstance(img, np.ndarray) else np.array(img))
273
  with torch.no_grad():
274
  masks, _, _ = mobile_predictor.predict(point_coords=np.array(points), point_labels=np.array(labels), multimask_output=False)
 
17
  import spaces
18
  import gradio as gr
19
 
20
+ mobile_sam = sam_model_registry['vit_h'](checkpoint='data/ckpt/sam_vit_h_4b8939.pth')
21
  mobile_sam.eval()
22
  mobile_predictor = SamPredictor(mobile_sam)
23
  colors = [(255, 0, 0), (0, 255, 0)]
 
269
  for p, l in sel_pix:
270
  points.append(p)
271
  labels.append(l)
272
+ mobile_predictor=mobile_predictor.to("cuda")
273
  mobile_predictor.set_image(img if isinstance(img, np.ndarray) else np.array(img))
274
  with torch.no_grad():
275
  masks, _, _ = mobile_predictor.predict(point_coords=np.array(points), point_labels=np.array(labels), multimask_output=False)