Spaces:
Runtime error
Runtime error
wendru18
commited on
Commit
·
1b5e4ab
1
Parent(s):
bffb796
method works well
Browse files- __pycache__/semantic_search.cpython-38.pyc +0 -0
- main.ipynb +176 -8
- semantic_search.py +39 -0
__pycache__/semantic_search.cpython-38.pyc
ADDED
Binary file (1.75 kB). View file
|
|
main.ipynb
CHANGED
@@ -2,13 +2,18 @@
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
-
"execution_count":
|
6 |
"metadata": {},
|
7 |
"outputs": [],
|
8 |
"source": [
|
|
|
9 |
"import pandas as pd\n",
|
|
|
10 |
"import praw\n",
|
11 |
"import os\n",
|
|
|
|
|
|
|
12 |
"\n",
|
13 |
"pd.set_option('max_colwidth', 100)\n",
|
14 |
"pd.set_option('display.max_columns', None)"
|
@@ -16,7 +21,7 @@
|
|
16 |
},
|
17 |
{
|
18 |
"cell_type": "code",
|
19 |
-
"execution_count":
|
20 |
"metadata": {},
|
21 |
"outputs": [],
|
22 |
"source": [
|
@@ -29,19 +34,182 @@
|
|
29 |
},
|
30 |
{
|
31 |
"cell_type": "code",
|
32 |
-
"execution_count":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
"metadata": {},
|
34 |
"outputs": [],
|
35 |
"source": [
|
36 |
"posts = []\n",
|
37 |
-
"comments = []\n",
|
38 |
"\n",
|
39 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
"\n",
|
41 |
-
"for post in reddit.subreddit(
|
42 |
-
"
|
|
|
43 |
"\n",
|
44 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
]
|
46 |
}
|
47 |
],
|
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
+
"execution_count": 70,
|
6 |
"metadata": {},
|
7 |
"outputs": [],
|
8 |
"source": [
|
9 |
+
"from semantic_search import SemanticSearch \n",
|
10 |
"import pandas as pd\n",
|
11 |
+
"import openai\n",
|
12 |
"import praw\n",
|
13 |
"import os\n",
|
14 |
+
"import re\n",
|
15 |
+
"\n",
|
16 |
+
"searcher = SemanticSearch()\n",
|
17 |
"\n",
|
18 |
"pd.set_option('max_colwidth', 100)\n",
|
19 |
"pd.set_option('display.max_columns', None)"
|
|
|
21 |
},
|
22 |
{
|
23 |
"cell_type": "code",
|
24 |
+
"execution_count": 71,
|
25 |
"metadata": {},
|
26 |
"outputs": [],
|
27 |
"source": [
|
|
|
34 |
},
|
35 |
{
|
36 |
"cell_type": "code",
|
37 |
+
"execution_count": 72,
|
38 |
+
"metadata": {},
|
39 |
+
"outputs": [],
|
40 |
+
"source": [
|
41 |
+
"def generate_topics(query, model=\"gpt-3.5-turbo\"):\n",
|
42 |
+
"\n",
|
43 |
+
" messages = [\n",
|
44 |
+
" {\"role\": \"user\", \"content\": f\"Take this query '{query}' and return a list of topics to input in Search so it returns good results. Each topic must stand on its own with respect to the relation of the question.\"},\n",
|
45 |
+
" ]\n",
|
46 |
+
"\n",
|
47 |
+
" response = openai.ChatCompletion.create(\n",
|
48 |
+
" model=model,\n",
|
49 |
+
" messages=messages\n",
|
50 |
+
" )\n",
|
51 |
+
"\n",
|
52 |
+
" response_message = response[\"choices\"][0][\"message\"][\"content\"]\n",
|
53 |
+
"\n",
|
54 |
+
" topics = re.sub(r'^\\d+\\.\\s*', '', response_message, flags=re.MULTILINE).split(\"\\n\")\n",
|
55 |
+
"\n",
|
56 |
+
" return topics"
|
57 |
+
]
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "code",
|
61 |
+
"execution_count": 73,
|
62 |
+
"metadata": {},
|
63 |
+
"outputs": [
|
64 |
+
{
|
65 |
+
"name": "stdout",
|
66 |
+
"output_type": "stream",
|
67 |
+
"text": [
|
68 |
+
"['Top-rated pizza places in NYC', 'Best New York-style pizza restaurants', 'Highest reviewed pizzerias in New York City', 'Best pizza joints in Manhattan', 'Award-winning pizza spots in NYC', 'Famous pizza places in New York', 'Top 10 pizza places to try in NYC', 'Authentic pizza restaurants in New York', 'Best pizza slices in NYC', 'Gourmet pizza places in New York City']\n"
|
69 |
+
]
|
70 |
+
}
|
71 |
+
],
|
72 |
+
"source": [
|
73 |
+
"query = \"Best pizza place in NYC\"\n",
|
74 |
+
"topics = generate_topics(query)\n",
|
75 |
+
"print(topics)"
|
76 |
+
]
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"cell_type": "code",
|
80 |
+
"execution_count": 74,
|
81 |
"metadata": {},
|
82 |
"outputs": [],
|
83 |
"source": [
|
84 |
"posts = []\n",
|
|
|
85 |
"\n",
|
86 |
+
"for topic in topics:\n",
|
87 |
+
" for post in reddit.subreddit(\"all\").search(\n",
|
88 |
+
" topic, limit=200):\n",
|
89 |
+
" posts.append([post.title, post.subreddit, post.selftext])\n",
|
90 |
+
"\n",
|
91 |
+
"posts = pd.DataFrame(posts,columns=['title', 'subreddit', 'text'])"
|
92 |
+
]
|
93 |
+
},
|
94 |
+
{
|
95 |
+
"cell_type": "code",
|
96 |
+
"execution_count": 75,
|
97 |
+
"metadata": {},
|
98 |
+
"outputs": [],
|
99 |
+
"source": [
|
100 |
+
"# Segments is title, text and subreddit at the end\n",
|
101 |
+
"segments = (posts['title'] + ' ' + posts['subreddit'].astype(str)).tolist()"
|
102 |
+
]
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"cell_type": "code",
|
106 |
+
"execution_count": 76,
|
107 |
+
"metadata": {},
|
108 |
+
"outputs": [
|
109 |
+
{
|
110 |
+
"name": "stdout",
|
111 |
+
"output_type": "stream",
|
112 |
+
"text": [
|
113 |
+
"Fitting with n=5...\n"
|
114 |
+
]
|
115 |
+
}
|
116 |
+
],
|
117 |
+
"source": [
|
118 |
+
"searcher.fit(segments)"
|
119 |
+
]
|
120 |
+
},
|
121 |
+
{
|
122 |
+
"cell_type": "code",
|
123 |
+
"execution_count": 77,
|
124 |
+
"metadata": {},
|
125 |
+
"outputs": [],
|
126 |
+
"source": [
|
127 |
+
"# TODO: Add distance check here\n",
|
128 |
+
"subreddits = set([result.split()[-1] for result in searcher(query)])"
|
129 |
+
]
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"cell_type": "code",
|
133 |
+
"execution_count": 78,
|
134 |
+
"metadata": {},
|
135 |
+
"outputs": [],
|
136 |
+
"source": [
|
137 |
+
"# Convert to string and \"+\" in between\n",
|
138 |
+
"subreddits = \"+\".join(subreddits)"
|
139 |
+
]
|
140 |
+
},
|
141 |
+
{
|
142 |
+
"cell_type": "code",
|
143 |
+
"execution_count": 79,
|
144 |
+
"metadata": {},
|
145 |
+
"outputs": [],
|
146 |
+
"source": [
|
147 |
+
"final_posts = []\n",
|
148 |
"\n",
|
149 |
+
"for post in reddit.subreddit(subreddits).search(\n",
|
150 |
+
" query, limit=100):\n",
|
151 |
+
" final_posts.append([post.title, post.subreddit, post.selftext])\n",
|
152 |
"\n",
|
153 |
+
"final_posts = pd.DataFrame(final_posts,columns=['title', 'subreddit', 'text'])"
|
154 |
+
]
|
155 |
+
},
|
156 |
+
{
|
157 |
+
"cell_type": "code",
|
158 |
+
"execution_count": 80,
|
159 |
+
"metadata": {},
|
160 |
+
"outputs": [],
|
161 |
+
"source": [
|
162 |
+
"final_segments = (final_posts['title'] + ' ' + final_posts['subreddit'].astype(str)).tolist()"
|
163 |
+
]
|
164 |
+
},
|
165 |
+
{
|
166 |
+
"cell_type": "code",
|
167 |
+
"execution_count": 81,
|
168 |
+
"metadata": {},
|
169 |
+
"outputs": [],
|
170 |
+
"source": [
|
171 |
+
"final_searcher = SemanticSearch()"
|
172 |
+
]
|
173 |
+
},
|
174 |
+
{
|
175 |
+
"cell_type": "code",
|
176 |
+
"execution_count": 82,
|
177 |
+
"metadata": {},
|
178 |
+
"outputs": [
|
179 |
+
{
|
180 |
+
"name": "stdout",
|
181 |
+
"output_type": "stream",
|
182 |
+
"text": [
|
183 |
+
"Fitting with n=5...\n"
|
184 |
+
]
|
185 |
+
}
|
186 |
+
],
|
187 |
+
"source": [
|
188 |
+
"final_searcher.fit(final_segments)"
|
189 |
+
]
|
190 |
+
},
|
191 |
+
{
|
192 |
+
"cell_type": "code",
|
193 |
+
"execution_count": 83,
|
194 |
+
"metadata": {},
|
195 |
+
"outputs": [
|
196 |
+
{
|
197 |
+
"data": {
|
198 |
+
"text/plain": [
|
199 |
+
"['Best Pizza in Manhattan? FoodNYC',\n",
|
200 |
+
" 'Best Pizza slice in NYC? FoodNYC',\n",
|
201 |
+
" 'Best Pizza, Manhattan NYC at 40.713050, -74.007230 The8BitRyanReddit',\n",
|
202 |
+
" \"Controversial: What's the best pizza place in all of NYC? newyorkcity\",\n",
|
203 |
+
" \"Best pizza place in NYC that tourists DON'T know? circlejerknyc\"]"
|
204 |
+
]
|
205 |
+
},
|
206 |
+
"execution_count": 83,
|
207 |
+
"metadata": {},
|
208 |
+
"output_type": "execute_result"
|
209 |
+
}
|
210 |
+
],
|
211 |
+
"source": [
|
212 |
+
"final_searcher(query)"
|
213 |
]
|
214 |
}
|
215 |
],
|
semantic_search.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sklearn.neighbors import NearestNeighbors
|
2 |
+
import tensorflow_hub as hub
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
class SemanticSearch:
|
6 |
+
|
7 |
+
def __init__(self):
|
8 |
+
self.use = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
|
9 |
+
self.fitted = False
|
10 |
+
|
11 |
+
|
12 |
+
def fit(self, data, batch=1000, n_neighbors=5):
|
13 |
+
print(f"Fitting with n={n_neighbors}...")
|
14 |
+
self.data = data
|
15 |
+
self.embeddings = self.get_text_embedding(data, batch=batch)
|
16 |
+
n_neighbors = min(n_neighbors, len(self.embeddings))
|
17 |
+
self.nn = NearestNeighbors(n_neighbors=n_neighbors)
|
18 |
+
self.nn.fit(self.embeddings)
|
19 |
+
self.fitted = True
|
20 |
+
|
21 |
+
|
22 |
+
def __call__(self, text, return_data=True):
|
23 |
+
inp_emb = self.use([text])
|
24 |
+
neighbors = self.nn.kneighbors(inp_emb, return_distance=False)[0]
|
25 |
+
|
26 |
+
if return_data:
|
27 |
+
return [self.data[i] for i in neighbors]
|
28 |
+
else:
|
29 |
+
return neighbors
|
30 |
+
|
31 |
+
|
32 |
+
def get_text_embedding(self, texts, batch=1000):
|
33 |
+
embeddings = []
|
34 |
+
for i in range(0, len(texts), batch):
|
35 |
+
text_batch = texts[i:(i+batch)]
|
36 |
+
emb_batch = self.use(text_batch)
|
37 |
+
embeddings.append(emb_batch)
|
38 |
+
embeddings = np.vstack(embeddings)
|
39 |
+
return embeddings
|