huzey commited on
Commit
6706a30
1 Parent(s): 560d63b

fix recursion gamma

Browse files
Files changed (1) hide show
  1. app.py +7 -1
app.py CHANGED
@@ -265,6 +265,7 @@ def ncut_run(
265
 
266
  if recursion:
267
  rgbs = []
 
268
  inp = features
269
  for i, n_eigs in enumerate([num_eig, recursion_l2_n_eigs, recursion_l3_n_eigs]):
270
  logging_str += f"Recursion #{i+1}\n"
@@ -272,7 +273,7 @@ def ncut_run(
272
  inp,
273
  num_eig=n_eigs,
274
  num_sample_ncut=num_sample_ncut,
275
- affinity_focal_gamma=affinity_focal_gamma,
276
  knn_ncut=knn_ncut,
277
  knn_tsne=knn_tsne,
278
  num_sample_tsne=num_sample_tsne,
@@ -352,9 +353,14 @@ def ncut_run(
352
 
353
  def _ncut_run(*args, **kwargs):
354
  try:
 
 
 
355
  ret = ncut_run(*args, **kwargs)
 
356
  if torch.cuda.is_available():
357
  torch.cuda.empty_cache()
 
358
  return ret
359
  except Exception as e:
360
  gr.Error(str(e))
 
265
 
266
  if recursion:
267
  rgbs = []
268
+ recursion_gammas = [recursion_l1_gamma, recursion_l2_gamma, recursion_l3_gamma]
269
  inp = features
270
  for i, n_eigs in enumerate([num_eig, recursion_l2_n_eigs, recursion_l3_n_eigs]):
271
  logging_str += f"Recursion #{i+1}\n"
 
273
  inp,
274
  num_eig=n_eigs,
275
  num_sample_ncut=num_sample_ncut,
276
+ affinity_focal_gamma=recursion_gammas[i],
277
  knn_ncut=knn_ncut,
278
  knn_tsne=knn_tsne,
279
  num_sample_tsne=num_sample_tsne,
 
353
 
354
  def _ncut_run(*args, **kwargs):
355
  try:
356
+ if torch.cuda.is_available():
357
+ torch.cuda.empty_cache()
358
+
359
  ret = ncut_run(*args, **kwargs)
360
+
361
  if torch.cuda.is_available():
362
  torch.cuda.empty_cache()
363
+
364
  return ret
365
  except Exception as e:
366
  gr.Error(str(e))