huzey commited on
Commit
d90b17f
1 Parent(s): a5faf16

fix eig norm

Browse files
Files changed (1) hide show
  1. app.py +14 -7
app.py CHANGED
@@ -870,6 +870,7 @@ def ncut_run(
870
  if not directed:
871
  only_eigvecs = kwargs.get("only_eigvecs", False)
872
  return_eigvec_and_rgb = kwargs.get("return_eigvec_and_rgb", False)
 
873
 
874
  rgb, _logging_str, eigvecs = compute_ncut(
875
  features,
@@ -893,12 +894,16 @@ def ncut_run(
893
 
894
 
895
  if only_eigvecs:
 
 
896
  eigvecs = eigvecs.to("cpu").reshape(features.shape[:-1] + (num_eig,))
897
  eigvecs = eigvecs.detach().numpy()
898
  logging_str += _logging_str
899
  return eigvecs, logging_str
900
 
901
  if return_eigvec_and_rgb:
 
 
902
  eigvecs = eigvecs.to("cpu").reshape(features.shape[:-1] + (num_eig,))
903
  eigvecs = eigvecs.detach().numpy()
904
  rgb = rgb.cpu().numpy()
@@ -1249,6 +1254,7 @@ def run_fn(
1249
  directed=False,
1250
  only_eigvecs=False,
1251
  return_eigvec_and_rgb=False,
 
1252
  ):
1253
  # print(node_type2, head_index_text, make_symmetric)
1254
  progress=gr.Progress()
@@ -1391,6 +1397,7 @@ def run_fn(
1391
  "make_symmetric": make_symmetric,
1392
  "only_eigvecs": only_eigvecs,
1393
  "return_eigvec_and_rgb": return_eigvec_and_rgb,
 
1394
  }
1395
  # print(kwargs)
1396
 
@@ -2232,16 +2239,16 @@ with demo:
2232
  def __run_fn(*args, **kwargs):
2233
  eigvecs, rgb, logging_str = run_fn(*args, **kwargs)
2234
  rgb_gallery = to_pil_images(rgb)
2235
- # normalize the eigvecs
2236
- eigvecs = torch.tensor(eigvecs)
2237
- if torch.cuda.is_available():
2238
- eigvecs = eigvecs.cuda()
2239
- eigvecs = F.normalize(eigvecs, p=2, dim=-1)
2240
- eigvecs = eigvecs.cpu().numpy()
2241
  return eigvecs, rgb, rgb_gallery, logging_str
2242
 
2243
  submit_button.click(
2244
- partial(__run_fn, n_ret=2, return_eigvec_and_rgb=True),
2245
  inputs=[
2246
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
2247
  positive_prompt, negative_prompt,
 
870
  if not directed:
871
  only_eigvecs = kwargs.get("only_eigvecs", False)
872
  return_eigvec_and_rgb = kwargs.get("return_eigvec_and_rgb", False)
873
+ normalize_eigvec_return = kwargs.get("normalize_eigvec_return", False)
874
 
875
  rgb, _logging_str, eigvecs = compute_ncut(
876
  features,
 
894
 
895
 
896
  if only_eigvecs:
897
+ if normalize_eigvec_return:
898
+ eigvecs = F.normalize(eigvecs, dim=-1)
899
  eigvecs = eigvecs.to("cpu").reshape(features.shape[:-1] + (num_eig,))
900
  eigvecs = eigvecs.detach().numpy()
901
  logging_str += _logging_str
902
  return eigvecs, logging_str
903
 
904
  if return_eigvec_and_rgb:
905
+ if normalize_eigvec_return:
906
+ eigvecs = F.normalize(eigvecs, dim=-1)
907
  eigvecs = eigvecs.to("cpu").reshape(features.shape[:-1] + (num_eig,))
908
  eigvecs = eigvecs.detach().numpy()
909
  rgb = rgb.cpu().numpy()
 
1254
  directed=False,
1255
  only_eigvecs=False,
1256
  return_eigvec_and_rgb=False,
1257
+ normalize_eigvec_return=False,
1258
  ):
1259
  # print(node_type2, head_index_text, make_symmetric)
1260
  progress=gr.Progress()
 
1397
  "make_symmetric": make_symmetric,
1398
  "only_eigvecs": only_eigvecs,
1399
  "return_eigvec_and_rgb": return_eigvec_and_rgb,
1400
+ "normalize_eigvec_return": normalize_eigvec_return,
1401
  }
1402
  # print(kwargs)
1403
 
 
2239
  def __run_fn(*args, **kwargs):
2240
  eigvecs, rgb, logging_str = run_fn(*args, **kwargs)
2241
  rgb_gallery = to_pil_images(rgb)
2242
+ # # normalize the eigvecs
2243
+ # eigvecs = torch.tensor(eigvecs)
2244
+ # if torch.cuda.is_available():
2245
+ # eigvecs = eigvecs.cuda()
2246
+ # eigvecs = F.normalize(eigvecs, p=2, dim=-1)
2247
+ # eigvecs = eigvecs.cpu().numpy()
2248
  return eigvecs, rgb, rgb_gallery, logging_str
2249
 
2250
  submit_button.click(
2251
+ partial(__run_fn, n_ret=2, return_eigvec_and_rgb=True, normalize_eigvec_return=True),
2252
  inputs=[
2253
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
2254
  positive_prompt, negative_prompt,