bhlewis commited on
Commit
a1d94cc
·
verified ·
1 Parent(s): ed780d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -70
app.py CHANGED
@@ -3,90 +3,80 @@ import numpy as np
3
  import h5py
4
  import faiss
5
  import json
6
- from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM
7
- from sklearn.feature_extraction.text import TfidfVectorizer
8
- from sklearn.metrics.pairwise import cosine_similarity
9
  import re
10
  from collections import Counter
11
  import torch
12
  from nltk.corpus import stopwords
13
  from nltk.tokenize import word_tokenize
14
  import nltk
 
 
 
15
 
16
  # Download necessary NLTK data
17
  nltk.download('stopwords', quiet=True)
18
  nltk.download('punkt', quiet=True)
19
 
20
- # Load BERT model for lemmatization
21
- bert_lemma_model_name = "bert-base-uncased"
22
- bert_lemma_tokenizer = AutoTokenizer.from_pretrained(bert_lemma_model_name)
23
- bert_lemma_model = AutoModelForMaskedLM.from_pretrained(bert_lemma_model_name).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
24
-
25
- # Load BERT model for encoding search queries
26
- bert_encode_model_name = 'anferico/bert-for-patents'
27
- bert_encode_tokenizer = AutoTokenizer.from_pretrained(bert_encode_model_name)
28
- bert_encode_model = AutoModel.from_pretrained(bert_encode_model_name)
29
-
30
- def bert_lemmatize(text):
31
- tokens = bert_lemma_tokenizer.tokenize(text)
32
- input_ids = bert_lemma_tokenizer.convert_tokens_to_ids(tokens)
33
- input_tensor = torch.tensor([input_ids]).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
34
- with torch.no_grad():
35
- outputs = bert_lemma_model(input_tensor)
36
- predictions = outputs.logits.argmax(dim=-1)
37
- lemmatized_tokens = bert_lemma_tokenizer.convert_ids_to_tokens(predictions[0])
38
- return ' '.join([token for token in lemmatized_tokens if token not in ['[CLS]', '[SEP]', '[PAD]']])
39
 
40
  def preprocess_query(text):
41
- # Convert to lowercase
42
- text = text.lower()
 
43
 
44
- # Remove any HTML tags (if present)
45
- text = re.sub('<.*?>', '', text)
 
46
 
47
- # Remove special characters, but keep hyphens, periods, and commas
48
- text = re.sub(r'[^a-zA-Z0-9\s\-\.\,]', '', text)
 
 
 
 
49
 
50
  # Tokenize
51
  tokens = word_tokenize(text)
52
 
53
- # Remove stopwords, but keep all other words
54
  stop_words = set(stopwords.words('english'))
55
- tokens = [word for word in tokens if word not in stop_words]
 
 
 
56
 
57
- # Join tokens back into a string
58
- processed_text = ' '.join(tokens)
59
 
60
- # Apply BERT lemmatization
61
- processed_text = bert_lemmatize(processed_text)
 
62
 
63
- return processed_text
 
 
 
64
 
65
  def extract_key_features(text):
66
  # For queries, we'll just preprocess and return all non-stopword terms
67
  processed_text = preprocess_query(text)
68
-
69
  # Split the processed text into individual terms
70
  features = processed_text.split()
71
-
72
  # Remove duplicates while preserving order
73
  features = list(dict.fromkeys(features))
74
-
75
  return features
76
 
77
- def encode_texts(texts, max_length=512):
78
- inputs = bert_encode_tokenizer(texts, padding=True, truncation=True, max_length=max_length, return_tensors='pt')
79
- with torch.no_grad():
80
- outputs = bert_encode_model(**inputs)
81
- embeddings = outputs.last_hidden_state.mean(dim=1)
82
- return embeddings.numpy()
83
 
84
  def load_data():
85
  try:
86
  with h5py.File('patent_embeddings.h5', 'r') as f:
87
  embeddings = f['embeddings'][:]
88
  patent_numbers = f['patent_numbers'][:]
89
-
90
  metadata = {}
91
  texts = []
92
  with open('patent_metadata.jsonl', 'r') as f:
@@ -94,17 +84,13 @@ def load_data():
94
  data = json.loads(line)
95
  metadata[data['patent_number']] = data
96
  texts.append(data['text'])
97
-
98
  print(f"Embedding shape: {embeddings.shape}")
99
  print(f"Number of patent numbers: {len(patent_numbers)}")
100
  print(f"Number of metadata entries: {len(metadata)}")
101
-
102
  return embeddings, patent_numbers, metadata, texts
103
- except FileNotFoundError as e:
104
- print(f"Error: Could not find file. {e}")
105
- raise
106
  except Exception as e:
107
- print(f"An unexpected error occurred while loading data: {e}")
108
  raise
109
 
110
  def compare_features(query_features, patent_features):
@@ -114,22 +100,21 @@ def compare_features(query_features, patent_features):
114
 
