huzey commited on
Commit
b1e189a
1 Parent(s): e9899b2

fix aligned advanced plot

Browse files
Files changed (1) hide show
  1. app.py +22 -17
app.py CHANGED
@@ -536,6 +536,9 @@ def ncut_run(
536
 
537
  if not advanced:
538
  return rgbs[0], rgbs[1], rgbs[2], logging_str
 
 
 
539
  if advanced:
540
  cluster_plots, norm_plots = [], []
541
  for i in range(3):
@@ -690,26 +693,26 @@ def ncut_run(
690
 
691
  def _ncut_run(*args, **kwargs):
692
  n_ret = kwargs.pop("n_ret", 1)
693
- try:
694
- if torch.cuda.is_available():
695
- torch.cuda.empty_cache()
696
 
697
- ret = ncut_run(*args, **kwargs)
698
 
699
- if torch.cuda.is_available():
700
- torch.cuda.empty_cache()
701
 
702
- ret = list(ret)[:n_ret] + [ret[-1]]
703
- return ret
704
- except Exception as e:
705
- gr.Error(str(e))
706
- if torch.cuda.is_available():
707
- torch.cuda.empty_cache()
708
- return *(None for _ in range(n_ret)), "Error: " + str(e)
709
-
710
- # ret = ncut_run(*args, **kwargs)
711
- # ret = list(ret)[:n_ret] + [ret[-1]]
712
- # return ret
713
 
714
  if USE_HUGGINGFACE_ZEROGPU:
715
  @spaces.GPU(duration=30)
@@ -871,6 +874,8 @@ def load_alignedthreemodel():
871
  model = ThreeAttnNodes(align_weights)
872
 
873
  return model
 
 
874
 
875
  promptable_diffusion_models = ["Diffusion(stabilityai/stable-diffusion-2)", "Diffusion(CompVis/stable-diffusion-v1-4)"]
876
  promptable_segmentation_models = ["LISA(xinlai/LISA-7B-v1)"]
 
536
 
537
  if not advanced:
538
  return rgbs[0], rgbs[1], rgbs[2], logging_str
539
+ if "AlignedThreeModelAttnNodes" == model_name:
540
+ return rgbs[0], rgbs[1], rgbs[2], logging_str
541
+
542
  if advanced:
543
  cluster_plots, norm_plots = [], []
544
  for i in range(3):
 
693
 
694
  def _ncut_run(*args, **kwargs):
695
  n_ret = kwargs.pop("n_ret", 1)
696
+ # try:
697
+ # if torch.cuda.is_available():
698
+ # torch.cuda.empty_cache()
699
 
700
+ # ret = ncut_run(*args, **kwargs)
701
 
702
+ # if torch.cuda.is_available():
703
+ # torch.cuda.empty_cache()
704
 
705
+ # ret = list(ret)[:n_ret] + [ret[-1]]
706
+ # return ret
707
+ # except Exception as e:
708
+ # gr.Error(str(e))
709
+ # if torch.cuda.is_available():
710
+ # torch.cuda.empty_cache()
711
+ # return *(None for _ in range(n_ret)), "Error: " + str(e)
712
+
713
+ ret = ncut_run(*args, **kwargs)
714
+ ret = list(ret)[:n_ret] + [ret[-1]]
715
+ return ret
716
 
717
  if USE_HUGGINGFACE_ZEROGPU:
718
  @spaces.GPU(duration=30)
 
874
  model = ThreeAttnNodes(align_weights)
875
 
876
  return model
877
+ # pre-load the alignedthree model in case it fails to load
878
+ load_alignedthreemodel()
879
 
880
  promptable_diffusion_models = ["Diffusion(stabilityai/stable-diffusion-2)", "Diffusion(CompVis/stable-diffusion-v1-4)"]
881
  promptable_segmentation_models = ["LISA(xinlai/LISA-7B-v1)"]