huzey commited on
Commit
4a168d1
1 Parent(s): 5dac5bc

add directed ncut (test)

Browse files
Files changed (1) hide show
  1. app.py +20 -20
app.py CHANGED
@@ -204,9 +204,9 @@ def compute_ncut_directed(
204
  make_symmetric=False,
205
  progess_start=0.4,
206
  ):
207
- print("Using directed_ncut")
208
- print("features_1.shape", features_1.shape)
209
- print("features_2.shape", features_2.shape)
210
  from directed_ncut import nystrom_ncut
211
  progress = gr.Progress()
212
  logging_str = ""
@@ -977,26 +977,26 @@ def ncut_run(
977
 
978
  def _ncut_run(*args, **kwargs):
979
  n_ret = kwargs.pop("n_ret", 1)
980
- # try:
981
- # if torch.cuda.is_available():
982
- # torch.cuda.empty_cache()
983
 
984
- # ret = ncut_run(*args, **kwargs)
985
 
986
- # if torch.cuda.is_available():
987
- # torch.cuda.empty_cache()
988
 
989
- # ret = list(ret)[:n_ret] + [ret[-1]]
990
- # return ret
991
- # except Exception as e:
992
- # gr.Error(str(e))
993
- # if torch.cuda.is_available():
994
- # torch.cuda.empty_cache()
995
- # return *(None for _ in range(n_ret)), "Error: " + str(e)
996
 
997
- ret = ncut_run(*args, **kwargs)
998
- ret = list(ret)[:n_ret] + [ret[-1]]
999
- return ret
1000
 
1001
  if USE_HUGGINGFACE_ZEROGPU:
1002
  @spaces.GPU(duration=30)
@@ -1213,7 +1213,7 @@ def run_fn(
1213
  advanced=False,
1214
  directed=False,
1215
  ):
1216
- print(node_type2, head_index_text, make_symmetric)
1217
  progress=gr.Progress()
1218
  progress(0, desc="Starting")
1219
 
 
204
  make_symmetric=False,
205
  progess_start=0.4,
206
  ):
207
+ # print("Using directed_ncut")
208
+ # print("features_1.shape", features_1.shape)
209
+ # print("features_2.shape", features_2.shape)
210
  from directed_ncut import nystrom_ncut
211
  progress = gr.Progress()
212
  logging_str = ""
 
977
 
978
  def _ncut_run(*args, **kwargs):
979
  n_ret = kwargs.pop("n_ret", 1)
980
+ try:
981
+ if torch.cuda.is_available():
982
+ torch.cuda.empty_cache()
983
 
984
+ ret = ncut_run(*args, **kwargs)
985
 
986
+ if torch.cuda.is_available():
987
+ torch.cuda.empty_cache()
988
 
989
+ ret = list(ret)[:n_ret] + [ret[-1]]
990
+ return ret
991
+ except Exception as e:
992
+ gr.Error(str(e))
993
+ if torch.cuda.is_available():
994
+ torch.cuda.empty_cache()
995
+ return *(None for _ in range(n_ret)), "Error: " + str(e)
996
 
997
+ # ret = ncut_run(*args, **kwargs)
998
+ # ret = list(ret)[:n_ret] + [ret[-1]]
999
+ # return ret
1000
 
1001
  if USE_HUGGINGFACE_ZEROGPU:
1002
  @spaces.GPU(duration=30)
 
1213
  advanced=False,
1214
  directed=False,
1215
  ):
1216
+ # print(node_type2, head_index_text, make_symmetric)
1217
  progress=gr.Progress()
1218
  progress(0, desc="Starting")
1219