115
  def hybrid_search(query, top_k=5):
116
  print(f"Original query: {query}")
117
-
118
  processed_query = preprocess_query(query)
119
  query_features = extract_key_features(processed_query)
120
-
121
- # Encode the processed query using the transformer model
122
  query_embedding = encode_texts([processed_query])[0]
123
  query_embedding = query_embedding / np.linalg.norm(query_embedding)
124
-
125
  # Perform semantic similarity search
126
  semantic_distances, semantic_indices = index.search(np.array([query_embedding]).astype('float32'), top_k * 2)
127
-
128
  # Perform TF-IDF based search
129
  query_tfidf = tfidf_vectorizer.transform([processed_query])
130
  tfidf_similarities = cosine_similarity(query_tfidf, tfidf_matrix).flatten()
131
  tfidf_indices = tfidf_similarities.argsort()[-top_k * 2:][::-1]
132
-
133
  # Combine and rank results
134
  combined_results = {}
135
  for i, idx in enumerate(semantic_indices[0]):
@@ -142,7 +127,7 @@ def hybrid_search(query, top_k=5):
142
  'common_features': common_features,
143
  'text': text
144
  }
145
-
146
  for idx in tfidf_indices:
147
  patent_number = patent_numbers[idx].decode('utf-8')
148
  if patent_number not in combined_results:
@@ -154,10 +139,9 @@ def hybrid_search(query, top_k=5):
154
  'common_features': common_features,
155
  'text': text
156
  }
157
-
158
  # Sort and get top results
159
  top_results = sorted(combined_results.items(), key=lambda x: x[1]['score'], reverse=True)[:top_k]
160
-
161
  results = []
162
  for patent_number, data in top_results:
163
  result = f"Patent Number: {patent_number}\n"
@@ -165,19 +149,12 @@ def hybrid_search(query, top_k=5):
165
  result += f"Combined Score: {data['score']:.4f}\n"
166
  result += f"Common Key Features: {', '.join(data['common_features'])}\n\n"
167
  results.append(result)
168
-
169
  return "\n".join(results)
170
 
171
  # Load data and prepare the FAISS index
172
  embeddings, patent_numbers, metadata, texts = load_data()
173
 
174
- # Check if the embedding dimensions match
175
- if embeddings.shape[1] != encode_texts(["test"]).shape[1]:
176
- print("Embedding dimensions do not match. Rebuilding FAISS index.")
177
- # Rebuild embeddings using the new model
178
- embeddings = encode_texts(texts)
179
- embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
180
-
181
  # Normalize embeddings for cosine similarity
182
  embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
183
 
@@ -189,7 +166,7 @@ index.add(embeddings)
189
  tfidf_vectorizer = TfidfVectorizer(stop_words='english')
190
  tfidf_matrix = tfidf_vectorizer.fit_transform(texts)
191
 
