flamehaze1115 commited on
Commit
3b7ab74
1 Parent(s): 09e9525

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -70,9 +70,9 @@ class SAMAPI:
70
  def get_instance(sam_checkpoint=None):
71
  if SAMAPI.predictor is None:
72
  if sam_checkpoint is None:
73
- sam_checkpoint = "tmp/sam_vit_h_4b8939.pth"
74
  if not os.path.exists(sam_checkpoint):
75
- os.makedirs('tmp', exist_ok=True)
76
  urllib.request.urlretrieve(
77
  "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
78
  sam_checkpoint
@@ -353,8 +353,6 @@ with results_container:
353
  torch.manual_seed(seed)
354
  img = Image.open(pic)
355
 
356
- data = prepare_data(img)
357
-
358
  if max(img.size) > 1280:
359
  w, h = img.size
360
  w = round(1280 / max(img.size) * w)
@@ -374,6 +372,7 @@ with results_container:
374
  img = expand2square(img, (127, 127, 127, 0))
375
  # pipeline.set_progress_bar_config(disable=True)
376
  prog.progress(0.3, "Run cross-domain diffusion model")
 
377
  normals_pred, images_pred = run_pipeline(pipeline, data, cfg_scale, seed)
378
  prog.progress(0.9, "finishing")
379
  left, right = st.columns(2)
 
70
  def get_instance(sam_checkpoint=None):
71
  if SAMAPI.predictor is None:
72
  if sam_checkpoint is None:
73
+ sam_checkpoint = "./sam_pt/sam_vit_h_4b8939.pth"
74
  if not os.path.exists(sam_checkpoint):
75
+ os.makedirs('sam_pt', exist_ok=True)
76
  urllib.request.urlretrieve(
77
  "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
78
  sam_checkpoint
 
353
  torch.manual_seed(seed)
354
  img = Image.open(pic)
355
 
 
 
356
  if max(img.size) > 1280:
357
  w, h = img.size
358
  w = round(1280 / max(img.size) * w)
 
372
  img = expand2square(img, (127, 127, 127, 0))
373
  # pipeline.set_progress_bar_config(disable=True)
374
  prog.progress(0.3, "Run cross-domain diffusion model")
375
+ data = prepare_data(img)
376
  normals_pred, images_pred = run_pipeline(pipeline, data, cfg_scale, seed)
377
  prog.progress(0.9, "finishing")
378
  left, right = st.columns(2)