Pringled commited on
Commit
2827b8a
·
1 Parent(s): 25d2eb7

Updated app with code for deduplication

Browse files
Files changed (2) hide show
  1. app.py +179 -4
  2. requirements.txt +6 -0
app.py CHANGED
@@ -1,7 +1,182 @@
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import gradio as gr
2
+
3
+ # def greet(name):
4
+ # return "Hello " + name + "!!"
5
+
6
+ # demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
+ # demo.launch()
8
+
9
+
10
  import gradio as gr
11
+ from datasets import load_dataset
12
+ import numpy as np
13
+ from model2vec import StaticModel
14
+ from reach import Reach
15
+ from tqdm import tqdm
16
+
17
+ def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024) -> tuple[np.ndarray, dict[int, int]]:
18
+ """
19
+ Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
20
+ """
21
+ reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
22
+
23
+ # Use a set for deduplicated indices and keep track of duplicates
24
+ deduplicated_indices = set(range(len(embedding_matrix))) # Start with all indices as deduplicated
25
+ duplicate_to_original_mapping = {}
26
+
27
+ results = reach.nearest_neighbor_threshold(
28
+ embedding_matrix,
29
+ threshold=threshold,
30
+ batch_size=batch_size,
31
+ show_progressbar=True
32
+ )
33
+
34
+ # Process duplicates
35
+ for i, similar_items in enumerate(tqdm(results)):
36
+ if i not in deduplicated_indices:
37
+ continue # Skip already marked duplicates
38
+
39
+ # Similar items are returned as (index, score), we are only interested in the index
40
+ similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
41
+
42
+ # Mark similar documents as duplicates and map them to the original
43
+ for sim_idx in similar_indices:
44
+ if sim_idx in deduplicated_indices:
45
+ deduplicated_indices.remove(sim_idx)
46
+ duplicate_to_original_mapping[sim_idx] = i # Map duplicate to original
47
+
48
+ return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
49
+
50
+ def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024) -> tuple[list[int], dict[int, int]]:
51
+ """
52
+ Deduplicate embeddings across two datasets and return the indices of duplicates between them.
53
+ """
54
+ reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
55
+
56
+ # Keep track of duplicates in the second dataset
57
+ duplicate_indices_in_test = []
58
+ duplicate_to_original_mapping = {}
59
+
60
+ # Find nearest neighbors from the test set in the train set
61
+ results = reach.nearest_neighbor_threshold(
62
+ embedding_matrix_2,
63
+ threshold=threshold,
64
+ batch_size=batch_size,
65
+ show_progressbar=True
66
+ )
67
+
68
+ # Process duplicates
69
+ for i, similar_items in enumerate(tqdm(results)):
70
+ # Similar items are returned as (index, score), we are only interested in the index
71
+ similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold] # Keep those above the threshold
72
+
73
+ # If we find a similar item in the train set, mark it as a duplicate
74
+ if similar_indices:
75
+ duplicate_indices_in_test.append(i)
76
+ duplicate_to_original_mapping[i] = similar_indices[0] # Map duplicate in test to original in train
77
+
78
+ return duplicate_indices_in_test, duplicate_to_original_mapping
79
 