192
- # Create Gradio interface with additional input fields
193
  iface = gr.Interface(
194
  fn=hybrid_search,
195
  inputs=[
@@ -202,4 +179,4 @@ iface = gr.Interface(
202
  )
203
 
204
  if __name__ == "__main__":
205
- iface.launch()
 
3
  import h5py
4
  import faiss
5
  import json
 
 
 
6
  import re
7
  from collections import Counter
8
  import torch
9
  from nltk.corpus import stopwords
10
  from nltk.tokenize import word_tokenize
11
  import nltk
12
+ from sentence_transformers import SentenceTransformer
13
+ from sklearn.feature_extraction.text import TfidfVectorizer
14
+ from sklearn.metrics.pairwise import cosine_similarity
15
 
16
  # Download necessary NLTK data
17
  nltk.download('stopwords', quiet=True)
18
  nltk.download('punkt', quiet=True)
19
 
20
+ # Load SentenceTransformer model
21
+ model = SentenceTransformer('anferico/bert-for-patents')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def preprocess_query(text):
24
+ # Remove "[EN]" label and claim numbers
25
+ text = re.sub(r'\[EN\]\s*', '', text)
26
+ text = re.sub(r'^\d+\.\s*', '', text, flags=re.MULTILINE)
27
 
28
+ # Convert to lowercase while preserving acronyms and units
29
+ words = text.split()
30
+ text = ' '.join(word if word.isupper() or re.match(r'^\d+(\.\d+)?[a-zA-Z]+$', word) else word.lower() for word in words)
31
 
32
+ # Remove special characters except hyphens and periods in numbers
33
+ text = re.sub(r'[^\w\s\-.]', ' ', text)
34
+ text = re.sub(r'(?<!\d)\.(?!\d)', ' ', text) # Remove periods not in numbers
35
+
36
+ # Normalize spaces
37
+ text = re.sub(r'\s+', ' ', text).strip()
38
 
39
  # Tokenize
40
  tokens = word_tokenize(text)
41
 
42
+ # Remove stopwords
43
  stop_words = set(stopwords.words('english'))
44
+ tokens = [word for word in tokens if word.lower() not in stop_words]
45
+
46
+ # Join tokens back into text
47
+ text = ' '.join(tokens)
48
 
49
+ # Preserve numerical values with units
50
+ text = re.sub(r'(\d+(\.\d+)?)([a-zA-Z]+)', r'\1_\3', text)
51
 
52
+ # Handle ranges and measurements
53
+ text = re.sub(r'(\d+(\.\d+)?)(\s*to\s*)(\d+(\.\d+)?)(\s*[a-zA-Z]+)', r'\1_to_\4_\6', text)
54
+ text = re.sub(r'between\s*(\d+(\.\d+)?)(\s*and\s*)(\d+(\.\d+)?)\s*([a-zA-Z]+)', r'between_\1_and_\4_\5', text)
55
 
56
+ # Preserve chemical formulas
57
+ text = re.sub(r'\b([A-Z][a-z]?\d*)+\b', lambda m: m.group().replace(' ', ''), text)
58
+
59
+ return text
60
 
61
  def extract_key_features(text):
62
  # For queries, we'll just preprocess and return all non-stopword terms
63
  processed_text = preprocess_query(text)
 
64
  # Split the processed text into individual terms
65
  features = processed_text.split()
 
66
  # Remove duplicates while preserving order
67
  features = list(dict.fromkeys(features))
 
68
  return features
69
 
70
+ def encode_texts(texts):
71
+ embeddings = model.encode(texts, show_progress_bar=True)
72
+ return embeddings
 
 
 
73
 
74
  def load_data():
75
  try:
76
  with h5py.File('patent_embeddings.h5', 'r') as f:
77
  embeddings = f['embeddings'][:]
78
  patent_numbers = f['patent_numbers'][:]
79
+
80
  metadata = {}
81
  texts = []
82
  with open('patent_metadata.jsonl', 'r') as f:
 
84
  data = json.loads(line)
85
  metadata[data['patent_number']] = data
86
  texts.append(data['text'])
87
+
88
  print(f"Embedding shape: {embeddings.shape}")
89
  print(f"Number of patent numbers: {len(patent_numbers)}")
90
  print(f"Number of metadata entries: {len(metadata)}")
 
91
  return embeddings, patent_numbers, metadata, texts
 
 
 
92
  except Exception as e:
93
+ print(f"An error occurred while loading data: {e}")
94
  raise
95
 
96
  def compare_features(query_features, patent_features):
 
100
 
101
  def hybrid_search(query, top_k=5):
102
  print(f"Original query: {query}")
 
103
  processed_query = preprocess_query(query)
104
  query_features = extract_key_features(processed_query)
105
+
106
+ # Encode the processed query using the SentenceTransformer model
107
  query_embedding = encode_texts([processed_query])[0]
108
  query_embedding = query_embedding / np.linalg.norm(query_embedding)
109
+
110
  # Perform semantic similarity search
111
  semantic_distances, semantic_indices = index.search(np.array([query_embedding]).astype('float32'), top_k * 2)
112
+
113
  # Perform TF-IDF based search
114
  query_tfidf = tfidf_vectorizer.transform([processed_query])
115
  tfidf_similarities = cosine_similarity(query_tfidf, tfidf_matrix).flatten()
116
  tfidf_indices = tfidf_similarities.argsort()[-top_k * 2:][::-1]
117
+
118
  # Combine and rank results
119
  combined_results = {}
120
  for i, idx in enumerate(semantic_indices[0]):
 
127
  'common_features': common_features,
128
  'text': text
129
  }
130
+
131
  for idx in tfidf_indices:
132
  patent_number = patent_numbers[idx].decode('utf-8')
133
  if patent_number not in combined_results:
 
139
  'common_features': common_features,
140
  'text': text
141
  }
142
+
143
  # Sort and get top results
144
  top_results = sorted(combined_results.items(), key=lambda x: x[1]['score'], reverse=True)[:top_k]
 
145
  results = []
146
  for patent_number, data in top_results:
147
  result = f"Patent Number: {patent_number}\n"
 
149
  result += f"Combined Score: {data['score']:.4f}\n"
150
  result += f"Common Key Features: {', '.join(data['common_features'])}\n\n"
151
  results.append(result)
152
+
153
  return "\n".join(results)
154
 
155
  # Load data and prepare the FAISS index
156
  embeddings, patent_numbers, metadata, texts = load_data()
157
 
 
 
 
 
 
 
 
158
  # Normalize embeddings for cosine similarity
159
  embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
160
 
 
166
  tfidf_vectorizer = TfidfVectorizer(stop_words='english')
167
  tfidf_matrix = tfidf_vectorizer.fit_transform(texts)
168
 
169
+ # Create Gradio interface
170
  iface = gr.Interface(
171
  fn=hybrid_search,
172
  inputs=[
 
179
  )
180
 
181
  if __name__ == "__main__":
182
+ iface.launch()