Update app.py
Browse files
app.py
CHANGED
@@ -3,26 +3,82 @@ import numpy as np
|
|
3 |
import h5py
|
4 |
import faiss
|
5 |
import json
|
6 |
-
from transformers import AutoTokenizer, AutoModel
|
7 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
8 |
from sklearn.metrics.pairwise import cosine_similarity
|
9 |
import re
|
10 |
from collections import Counter
|
11 |
-
import spacy
|
12 |
import torch
|
13 |
-
from nltk.corpus import
|
|
|
14 |
import nltk
|
15 |
|
16 |
-
# Download
|
17 |
-
nltk.download('
|
|
|
18 |
|
19 |
-
# Load
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
def load_data():
|
28 |
try:
|
@@ -50,50 +106,6 @@ def load_data():
|
|
50 |
print(f"An unexpected error occurred while loading data: {e}")
|
51 |
raise
|
52 |
|
53 |
-
embeddings, patent_numbers, metadata, texts = load_data()
|
54 |
-
|
55 |
-
# Load BERT model for encoding search queries
|
56 |
-
tokenizer = AutoTokenizer.from_pretrained('anferico/bert-for-patents')
|
57 |
-
bert_model = AutoModel.from_pretrained('anferico/bert-for-patents')
|
58 |
-
|
59 |
-
def encode_texts(texts, max_length=512):
|
60 |
-
inputs = tokenizer(texts, padding=True, truncation=True, max_length=max_length, return_tensors='pt')
|
61 |
-
with torch.no_grad():
|
62 |
-
outputs = bert_model(**inputs)
|
63 |
-
embeddings = outputs.last_hidden_state.mean(dim=1)
|
64 |
-
return embeddings.numpy()
|
65 |
-
|
66 |
-
# Check if the embedding dimensions match
|
67 |
-
if embeddings.shape[1] != encode_texts(["test"]).shape[1]:
|
68 |
-
print("Embedding dimensions do not match. Rebuilding FAISS index.")
|
69 |
-
# Rebuild embeddings using the new model
|
70 |
-
embeddings = encode_texts(texts)
|
71 |
-
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
|
72 |
-
|
73 |
-
# Normalize embeddings for cosine similarity
|
74 |
-
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
|
75 |
-
|
76 |
-
# Create FAISS index for cosine similarity
|
77 |
-
index = faiss.IndexFlatIP(embeddings.shape[1])
|
78 |
-
index.add(embeddings)
|
79 |
-
|
80 |
-
# Create TF-IDF vectorizer
|
81 |
-
tfidf_vectorizer = TfidfVectorizer(stop_words='english')
|
82 |
-
tfidf_matrix = tfidf_vectorizer.fit_transform(texts)
|
83 |
-
|
84 |
-
def extract_key_features(text):
|
85 |
-
# Use Spacy to extract technical terms and phrases
|
86 |
-
doc = nlp(text)
|
87 |
-
technical_terms = []
|
88 |
-
for token in doc:
|
89 |
-
if token.dep_ in ('amod', 'compound') or token.ent_type_ in ('PRODUCT', 'ORG', 'GPE', 'NORP'):
|
90 |
-
technical_terms.append(token.text.lower())
|
91 |
-
noun_phrases = [chunk.text.lower() for chunk in doc.noun_chunks]
|
92 |
-
feature_phrases = [sent.text.lower() for sent in doc.sents if re.search(r'(comprising|including|consisting of|deformable|insulation|heat-resistant|memory foam|high-temperature)', sent.text, re.IGNORECASE)]
|
93 |
-
|
94 |
-
all_features = technical_terms + noun_phrases + feature_phrases
|
95 |
-
return list(set(all_features))
|
96 |
-
|
97 |
def compare_features(query_features, patent_features):
|
98 |
common_features = set(query_features) & set(patent_features)
|
99 |
similarity_score = len(common_features) / max(len(query_features), len(patent_features))
|
@@ -102,17 +114,18 @@ def compare_features(query_features, patent_features):
|
|
102 |
def hybrid_search(query, top_k=5):
|
103 |
print(f"Original query: {query}")
|
104 |
|
105 |
-
|
|
|
106 |
|
107 |
-
# Encode the query using the transformer model
|
108 |
-
query_embedding = encode_texts([
|
109 |
query_embedding = query_embedding / np.linalg.norm(query_embedding)
|
110 |
|
111 |
# Perform semantic similarity search
|
112 |
semantic_distances, semantic_indices = index.search(np.array([query_embedding]).astype('float32'), top_k * 2)
|
113 |
|
114 |
# Perform TF-IDF based search
|
115 |
-
query_tfidf = tfidf_vectorizer.transform([
|
116 |
tfidf_similarities = cosine_similarity(query_tfidf, tfidf_matrix).flatten()
|
117 |
tfidf_indices = tfidf_similarities.argsort()[-top_k * 2:][::-1]
|
118 |
|
@@ -154,6 +167,27 @@ def hybrid_search(query, top_k=5):
|
|
154 |
|
155 |
return "\n".join(results)
|
156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
# Create Gradio interface with additional input fields
|
158 |
iface = gr.Interface(
|
159 |
fn=hybrid_search,
|
@@ -167,4 +201,4 @@ iface = gr.Interface(
|
|
167 |
)
|
168 |
|
169 |
if __name__ == "__main__":
|
170 |
-
iface.launch()
|
|
|
3 |
import h5py
|
4 |
import faiss
|
5 |
import json
|
6 |
+
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM
|
7 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
8 |
from sklearn.metrics.pairwise import cosine_similarity
|
9 |
import re
|
10 |
from collections import Counter
|
|
|
11 |
import torch
|
12 |
+
from nltk.corpus import stopwords
|
13 |
+
from nltk.tokenize import word_tokenize
|
14 |
import nltk
|
15 |
|
16 |
+
# Download necessary NLTK data
|
17 |
+
nltk.download('stopwords', quiet=True)
|
18 |
+
nltk.download('punkt', quiet=True)
|
19 |
|
20 |
+
# Load BERT model for lemmatization
|
21 |
+
bert_model_name = "bert-base-uncased"
|
22 |
+
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
|
23 |
+
bert_model = AutoModelForMaskedLM.from_pretrained(bert_model_name).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
24 |
+
|
25 |
+
# Load BERT model for encoding search queries
|
26 |
+
tokenizer = AutoTokenizer.from_pretrained('anferico/bert-for-patents')
|
27 |
+
bert_model = AutoModel.from_pretrained('anferico/bert-for-patents')
|
28 |
+
|
29 |
+
def bert_lemmatize(text):
|
30 |
+
tokens = bert_tokenizer.tokenize(text)
|
31 |
+
input_ids = bert_tokenizer.convert_tokens_to_ids(tokens)
|
32 |
+
input_tensor = torch.tensor([input_ids]).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
33 |
+
with torch.no_grad():
|
34 |
+
outputs = bert_model(input_tensor)
|
35 |
+
predictions = outputs.logits.argmax(dim=-1)
|
36 |
+
lemmatized_tokens = bert_tokenizer.convert_ids_to_tokens(predictions[0])
|
37 |
+
return ' '.join([token for token in lemmatized_tokens if token not in ['[CLS]', '[SEP]', '[PAD]']])
|
38 |
+
|
39 |
+
def preprocess_query(text):
|
40 |
+
# Convert to lowercase
|
41 |
+
text = text.lower()
|
42 |
+
|
43 |
+
# Remove any HTML tags (if present)
|
44 |
+
text = re.sub('<.*?>', '', text)
|
45 |
+
|
46 |
+
# Remove special characters, but keep hyphens, periods, and commas
|
47 |
+
text = re.sub(r'[^a-zA-Z0-9\s\-\.\,]', '', text)
|
48 |
+
|
49 |
+
# Tokenize
|
50 |
+
tokens = word_tokenize(text)
|
51 |
+
|
52 |
+
# Remove stopwords, but keep all other words
|
53 |
+
stop_words = set(stopwords.words('english'))
|
54 |
+
tokens = [word for word in tokens if word not in stop_words]
|
55 |
+
|
56 |
+
# Join tokens back into a string
|
57 |
+
processed_text = ' '.join(tokens)
|
58 |
+
|
59 |
+
# Apply BERT lemmatization
|
60 |
+
processed_text = bert_lemmatize(processed_text)
|
61 |
+
|
62 |
+
return processed_text
|
63 |
+
|
64 |
+
def extract_key_features(text):
|
65 |
+
# For queries, we'll just preprocess and return all non-stopword terms
|
66 |
+
processed_text = preprocess_query(text)
|
67 |
+
|
68 |
+
# Split the processed text into individual terms
|
69 |
+
features = processed_text.split()
|
70 |
+
|
71 |
+
# Remove duplicates while preserving order
|
72 |
+
features = list(dict.fromkeys(features))
|
73 |
+
|
74 |
+
return features
|
75 |
+
|
76 |
+
def encode_texts(texts, max_length=512):
|
77 |
+
inputs = tokenizer(texts, padding=True, truncation=True, max_length=max_length, return_tensors='pt')
|
78 |
+
with torch.no_grad():
|
79 |
+
outputs = bert_model(**inputs)
|
80 |
+
embeddings = outputs.last_hidden_state.mean(dim=1)
|
81 |
+
return embeddings.numpy()
|
82 |
|
83 |
def load_data():
|
84 |
try:
|
|
|
106 |
print(f"An unexpected error occurred while loading data: {e}")
|
107 |
raise
|
108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
def compare_features(query_features, patent_features):
|
110 |
common_features = set(query_features) & set(patent_features)
|
111 |
similarity_score = len(common_features) / max(len(query_features), len(patent_features))
|
|
|
114 |
def hybrid_search(query, top_k=5):
|
115 |
print(f"Original query: {query}")
|
116 |
|
117 |
+
processed_query = preprocess_query(query)
|
118 |
+
query_features = extract_key_features(processed_query)
|
119 |
|
120 |
+
# Encode the processed query using the transformer model
|
121 |
+
query_embedding = encode_texts([processed_query])[0]
|
122 |
query_embedding = query_embedding / np.linalg.norm(query_embedding)
|
123 |
|
124 |
# Perform semantic similarity search
|
125 |
semantic_distances, semantic_indices = index.search(np.array([query_embedding]).astype('float32'), top_k * 2)
|
126 |
|
127 |
# Perform TF-IDF based search
|
128 |
+
query_tfidf = tfidf_vectorizer.transform([processed_query])
|
129 |
tfidf_similarities = cosine_similarity(query_tfidf, tfidf_matrix).flatten()
|
130 |
tfidf_indices = tfidf_similarities.argsort()[-top_k * 2:][::-1]
|
131 |
|
|
|
167 |
|
168 |
return "\n".join(results)
|
169 |
|
170 |
+
# Load data and prepare the FAISS index
|
171 |
+
embeddings, patent_numbers, metadata, texts = load_data()
|
172 |
+
|
173 |
+
# Check if the embedding dimensions match
|
174 |
+
if embeddings.shape[1] != encode_texts(["test"]).shape[1]:
|
175 |
+
print("Embedding dimensions do not match. Rebuilding FAISS index.")
|
176 |
+
# Rebuild embeddings using the new model
|
177 |
+
embeddings = encode_texts(texts)
|
178 |
+
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
|
179 |
+
|
180 |
+
# Normalize embeddings for cosine similarity
|
181 |
+
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
|
182 |
+
|
183 |
+
# Create FAISS index for cosine similarity
|
184 |
+
index = faiss.IndexFlatIP(embeddings.shape[1])
|
185 |
+
index.add(embeddings)
|
186 |
+
|
187 |
+
# Create TF-IDF vectorizer
|
188 |
+
tfidf_vectorizer = TfidfVectorizer(stop_words='english')
|
189 |
+
tfidf_matrix = tfidf_vectorizer.fit_transform(texts)
|
190 |
+
|
191 |
# Create Gradio interface with additional input fields
|
192 |
iface = gr.Interface(
|
193 |
fn=hybrid_search,
|
|
|
201 |
)
|
202 |
|
203 |
if __name__ == "__main__":
|
204 |
+
iface.launch()
|