80
+ def perform_deduplication(
81
+ deduplication_type,
82
+ dataset1_name,
83
+ dataset1_split,
84
+ dataset2_name,
85
+ dataset2_split,
86
+ threshold
87
+ ):
88
+ # Convert threshold to float
89
+ threshold = float(threshold)
90
+
91
+ if deduplication_type == "Single dataset":
92
+ # Load the dataset
93
+ ds = load_dataset(dataset1_name, split=dataset1_split)
94
+
95
+ # Extract texts
96
+ texts = [example['text'] for example in ds]
97
+
98
+ # Compute embeddings
99
+ model = StaticModel.from_pretrained("minishlab/M2V_base_output")
100
+ embedding_matrix = model.encode(texts, show_progressbar=True)
101
+
102
+ # Deduplicate
103
+ deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, threshold)
104
+
105
+ # Prepare the results
106
+ num_duplicates = len(duplicate_to_original_mapping)
107
+ num_total = len(texts)
108
+ num_deduplicated = len(deduplicated_indices)
109
+
110
+ result_text = f"**Total documents:** {num_total}\n"
111
+ result_text += f"**Number of duplicates found:** {num_duplicates}\n"
112
+ result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
113
+ result_text += f"**Deduplicated indices:** {deduplicated_indices.tolist()}\n\n"
114
+ result_text += f"**Duplicate to original mapping:** {duplicate_to_original_mapping}\n"
115
+
116
+ return result_text
117
+
118
+ elif deduplication_type == "Cross-dataset":
119
+ # Load datasets
120
+ ds1 = load_dataset(dataset1_name, split=dataset1_split)
121
+ ds2 = load_dataset(dataset2_name, split=dataset2_split)
122
+
123
+ # Extract texts
124
+ texts1 = [example['text'] for example in ds1]
125
+ texts2 = [example['text'] for example in ds2]
126
+
127
+ # Compute embeddings
128
+ model = StaticModel.from_pretrained("minishlab/M2V_base_output")
129
+ embedding_matrix1 = model.encode(texts1, show_progressbar=True)
130
+ embedding_matrix2 = model.encode(texts2, show_progressbar=True)
131
+
132
+ # Deduplicate across datasets
133
+ duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(embedding_matrix1, embedding_matrix2, threshold)
134
+
135
+ num_duplicates = len(duplicate_indices_in_ds2)
136
+ num_total_ds2 = len(texts2)
137
+ num_unique_ds2 = num_total_ds2 - num_duplicates
138
+
139
+ result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
140
+ result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
141
+ result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
142
+ result_text += f"**Duplicate indices in {dataset2_name}/{dataset2_split}:** {duplicate_indices_in_ds2}\n\n"
143
+ result_text += f"**Duplicate to original mapping:** {duplicate_to_original_mapping}\n"
144
+
145
+ return result_text
146
 
147
+ with gr.Blocks() as demo:
148
+ gr.Markdown("# Semantic Deduplication")
149
+
150
+ deduplication_type = gr.Radio(choices=["Single dataset", "Cross-dataset"], label="Deduplication Type", value="Single dataset")
151
+
152
+ with gr.Row():
153
+ dataset1_name = gr.Textbox(value="ag_news", label="Dataset 1 Name")
154
+ dataset1_split = gr.Textbox(value="train", label="Dataset 1 Split")
155
+
156
+ dataset2_row = gr.Row(visible=False)
157
+ with dataset2_row:
158
+ dataset2_name = gr.Textbox(value="ag_news", label="Dataset 2 Name")
159
+ dataset2_split = gr.Textbox(value="test", label="Dataset 2 Split")
160
+
161
+ threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.8, label="Similarity Threshold")
162
+
163
+ compute_button = gr.Button("Compute")
164
+
165
+ output = gr.Markdown()
166
+
167
+ # Function to update the visibility of dataset2_row
168
+ def update_visibility(deduplication_type):
169
+ if deduplication_type == "Cross-dataset":
170
+ return {dataset2_row: gr.update(visible=True)}
171
+ else:
172
+ return {dataset2_row: gr.update(visible=False)}
173
+
174
+ deduplication_type.change(update_visibility, inputs=deduplication_type, outputs=[dataset2_row])
175
+
176
+ compute_button.click(
177
+ fn=perform_deduplication,
178
+ inputs=[deduplication_type, dataset1_name, dataset1_split, dataset2_name, dataset2_split, threshold],
179
+ outputs=output
180
+ )
181
+
182
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ reach < 5
2
+ model2vec
3
+ numpy
4
+ datasets
5
+ tqdm
6
+