Mark7549 commited on
Commit
910ea33
·
1 Parent(s): 6cd05ce

put nearest neighbours function into a form (input process is faster now)

Browse files
Files changed (1) hide show
  1. app.py +41 -52
app.py CHANGED
@@ -24,26 +24,19 @@ lemma_dict = json.load(open('lsj_dict.json', 'r'))
24
 
25
  # Nearest neighbours tab
26
  if active_tab == "Nearest neighbours":
27
- st.write("### TO DO: add description of function")
28
- col1, col2 = st.columns(2)
29
 
30
  # Load the compressed word list
31
  compressed_word_list_filename = 'corpora/compass_filtered.pkl.gz'
32
  all_words = load_compressed_word_list(compressed_word_list_filename)
33
  eligible_models = ["Archaic", "Classical", "Hellenistic", "Early Roman", "Late Roman"]
34
 
35
- if 'nearest_neighbours' not in st.session_state:
36
- st.session_state.nearest_neighbours = False
37
-
38
- with st.container():
39
-
40
- word = st.multiselect("Enter a word", all_words, max_selections=1)
41
- if len(word) > 0:
42
- word = word[0]
43
-
44
- # Check which models contain the word
45
- eligible_models = check_word_in_models(word)
46
 
 
47
 
48
  models = st.multiselect(
49
  "Select models to search for neighbours",
@@ -51,49 +44,45 @@ if active_tab == "Nearest neighbours":
51
  )
52
  n = st.slider("Number of neighbours", 1, 50, 15)
53
 
54
- nearest_neighbours_button = st.button("Find nearest neighbours", on_click = click_nn_button)
55
 
56
- # If the button to calculate nearest neighbours is clicked
57
- if st.session_state.nearest_neighbours:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- # Check if all fields are filled in
60
- if validate_nearest_neighbours(word, n, models) == False:
61
- st.error('Please fill in all fields')
62
- else:
63
- # Rewrite models to list of all loaded models
64
- models = load_selected_models(models)
65
-
66
- nearest_neighbours = get_nearest_neighbours(word, n, models)
67
 
68
- all_dfs = []
69
-
70
- # Create dataframes
71
- for model in nearest_neighbours.keys():
72
- st.write(f"### {model}")
73
- df = pd.DataFrame(
74
- nearest_neighbours[model],
75
- columns = ['Word', 'Cosine Similarity']
76
  )
77
-
78
- all_dfs.append((model, df))
79
- st.table(df)
80
-
81
-
82
- # Store content in a temporary file
83
- tmp_file = store_df_in_temp_file(all_dfs)
84
-
85
- # Open the temporary file and read its content
86
- with open(tmp_file, "rb") as file:
87
- file_byte = file.read()
88
-
89
- # Create download button
90
- st.download_button(
91
- "Download results",
92
- data=file_byte,
93
- file_name = f'nearest_neighbours_{word}_TEST.xlsx',
94
- mime='application/octet-stream'
95
- )
96
-
97
 
98
 
99
  # Cosine similarity tab
 
24
 
25
  # Nearest neighbours tab
26
  if active_tab == "Nearest neighbours":
 
 
27
 
28
  # Load the compressed word list
29
  compressed_word_list_filename = 'corpora/compass_filtered.pkl.gz'
30
  all_words = load_compressed_word_list(compressed_word_list_filename)
31
  eligible_models = ["Archaic", "Classical", "Hellenistic", "Early Roman", "Late Roman"]
32
 
33
+ with st.form("nn_form"):
34
+ st.markdown("## Nearest Neighbours")
35
+ target_word = st.multiselect("Enter a word", all_words, max_selections=1)
36
+ if len(target_word) > 0:
37
+ target_word = target_word[0]
 
 
 
 
 
 
38
 
39
+ eligible_models = check_word_in_models(target_word)
40
 
41
  models = st.multiselect(
42
  "Select models to search for neighbours",
 
44
  )
45
  n = st.slider("Number of neighbours", 1, 50, 15)
46
 
47
+ nearest_neighbours_button = st.form_submit_button("Find nearest neighbours", on_click = click_nn_button)
48
 
49
+ if nearest_neighbours_button:
50
+ if validate_nearest_neighbours(target_word, n, models) == False:
51
+ st.error('Please fill in all fields')
52
+ else:
53
+ # Rewrite models to list of all loaded models
54
+ models = load_selected_models(models)
55
+
56
+ nearest_neighbours = get_nearest_neighbours(target_word, n, models)
57
+
58
+ all_dfs = []
59
+
60
+ # Create dataframes
61
+ for model in nearest_neighbours.keys():
62
+ st.write(f"### {model}")
63
+ df = pd.DataFrame(
64
+ nearest_neighbours[model],
65
+ columns = ['Word', 'Cosine Similarity']
66
+ )
67
+
68
+ all_dfs.append((model, df))
69
+ st.table(df)
70
 
71
+
72
+ # Store content in a temporary file
73
+ tmp_file = store_df_in_temp_file(all_dfs)
74
+
75
+ # Open the temporary file and read its content
76
+ with open(tmp_file, "rb") as file:
77
+ file_byte = file.read()
 
78
 
79
+ # Create download button
80
+ st.download_button(
81
+ "Download results",
82
+ data=file_byte,
83
+ file_name = f'nearest_neighbours_{target_word}_TEST.xlsx',
84
+ mime='application/octet-stream'
 
 
85
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
 
88
  # Cosine similarity tab