huzey commited on
Commit
b7c5735
1 Parent(s): b1e189a

add advanced compare models

Browse files
Files changed (1) hide show
  1. app.py +82 -2
app.py CHANGED
@@ -874,8 +874,12 @@ def load_alignedthreemodel():
874
  model = ThreeAttnNodes(align_weights)
875
 
876
  return model
877
- # pre-load the alignedthree model in case it fails to load
878
- load_alignedthreemodel()
 
 
 
 
879
 
880
  promptable_diffusion_models = ["Diffusion(stabilityai/stable-diffusion-2)", "Diffusion(CompVis/stable-diffusion-v1-4)"]
881
  promptable_segmentation_models = ["LISA(xinlai/LISA-7B-v1)"]
@@ -2029,6 +2033,81 @@ with demo:
2029
  buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
2030
  buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
2031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2032
 
2033
  with gr.Tab('📄About'):
2034
  with gr.Column():
@@ -2087,6 +2166,7 @@ with demo:
2087
  hidden_button.change(unlock_tabs_with_info, n_smiles, tab_alignedcut_advanced)
2088
  hidden_button.change(unlock_tabs, n_smiles, tab_model_aligned_advanced)
2089
  hidden_button.change(unlock_tabs, n_smiles, tab_recursivecut_advanced)
 
2090
 
2091
  with gr.Row():
2092
  with gr.Column():
 
874
  model = ThreeAttnNodes(align_weights)
875
 
876
  return model
877
+
878
+ try:
879
+ # pre-load the alignedthree model in case it fails to load
880
+ load_alignedthreemodel()
881
+ except Exception as e:
882
+ pass
883
 
884
  promptable_diffusion_models = ["Diffusion(stabilityai/stable-diffusion-2)", "Diffusion(CompVis/stable-diffusion-v1-4)"]
885
  promptable_segmentation_models = ["LISA(xinlai/LISA-7B-v1)"]
 
2033
  buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
2034
  buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
2035
 
2036
+ with gr.Tab('Compare Models (Advanced)', visible=False) as tab_compare_models_advanced:
2037
+ def add_one_model(i_model=1):
2038
+ with gr.Column(scale=5, min_width=200) as col:
2039
+ gr.Markdown(f'### Output Images')
2040
+ output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
2041
+ submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
2042
+ add_output_images_buttons(output_gallery)
2043
+ norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
2044
+ cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height=600, show_share_button=True, preview=True, interactive=False)
2045
+ [
2046
+ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
2047
+ affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
2048
+ embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
2049
+ perplexity_slider, n_neighbors_slider, min_dist_slider,
2050
+ sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
2051
+ ] = make_parameters_section()
2052
+ # logging text box
2053
+ logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
2054
+ false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
2055
+ no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
2056
+
2057
+ submit_button.click(
2058
+ partial(run_fn, n_ret=3, plot_clusters=True, alignedcut_eig_norm_plot=True, advanced=True),
2059
+ inputs=[
2060
+ input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
2061
+ positive_prompt, negative_prompt,
2062
+ false_placeholder, no_prompt, no_prompt, no_prompt,
2063
+ affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
2064
+ embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
2065
+ perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown
2066
+ ],
2067
+ outputs=[output_gallery, cluster_gallery, norm_gallery, logging_text]
2068
+ )
2069
+
2070
+ return col
2071
+
2072
+ with gr.Row():
2073
+ with gr.Column(scale=5, min_width=200):
2074
+ input_gallery, submit_button, clear_images_button = make_input_images_section()
2075
+ clear_images_button.click(lambda x: [], outputs=[input_gallery])
2076
+ submit_button.visible = False
2077
+ dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=True)
2078
+
2079
+
2080
+ for i in range(2):
2081
+ add_one_model()
2082
+
2083
+ # Create rows and buttons in a loop
2084
+ rows = []
2085
+ buttons = []
2086
+
2087
+ for i in range(4):
2088
+ row = gr.Row(visible=False)
2089
+ rows.append(row)
2090
+
2091
+ with row:
2092
+ for j in range(3):
2093
+ with gr.Column(scale=5, min_width=200):
2094
+ add_one_model()
2095
+
2096
+ button = gr.Button("➕ Add Compare", elem_id=f"add_button_{i}", visible=False if i > 0 else True, scale=3)
2097
+ buttons.append(button)
2098
+
2099
+ if i > 0:
2100
+ # Reveal the current row and next button
2101
+ buttons[i - 1].click(fn=lambda x: gr.update(visible=True), outputs=row)
2102
+ buttons[i - 1].click(fn=lambda x: gr.update(visible=True), outputs=button)
2103
+
2104
+ # Hide the current button
2105
+ buttons[i - 1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[i - 1])
2106
+
2107
+ # Last button only reveals the last row and hides itself
2108
+ buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
2109
+ buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
2110
+
2111
 
2112
  with gr.Tab('📄About'):
2113
  with gr.Column():
 
2166
  hidden_button.change(unlock_tabs_with_info, n_smiles, tab_alignedcut_advanced)
2167
  hidden_button.change(unlock_tabs, n_smiles, tab_model_aligned_advanced)
2168
  hidden_button.change(unlock_tabs, n_smiles, tab_recursivecut_advanced)
2169
+ hidden_button.change(unlock_tabs, n_smiles, tab_compare_models_advanced)
2170
 
2171
  with gr.Row():
2172
  with gr.Column():