bhlewis commited on
Commit
8e3cac5
·
verified ·
1 Parent(s): 61b33b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -187
app.py CHANGED
@@ -1,229 +1,205 @@
1
  import gradio as gr
2
- import pandas as pd
3
  import numpy as np
4
  import h5py
 
5
  import json
6
- import os
7
- import tempfile
 
8
  import re
9
- import time
10
- import logging
11
- from sentence_transformers import SentenceTransformer
12
  from nltk.corpus import stopwords
13
  from nltk.tokenize import word_tokenize
14
  import nltk
15
- import torch
16
- from sklearn.feature_extraction.text import CountVectorizer
17
 
18
- # Set up logging
19
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
20
-
21
- # Ensure you have downloaded the necessary NLTK data
22
  nltk.download('stopwords', quiet=True)
23
  nltk.download('punkt', quiet=True)
24
 
25
- # Disable tokenizer parallelism warning
26
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
 
27
 
28
- # Check for GPU availability
29
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
30
 
31
- # Load pre-trained model from Hugging Face
32
- logging.info("Loading SentenceTransformer model...")
33
- model = SentenceTransformer('anferico/bert-for-patents').to(device)
34
- logging.info("SentenceTransformer model loaded successfully.")
 
 
 
 
 
35
 
36
- def preprocess_text(text):
37
- # Remove "[EN]" label and claim numbers
38
- text = re.sub(r'\[EN\]\s*', '', text)
39
- text = re.sub(r'^\d+\.\s*', '', text, flags=re.MULTILINE)
40
-
41
- # Convert to lowercase while preserving acronyms and units
42
- words = text.split()
43
- text = ' '.join(word if word.isupper() or re.match(r'^\d+(\.\d+)?[a-zA-Z]+$', word) else word.lower() for word in words)
44
 
45
- # Remove special characters except hyphens and periods in numbers
46
- text = re.sub(r'[^\w\s\-.]', ' ', text)
47
- text = re.sub(r'(?<!\d)\.(?!\d)', ' ', text) # Remove periods not in numbers
48
 
49
- # Normalize spaces
50
- text = re.sub(r'\s+', ' ', text).strip()
51
 
52
  # Tokenize
53
  tokens = word_tokenize(text)
54
 
55
- # Remove stopwords
56
  stop_words = set(stopwords.words('english'))
57
- tokens = [word for word in tokens if word.lower() not in stop_words]
58
 
59
- # Join tokens back into text
60
- text = ' '.join(tokens)
61
 
62
- # Preserve numerical values with units
63
- text = re.sub(r'(\d+(\.\d+)?)([a-zA-Z]+)', r'\1_\3', text)
64
 
65
- # Handle ranges and measurements
66
- text = re.sub(r'(\d+(\.\d+)?)(\s*to\s*)(\d+(\.\d+)?)(\s*[a-zA-Z]+)', r'\1_to_\4_\6', text)
67
- text = re.sub(r'between\s*(\d+(\.\d+)?)(\s*and\s*)(\d+(\.\d+)?)\s*([a-zA-Z]+)', r'between_\1_and_\4_\5', text)
68
-
69
- # Preserve chemical formulas
70
- text = re.sub(r'\b([A-Z][a-z]?\d*)+\b', lambda m: m.group().replace(' ', ''), text)
71
-
72
- return text
73
 
74
- def filter_common_terms(texts, threshold=0.10):
75
- vectorizer = CountVectorizer()
76
- X = vectorizer.fit_transform(texts)
77
- term_frequencies = np.sum(X.toarray(), axis=0)
78
- document_frequencies = np.sum(X.toarray() > 0, axis=0)
79
- num_documents = X.shape[0]
80
 
81
- common_terms = set()
82
- removed_words = {}
83
- for term, doc_freq in zip(vectorizer.get_feature_names_out(), document_frequencies):
84
- if doc_freq / num_documents > threshold:
85
- common_terms.add(term)
86
- removed_words[term] = doc_freq
87
 
