Spaces:
Running
on
Zero
Running
on
Zero
fix eig norm
Browse files
app.py
CHANGED
@@ -870,6 +870,7 @@ def ncut_run(
|
|
870 |
if not directed:
|
871 |
only_eigvecs = kwargs.get("only_eigvecs", False)
|
872 |
return_eigvec_and_rgb = kwargs.get("return_eigvec_and_rgb", False)
|
|
|
873 |
|
874 |
rgb, _logging_str, eigvecs = compute_ncut(
|
875 |
features,
|
@@ -893,12 +894,16 @@ def ncut_run(
|
|
893 |
|
894 |
|
895 |
if only_eigvecs:
|
|
|
|
|
896 |
eigvecs = eigvecs.to("cpu").reshape(features.shape[:-1] + (num_eig,))
|
897 |
eigvecs = eigvecs.detach().numpy()
|
898 |
logging_str += _logging_str
|
899 |
return eigvecs, logging_str
|
900 |
|
901 |
if return_eigvec_and_rgb:
|
|
|
|
|
902 |
eigvecs = eigvecs.to("cpu").reshape(features.shape[:-1] + (num_eig,))
|
903 |
eigvecs = eigvecs.detach().numpy()
|
904 |
rgb = rgb.cpu().numpy()
|
@@ -1249,6 +1254,7 @@ def run_fn(
|
|
1249 |
directed=False,
|
1250 |
only_eigvecs=False,
|
1251 |
return_eigvec_and_rgb=False,
|
|
|
1252 |
):
|
1253 |
# print(node_type2, head_index_text, make_symmetric)
|
1254 |
progress=gr.Progress()
|
@@ -1391,6 +1397,7 @@ def run_fn(
|
|
1391 |
"make_symmetric": make_symmetric,
|
1392 |
"only_eigvecs": only_eigvecs,
|
1393 |
"return_eigvec_and_rgb": return_eigvec_and_rgb,
|
|
|
1394 |
}
|
1395 |
# print(kwargs)
|
1396 |
|
@@ -2232,16 +2239,16 @@ with demo:
|
|
2232 |
def __run_fn(*args, **kwargs):
|
2233 |
eigvecs, rgb, logging_str = run_fn(*args, **kwargs)
|
2234 |
rgb_gallery = to_pil_images(rgb)
|
2235 |
-
# normalize the eigvecs
|
2236 |
-
eigvecs = torch.tensor(eigvecs)
|
2237 |
-
if torch.cuda.is_available():
|
2238 |
-
|
2239 |
-
eigvecs = F.normalize(eigvecs, p=2, dim=-1)
|
2240 |
-
eigvecs = eigvecs.cpu().numpy()
|
2241 |
return eigvecs, rgb, rgb_gallery, logging_str
|
2242 |
|
2243 |
submit_button.click(
|
2244 |
-
partial(__run_fn, n_ret=2, return_eigvec_and_rgb=True),
|
2245 |
inputs=[
|
2246 |
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
2247 |
positive_prompt, negative_prompt,
|
|
|
870 |
if not directed:
|
871 |
only_eigvecs = kwargs.get("only_eigvecs", False)
|
872 |
return_eigvec_and_rgb = kwargs.get("return_eigvec_and_rgb", False)
|
873 |
+
normalize_eigvec_return = kwargs.get("normalize_eigvec_return", False)
|
874 |
|
875 |
rgb, _logging_str, eigvecs = compute_ncut(
|
876 |
features,
|
|
|
894 |
|
895 |
|
896 |
if only_eigvecs:
|
897 |
+
if normalize_eigvec_return:
|
898 |
+
eigvecs = F.normalize(eigvecs, dim=-1)
|
899 |
eigvecs = eigvecs.to("cpu").reshape(features.shape[:-1] + (num_eig,))
|
900 |
eigvecs = eigvecs.detach().numpy()
|
901 |
logging_str += _logging_str
|
902 |
return eigvecs, logging_str
|
903 |
|
904 |
if return_eigvec_and_rgb:
|
905 |
+
if normalize_eigvec_return:
|
906 |
+
eigvecs = F.normalize(eigvecs, dim=-1)
|
907 |
eigvecs = eigvecs.to("cpu").reshape(features.shape[:-1] + (num_eig,))
|
908 |
eigvecs = eigvecs.detach().numpy()
|
909 |
rgb = rgb.cpu().numpy()
|
|
|
1254 |
directed=False,
|
1255 |
only_eigvecs=False,
|
1256 |
return_eigvec_and_rgb=False,
|
1257 |
+
normalize_eigvec_return=False,
|
1258 |
):
|
1259 |
# print(node_type2, head_index_text, make_symmetric)
|
1260 |
progress=gr.Progress()
|
|
|
1397 |
"make_symmetric": make_symmetric,
|
1398 |
"only_eigvecs": only_eigvecs,
|
1399 |
"return_eigvec_and_rgb": return_eigvec_and_rgb,
|
1400 |
+
"normalize_eigvec_return": normalize_eigvec_return,
|
1401 |
}
|
1402 |
# print(kwargs)
|
1403 |
|
|
|
2239 |
def __run_fn(*args, **kwargs):
|
2240 |
eigvecs, rgb, logging_str = run_fn(*args, **kwargs)
|
2241 |
rgb_gallery = to_pil_images(rgb)
|
2242 |
+
# # normalize the eigvecs
|
2243 |
+
# eigvecs = torch.tensor(eigvecs)
|
2244 |
+
# if torch.cuda.is_available():
|
2245 |
+
# eigvecs = eigvecs.cuda()
|
2246 |
+
# eigvecs = F.normalize(eigvecs, p=2, dim=-1)
|
2247 |
+
# eigvecs = eigvecs.cpu().numpy()
|
2248 |
return eigvecs, rgb, rgb_gallery, logging_str
|
2249 |
|
2250 |
submit_button.click(
|
2251 |
+
partial(__run_fn, n_ret=2, return_eigvec_and_rgb=True, normalize_eigvec_return=True),
|
2252 |
inputs=[
|
2253 |
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
2254 |
positive_prompt, negative_prompt,
|