88
- filtered_texts = []
89
- for text in texts:
90
- filtered_text = ' '.join([word for word in text.split() if word not in common_terms])
91
- filtered_texts.append(filtered_text)
92
 
93
- return filtered_texts, removed_words
94
 
95
- def encode_texts(texts, progress=gr.Progress(), batch_size=64):
96
- embeddings = []
97
- total_batches = len(texts) // batch_size + (1 if len(texts) % batch_size != 0 else 0)
98
-
99
- for i in range(0, len(texts), batch_size):
100
- batch_texts = texts[i:i+batch_size]
101
- batch_texts = [str(text) for text in batch_texts]
102
- batch_embeddings = model.encode(batch_texts, show_progress_bar=True)
103
- embeddings.extend(batch_embeddings)
104
- progress((i // batch_size + 1) / total_batches, f"Processing batch {i // batch_size + 1}/{total_batches}")
105
-
106
- embeddings = np.array(embeddings)
107
- embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
108
- return embeddings
109
 
110
- def process_file(file, progress=gr.Progress()):
111
  try:
112
- start_time = time.time()
113
-
114
- # Read CSV file
115
- df = pd.read_csv(file.name, encoding='utf-8')
116
- logging.info(f"CSV file read successfully. Shape: {df.shape}")
117
-
118
- required_columns = ['Master Patent Number', 'Abstract', 'Claims']
119
- missing_columns = [col for col in required_columns if col not in df.columns]
120
- if missing_columns:
121
- return None, None, None, f"Error: Missing columns: {', '.join(missing_columns)}"
122
-
123
- valid_texts = []
124
- valid_patent_numbers = []
125
- skipped_rows = []
126
- error_rows = []
127
- total_rows = len(df)
128
-
129
- for index, row in df.iterrows():
130
- try:
131
- progress((index + 1) / total_rows, f"Processing row {index + 1}/{total_rows}")
132
- logging.info(f"Processing row {index + 1}/{total_rows}")
133
-
134
- abstract = row['Abstract'] if pd.notna(row['Abstract']) else ''
135
- claims = row['Claims'] if pd.notna(row['Claims']) else ''
136
-
137
- if not abstract and not claims:
138
- skipped_rows.append(row['Master Patent Number'])
139
- continue
140
-
141
- # Preprocess the abstract and claims separately
142
- preprocessed_abstract = preprocess_text(abstract)
143
- preprocessed_claims = preprocess_text(claims)
144
-
145
- # Combine preprocessed abstract and claims
146
- combined_text = preprocessed_abstract + ' ' + preprocessed_claims
147
-
148
- valid_texts.append(combined_text)
149
- valid_patent_numbers.append(str(row['Master Patent Number']))
150
-
151
- except Exception as e:
152
- error_message = f"Error processing row {index + 1}: {str(e)}"
153
- logging.error(error_message)
154
- error_rows.append((index, row['Master Patent Number'], error_message))
155
- continue
156
-
157
- logging.info(f"Preprocessed abstracts and claims. Number of valid texts: {len(valid_texts)}")
158
-
159
- if skipped_rows:
160
- logging.info(f"Skipped {len(skipped_rows)} rows due to missing abstract and claims.")
161
- if error_rows:
162
- logging.info(f"Encountered errors in {len(error_rows)} rows.")
163
 
164
- # Filter out common terms
165
- logging.info("Filtering common terms...")
166
- filtered_texts, removed_words = filter_common_terms(valid_texts, threshold=0.10)
 
 
 
 
167
 
168
- # Generate removed words file
169
- removed_words_file = 'removed_words.txt'
170
- with open(removed_words_file, 'w', encoding='utf-8') as f:
171
- for word, count in sorted(removed_words.items(), key=lambda x: x[1], reverse=True):
172
- f.write(f"{word}: {count}\n")
173
 
174
- logging.info("Encoding texts...")
175
- embeddings = encode_texts(filtered_texts, progress)
176
- logging.info("Texts encoded successfully.")
177
-
178
- # Save embeddings and metadata
179
- embeddings_file = tempfile.NamedTemporaryFile(delete=False, suffix='.h5').name
180
- with h5py.File(embeddings_file, 'w') as f:
181
- f.create_dataset('embeddings', data=embeddings)
182
- f.create_dataset('patent_numbers', data=valid_patent_numbers)
183
-
184
- metadata_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jsonl').name
185
- with open(metadata_file, 'w', encoding='utf-8') as f:
186
- for index, (patent_number, text) in enumerate(zip(valid_patent_numbers, filtered_texts)):
187
- json.dump({
188
- 'index': index,
189
- 'patent_number': patent_number,
190
- 'text': text,
191
- 'embedding_index': index
192
- }, f, ensure_ascii=False)
193
- f.write('\n')
194
-
195
- end_time = time.time()
196
- total_time = end_time - start_time
197
- logging.info(f"Processing completed in {total_time:.2f} seconds.")
198
-
199
- # Save error log
200
- error_log_file = 'error_log.txt'
201
- with open(error_log_file, 'w', encoding='utf-8') as f:
202
- for row in error_rows:
203
- f.write(f"Row {row[0]}, Patent {row[1]}: {row[2]}\n")
204
-
205
- return embeddings_file, metadata_file, removed_words_file, f"Processing complete. Encoded {len(filtered_texts)} patents. Skipped {len(skipped_rows)} patents due to missing data. Errors in {len(error_rows)} rows. See error_log.txt for details."
206
-
207
  except Exception as e:
208
- logging.error(f"An error occurred: {e}")
209
- import traceback
210
- traceback.print_exc()
211
- return None, None, None, f"An error occurred: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
 
213
  iface = gr.Interface(
214
- fn=process_file,
215
- inputs=gr.File(label="Upload a CSV file with patent data"),
216
- outputs=[
217
- gr.File(label="Patent Embeddings (HDF5)"),
218
- gr.File(label="Patent Metadata (JSONL)"),
219
- gr.File(label="Removed Words List (TXT)"),
220
- gr.Textbox(label="Processing Status")
221
  ],
222
- title="Patent Text Encoder",
223
- description="Upload a CSV file containing patent data (must include 'Master Patent Number', 'Abstract', and 'Claims' columns). The app will generate embeddings and save them along with metadata as downloadable files.",
224
- allow_flagging="never",
225
- cache_examples=False,
226
  )
227
 
228
  if __name__ == "__main__":
229
- iface.launch()
 
1
  import gradio as gr
 
2
  import numpy as np
3
  import h5py
4
+ import faiss
5
  import json
6
+ from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM
7
+ from sklearn.feature_extraction.text import TfidfVectorizer
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
  import re
10
+ from collections import Counter
11
+ import torch
 
12
  from nltk.corpus import stopwords
13
  from nltk.tokenize import word_tokenize
14
  import nltk
 
 
15
 
16
+ # Download necessary NLTK data
 
 
 
17
  nltk.download('stopwords', quiet=True)
18
  nltk.download('punkt', quiet=True)
19
 
20
+ # Load BERT model for lemmatization
21
+ bert_lemma_model_name = "bert-base-uncased"
22
+ bert_lemma_tokenizer = AutoTokenizer.from_pretrained(bert_lemma_model_name)
23
+ bert_lemma_model = AutoModelForMaskedLM.from_pretrained(bert_lemma_model_name).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
24
 
25
+ # Load BERT model for encoding search queries
26
+ bert_encode_model_name = 'anferico/bert-for-patents'
27
+ bert_encode_tokenizer = AutoTokenizer.from_pretrained(bert_encode_model_name)
28
+ bert_encode_model = AutoModel.from_pretrained(bert_encode_model_name)
29
 
30
+ def bert_lemmatize(text):
31
+ tokens = bert_lemma_tokenizer.tokenize(text)
32
+ input_ids = bert_lemma_tokenizer.convert_tokens_to_ids(tokens)
33
+ input_tensor = torch.tensor([input_ids]).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
34
+ with torch.no_grad():
35
+ outputs = bert_lemma_model(input_tensor)
36
+ predictions = outputs.logits.argmax(dim=-1)
37
+ lemmatized_tokens = bert_lemma_tokenizer.convert_ids_to_tokens(predictions[0])
38
+ return ' '.join([token for token in lemmatized_tokens if token not in ['[CLS]', '[SEP]', '[PAD]']])
39
 
40
+ def preprocess_query(text):
41
+ # Convert to lowercase
42
+ text = text.lower()
 
 
 
 
 
43
 
44
+ # Remove any HTML tags (if present)
45
+ text = re.sub('<.*?>', '', text)
 
46
 
47
+ # Remove special characters, but keep hyphens, periods, and commas
48
+ text = re.sub(r'[^a-zA-Z0-9\s\-\.\,]', '', text)
49
 
50
  # Tokenize
51
  tokens = word_tokenize(text)
52
 
53
+ # Remove stopwords, but keep all other words
54
  stop_words = set(stopwords.words('english'))
55
+ tokens = [word for word in tokens if word not in stop_words]
56
 
57
+ # Join tokens back into a string
58
+ processed_text = ' '.join(tokens)
59
 
60
+ # Apply BERT lemmatization
61
+ processed_text = bert_lemmatize(processed_text)
62
 
63
+ return processed_text
 
 
 
 
 
 
 
64
 
65
+ def extract_key_features(text):
66
+ # For queries, we'll just preprocess and return all non-stopword terms
67
+ processed_text = preprocess_query(text)
 
 
 
68
 
69
+ # Split the processed text into individual terms
70
+ features = processed_text.split()
 
 
 
 
71
 
72
+ # Remove duplicates while preserving order
73
+ features = list(dict.fromkeys(features))
 
 
74
 
75
+ return features
76
 
77
+ def encode_texts(texts, max_length=512):
78
+ inputs = bert_encode_tokenizer(texts, padding=True, truncation=True, max_length=max_length, return_tensors='pt')
79
+ with torch.no_grad():
80
+ outputs = bert_encode_model(**inputs)
81
+ embeddings = outputs.last_hidden_state.mean(dim=1)
82
+ return embeddings.numpy()
 
 
 
 
 
 
 
 
83
 
84
+ def load_data():
85
  try:
86
+ with h5py.File('patent_embeddings.h5', 'r') as f:
87
+ embeddings = f['embeddings'][:]
88
+ patent_numbers = f['patent_numbers'][:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ metadata = {}
91
+ texts = []
92
+ with open('patent_metadata.jsonl', 'r') as f:
93
+ for line in f:
94
+ data = json.loads(line)
95
+ metadata[data['patent_number']] = data
96
+ texts.append(data['text'])
97
 
98
+ print(f"Embedding shape: {embeddings.shape}")
99
+ print(f"Number of patent numbers: {len(patent_numbers)}")
100
+ print(f"Number of metadata entries: {len(metadata)}")
 
 
101
 
102
+ return embeddings, patent_numbers, metadata, texts
103
+ except FileNotFoundError as e:
104
+ print(f"Error: Could not find file. {e}")
105
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  except Exception as e:
107
+ print(f"An unexpected error occurred while loading data: {e}")
108
+ raise
109
+
110
+ def compare_features(query_features, patent_features):
111
+ common_features = set(query_features) & set(patent_features)
112
+ similarity_score = len(common_features) / max(len(query_features), len(patent_features))
113
+ return common_features, similarity_score
114
+
115
+ def hybrid_search(query, top_k=5):
116
+ print(f"Original query: {query}")
117
+
118
+ processed_query = preprocess_query(query)
119
+ query_features = extract_key_features(processed_query)
120
+
121
+ # Encode the processed query using the transformer model
122
+ query_embedding = encode_texts([processed_query])[0]
123
+ query_embedding = query_embedding / np.linalg.norm(query_embedding)
124
+
125
+ # Perform semantic similarity search
126
+ semantic_distances, semantic_indices = index.search(np.array([query_embedding]).astype('float32'), top_k * 2)
127
+
128
+ # Perform TF-IDF based search
129
+ query_tfidf = tfidf_vectorizer.transform([processed_query])
130
+ tfidf_similarities = cosine_similarity(query_tfidf, tfidf_matrix).flatten()
131
+ tfidf_indices = tfidf_similarities.argsort()[-top_k * 2:][::-1]
132
+
133
+ # Combine and rank results
134
+ combined_results = {}
135
+ for i, idx in enumerate(semantic_indices[0]):
136
+ patent_number = patent_numbers[idx].decode('utf-8')
137
+ text = metadata[patent_number]['text']
138
+ patent_features = extract_key_features(text)
139
+ common_features, feature_similarity = compare_features(query_features, patent_features)
140
+ combined_results[patent_number] = {
141
+ 'score': semantic_distances[0][i] * 1.0 + tfidf_similarities[idx] * 0.5 + feature_similarity,
142
+ 'common_features': common_features,
143
+ 'text': text
144
+ }
145
+
146
+ for idx in tfidf_indices:
147
+ patent_number = patent_numbers[idx].decode('utf-8')
148
+ if patent_number not in combined_results:
149
+ text = metadata[patent_number]['text']
150
+ patent_features = extract_key_features(text)
151
+ common_features, feature_similarity = compare_features(query_features, patent_features)
152
+ combined_results[patent_number] = {
153
+ 'score': tfidf_similarities[idx] * 1.0 + feature_similarity,
154
+ 'common_features': common_features,
155
+ 'text': text
156
+ }
157
+
158
+ # Sort and get top results
159
+ top_results = sorted(combined_results.items(), key=lambda x: x[1]['score'], reverse=True)[:top_k]
160
+
161
+ results = []
162
+ for patent_number, data in top_results:
163
+ result = f"Patent Number: {patent_number}\n"
164
+ result += f"Text: {data['text'][:200]}...\n"
165
+ result += f"Combined Score: {data['score']:.4f}\n"
166
+ result += f"Common Key Features: {', '.join(data['common_features'])}\n\n"
167
+ results.append(result)
168
+
169
+ return "\n".join(results)
170
+
171
+ # Load data and prepare the FAISS index
172
+ embeddings, patent_numbers, metadata, texts = load_data()
173
+
174
+ # Check if the embedding dimensions match
175
+ if embeddings.shape[1] != encode_texts(["test"]).shape[1]:
176
+ print("Embedding dimensions do not match. Rebuilding FAISS index.")
177
+ # Rebuild embeddings using the new model
178
+ embeddings = encode_texts(texts)
179
+ embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
180
+
181
+ # Normalize embeddings for cosine similarity
182
+ embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
183
+
184
+ # Create FAISS index for cosine similarity
185
+ index = faiss.IndexFlatIP(embeddings.shape[1])
186
+ index.add(embeddings)
187
+
188
+ # Create TF-IDF vectorizer
189
+ tfidf_vectorizer = TfidfVectorizer(stop_words='english')
190
+ tfidf_matrix = tfidf_vectorizer.fit_transform(texts)
191
 
192
+ # Create Gradio interface with additional input fields
193
  iface = gr.Interface(
194
+ fn=hybrid_search,
195
+ inputs=[
196
+ gr.Textbox(lines=2, placeholder="Enter your patent query here..."),
197
+ gr.Slider(minimum=1, maximum=20, step=1, value=5, label="Top K Results"),
 
 
 
198
  ],
199
+ outputs=gr.Textbox(lines=10, label="Search Results"),
200
+ title="Patent Similarity Search",
201
+ description="Enter a patent description to find similar patents based on key features."
 
202
  )
203
 
204
  if __name__ == "__main__":
205
+ iface.launch()