Spaces:
Runtime error
Runtime error
wendru18
commited on
Commit
·
6ee98e6
1
Parent(s):
13db74a
added all steps, not fully fuctional
Browse files- main.ipynb +158 -78
- semantic_search.py +3 -3
main.ipynb
CHANGED
@@ -2,26 +2,34 @@
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
-
"execution_count":
|
6 |
"metadata": {},
|
7 |
"outputs": [],
|
8 |
"source": [
|
9 |
"from semantic_search import SemanticSearch \n",
|
10 |
"import pandas as pd\n",
|
|
|
11 |
"import openai\n",
|
12 |
"import praw\n",
|
13 |
"import os\n",
|
14 |
"import re\n",
|
15 |
"\n",
|
16 |
-
"searcher = SemanticSearch()\n",
|
17 |
-
"\n",
|
18 |
"pd.set_option('max_colwidth', 100)\n",
|
19 |
"pd.set_option('display.max_columns', None)"
|
20 |
]
|
21 |
},
|
22 |
{
|
23 |
"cell_type": "code",
|
24 |
-
"execution_count":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
"metadata": {},
|
26 |
"outputs": [],
|
27 |
"source": [
|
@@ -32,16 +40,24 @@
|
|
32 |
"reddit = praw.Reddit(client_id=REDDIT_CLIENT_ID, client_secret=REDDIT_CLIENT_SECRET, user_agent=REDDIT_USER_AGENT)"
|
33 |
]
|
34 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
{
|
36 |
"cell_type": "code",
|
37 |
-
"execution_count":
|
38 |
"metadata": {},
|
39 |
"outputs": [],
|
40 |
"source": [
|
41 |
"def generate_topics(query, model=\"gpt-3.5-turbo\"):\n",
|
42 |
"\n",
|
43 |
" messages = [\n",
|
44 |
-
" {\"role\": \"user\", \"content\": f\"Take this query '{query}' and return a list of topics to input in Search so it returns good results. Each topic must stand on its own with respect to the relation of the question.\"},\n",
|
45 |
" ]\n",
|
46 |
"\n",
|
47 |
" response = openai.ChatCompletion.create(\n",
|
@@ -58,26 +74,35 @@
|
|
58 |
},
|
59 |
{
|
60 |
"cell_type": "code",
|
61 |
-
"execution_count":
|
62 |
"metadata": {},
|
63 |
-
"outputs": [
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
72 |
"source": [
|
73 |
-
"query = \"Best pizza place in NYC\"\n",
|
74 |
"topics = generate_topics(query)\n",
|
|
|
75 |
"print(topics)"
|
76 |
]
|
77 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
{
|
79 |
"cell_type": "code",
|
80 |
-
"execution_count":
|
81 |
"metadata": {},
|
82 |
"outputs": [],
|
83 |
"source": [
|
@@ -88,128 +113,183 @@
|
|
88 |
" topic, limit=200):\n",
|
89 |
" posts.append([post.title, post.subreddit, post.selftext])\n",
|
90 |
"\n",
|
91 |
-
"posts = pd.DataFrame(posts,columns=['title', 'subreddit', 'text'])"
|
|
|
|
|
|
|
92 |
]
|
93 |
},
|
94 |
{
|
95 |
"cell_type": "code",
|
96 |
-
"execution_count":
|
97 |
"metadata": {},
|
98 |
"outputs": [],
|
99 |
"source": [
|
100 |
-
"
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
]
|
103 |
},
|
104 |
{
|
105 |
"cell_type": "code",
|
106 |
-
"execution_count":
|
107 |
"metadata": {},
|
108 |
-
"outputs": [
|
109 |
-
{
|
110 |
-
"name": "stdout",
|
111 |
-
"output_type": "stream",
|
112 |
-
"text": [
|
113 |
-
"Fitting with n=5...\n"
|
114 |
-
]
|
115 |
-
}
|
116 |
-
],
|
117 |
"source": [
|
118 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
]
|
120 |
},
|
121 |
{
|
122 |
"cell_type": "code",
|
123 |
-
"execution_count":
|
124 |
"metadata": {},
|
125 |
"outputs": [],
|
126 |
"source": [
|
127 |
-
"
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
]
|
130 |
},
|
131 |
{
|
132 |
"cell_type": "code",
|
133 |
-
"execution_count":
|
134 |
"metadata": {},
|
135 |
"outputs": [],
|
136 |
"source": [
|
137 |
-
"
|
138 |
-
"
|
|
|
139 |
]
|
140 |
},
|
141 |
{
|
142 |
"cell_type": "code",
|
143 |
-
"execution_count":
|
144 |
"metadata": {},
|
145 |
"outputs": [],
|
146 |
"source": [
|
147 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
"\n",
|
149 |
-
"
|
150 |
-
" query, limit=100):\n",
|
151 |
-
" final_posts.append([post.title, post.subreddit, post.selftext])\n",
|
152 |
"\n",
|
153 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
]
|
155 |
},
|
156 |
{
|
157 |
"cell_type": "code",
|
158 |
-
"execution_count":
|
159 |
"metadata": {},
|
160 |
"outputs": [],
|
161 |
"source": [
|
162 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
]
|
164 |
},
|
165 |
{
|
166 |
"cell_type": "code",
|
167 |
-
"execution_count":
|
168 |
"metadata": {},
|
169 |
"outputs": [],
|
170 |
"source": [
|
171 |
-
"
|
172 |
]
|
173 |
},
|
174 |
{
|
175 |
"cell_type": "code",
|
176 |
-
"execution_count":
|
177 |
"metadata": {},
|
178 |
-
"outputs": [
|
179 |
-
{
|
180 |
-
"name": "stdout",
|
181 |
-
"output_type": "stream",
|
182 |
-
"text": [
|
183 |
-
"Fitting with n=5...\n"
|
184 |
-
]
|
185 |
-
}
|
186 |
-
],
|
187 |
"source": [
|
188 |
-
"
|
189 |
]
|
190 |
},
|
191 |
{
|
192 |
"cell_type": "code",
|
193 |
-
"execution_count":
|
194 |
"metadata": {},
|
195 |
-
"outputs": [
|
196 |
-
{
|
197 |
-
"data": {
|
198 |
-
"text/plain": [
|
199 |
-
"['Best Pizza in Manhattan? FoodNYC',\n",
|
200 |
-
" 'Best Pizza slice in NYC? FoodNYC',\n",
|
201 |
-
" 'Best Pizza, Manhattan NYC at 40.713050, -74.007230 The8BitRyanReddit',\n",
|
202 |
-
" \"Controversial: What's the best pizza place in all of NYC? newyorkcity\",\n",
|
203 |
-
" \"Best pizza place in NYC that tourists DON'T know? circlejerknyc\"]"
|
204 |
-
]
|
205 |
-
},
|
206 |
-
"execution_count": 83,
|
207 |
-
"metadata": {},
|
208 |
-
"output_type": "execute_result"
|
209 |
-
}
|
210 |
-
],
|
211 |
"source": [
|
212 |
-
"
|
213 |
]
|
214 |
}
|
215 |
],
|
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
"metadata": {},
|
7 |
"outputs": [],
|
8 |
"source": [
|
9 |
"from semantic_search import SemanticSearch \n",
|
10 |
"import pandas as pd\n",
|
11 |
+
"import tiktoken\n",
|
12 |
"import openai\n",
|
13 |
"import praw\n",
|
14 |
"import os\n",
|
15 |
"import re\n",
|
16 |
"\n",
|
|
|
|
|
17 |
"pd.set_option('max_colwidth', 100)\n",
|
18 |
"pd.set_option('display.max_columns', None)"
|
19 |
]
|
20 |
},
|
21 |
{
|
22 |
"cell_type": "code",
|
23 |
+
"execution_count": null,
|
24 |
+
"metadata": {},
|
25 |
+
"outputs": [],
|
26 |
+
"source": [
|
27 |
+
"searcher = SemanticSearch()"
|
28 |
+
]
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"cell_type": "code",
|
32 |
+
"execution_count": null,
|
33 |
"metadata": {},
|
34 |
"outputs": [],
|
35 |
"source": [
|
|
|
40 |
"reddit = praw.Reddit(client_id=REDDIT_CLIENT_ID, client_secret=REDDIT_CLIENT_SECRET, user_agent=REDDIT_USER_AGENT)"
|
41 |
]
|
42 |
},
|
43 |
+
{
|
44 |
+
"attachments": {},
|
45 |
+
"cell_type": "markdown",
|
46 |
+
"metadata": {},
|
47 |
+
"source": [
|
48 |
+
"## Topic Retrieval"
|
49 |
+
]
|
50 |
+
},
|
51 |
{
|
52 |
"cell_type": "code",
|
53 |
+
"execution_count": null,
|
54 |
"metadata": {},
|
55 |
"outputs": [],
|
56 |
"source": [
|
57 |
"def generate_topics(query, model=\"gpt-3.5-turbo\"):\n",
|
58 |
"\n",
|
59 |
" messages = [\n",
|
60 |
+
" {\"role\": \"user\", \"content\": f\"Take this query '{query}' and return a list of short topics to input in Search so it returns good results. Each topic must stand on its own with respect to the relation of the question.\"},\n",
|
61 |
" ]\n",
|
62 |
"\n",
|
63 |
" response = openai.ChatCompletion.create(\n",
|
|
|
74 |
},
|
75 |
{
|
76 |
"cell_type": "code",
|
77 |
+
"execution_count": null,
|
78 |
"metadata": {},
|
79 |
+
"outputs": [],
|
80 |
+
"source": [
|
81 |
+
"query = \"Where are some nice places where I can work remotely in Malta?\""
|
82 |
+
]
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"cell_type": "code",
|
86 |
+
"execution_count": null,
|
87 |
+
"metadata": {},
|
88 |
+
"outputs": [],
|
89 |
"source": [
|
|
|
90 |
"topics = generate_topics(query)\n",
|
91 |
+
"topics = [topic.strip() for topic in topics]\n",
|
92 |
"print(topics)"
|
93 |
]
|
94 |
},
|
95 |
+
{
|
96 |
+
"attachments": {},
|
97 |
+
"cell_type": "markdown",
|
98 |
+
"metadata": {},
|
99 |
+
"source": [
|
100 |
+
"## Relevant Subreddits Retrieval"
|
101 |
+
]
|
102 |
+
},
|
103 |
{
|
104 |
"cell_type": "code",
|
105 |
+
"execution_count": null,
|
106 |
"metadata": {},
|
107 |
"outputs": [],
|
108 |
"source": [
|
|
|
113 |
" topic, limit=200):\n",
|
114 |
" posts.append([post.title, post.subreddit, post.selftext])\n",
|
115 |
"\n",
|
116 |
+
"posts = pd.DataFrame(posts,columns=['title', 'subreddit', 'text'])\n",
|
117 |
+
"\n",
|
118 |
+
"# Segments is title, text and subreddit at the end\n",
|
119 |
+
"segments = (posts['title'] + ' ' + posts['subreddit'].astype(str)).tolist()"
|
120 |
]
|
121 |
},
|
122 |
{
|
123 |
"cell_type": "code",
|
124 |
+
"execution_count": null,
|
125 |
"metadata": {},
|
126 |
"outputs": [],
|
127 |
"source": [
|
128 |
+
"searcher.fit(segments, n_neighbors=5)"
|
129 |
+
]
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"cell_type": "code",
|
133 |
+
"execution_count": null,
|
134 |
+
"metadata": {},
|
135 |
+
"outputs": [],
|
136 |
+
"source": [
|
137 |
+
"# TODO: Add distance check here\n",
|
138 |
+
"subreddits = set([result.split()[-1] for result in searcher(query)])\n",
|
139 |
+
"\n",
|
140 |
+
"# Convert to string and \"+\" in between\n",
|
141 |
+
"subreddits = \"+\".join(subreddits)\n",
|
142 |
+
"\n",
|
143 |
+
"print(f\"Relevant subreddits: {subreddits}\")"
|
144 |
+
]
|
145 |
+
},
|
146 |
+
{
|
147 |
+
"attachments": {},
|
148 |
+
"cell_type": "markdown",
|
149 |
+
"metadata": {},
|
150 |
+
"source": [
|
151 |
+
"## Relevant Posts Retrieval"
|
152 |
]
|
153 |
},
|
154 |
{
|
155 |
"cell_type": "code",
|
156 |
+
"execution_count": null,
|
157 |
"metadata": {},
|
158 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
"source": [
|
160 |
+
"segments = []\n",
|
161 |
+
"segment_length = 100\n",
|
162 |
+
"\n",
|
163 |
+
"\n",
|
164 |
+
"for topic in topics:\n",
|
165 |
+
" for post in reddit.subreddit(subreddits).search(\n",
|
166 |
+
" topic, limit=50):\n",
|
167 |
+
" \n",
|
168 |
+
" comments = \"\"\n",
|
169 |
+
"\n",
|
170 |
+
" post.comments.replace_more(limit=3)\n",
|
171 |
+
" for comment in post.comments.list():\n",
|
172 |
+
" if comment.body != \"[deleted]\":\n",
|
173 |
+
" comments += comment.body + \"\\n\"\n",
|
174 |
+
"\n",
|
175 |
+
" words = comments.split()\n",
|
176 |
+
" segments.extend([post.title + \" \" + post.id + \"\\n\" + ' '.join(words[i:i+segment_length]) for i in range(0, len(words), segment_length)])"
|
177 |
]
|
178 |
},
|
179 |
{
|
180 |
"cell_type": "code",
|
181 |
+
"execution_count": null,
|
182 |
"metadata": {},
|
183 |
"outputs": [],
|
184 |
"source": [
|
185 |
+
"searcher.fit(segments, n_neighbors=5)"
|
186 |
+
]
|
187 |
+
},
|
188 |
+
{
|
189 |
+
"attachments": {},
|
190 |
+
"cell_type": "markdown",
|
191 |
+
"metadata": {},
|
192 |
+
"source": [
|
193 |
+
"## Answering the Query"
|
194 |
]
|
195 |
},
|
196 |
{
|
197 |
"cell_type": "code",
|
198 |
+
"execution_count": null,
|
199 |
"metadata": {},
|
200 |
"outputs": [],
|
201 |
"source": [
|
202 |
+
"def num_tokens(text, model):\n",
|
203 |
+
" encoding = tiktoken.encoding_for_model(model)\n",
|
204 |
+
" return len(encoding.encode(text))"
|
205 |
]
|
206 |
},
|
207 |
{
|
208 |
"cell_type": "code",
|
209 |
+
"execution_count": null,
|
210 |
"metadata": {},
|
211 |
"outputs": [],
|
212 |
"source": [
|
213 |
+
"def form_query(query, model, token_budget):\n",
|
214 |
+
"\n",
|
215 |
+
" relevant_segments = searcher(query)\n",
|
216 |
+
"\n",
|
217 |
+
" introduction = 'Use the below segments from multiple Reddit posts to answer the subsequent question. If the answer cannot be found in the articles, write \"I could not find an answer.\" Cite each sentence using the [postid] notation found at the start of each segment. Every sentence MUST have a citation!\\n\\n'\n",
|
218 |
+
"\n",
|
219 |
+
" message = introduction\n",
|
220 |
+
"\n",
|
221 |
+
" query = f\"\\n\\nQuestion: {query}\"\n",
|
222 |
"\n",
|
223 |
+
" evidence = []\n",
|
|
|
|
|
224 |
"\n",
|
225 |
+
" for i, result in enumerate(relevant_segments):\n",
|
226 |
+
" if (\n",
|
227 |
+
" num_tokens(message + result + query, model=model)\n",
|
228 |
+
" > token_budget\n",
|
229 |
+
" ):\n",
|
230 |
+
" break\n",
|
231 |
+
" else:\n",
|
232 |
+
" result = result + \"\\n\\n\"\n",
|
233 |
+
" message += result\n",
|
234 |
+
" evidence.append(result.split(\"\\n\")[0])\n",
|
235 |
+
"\n",
|
236 |
+
" evidence = list(set(evidence))\n",
|
237 |
+
"\n",
|
238 |
+
" return message + query, evidence"
|
239 |
]
|
240 |
},
|
241 |
{
|
242 |
"cell_type": "code",
|
243 |
+
"execution_count": null,
|
244 |
"metadata": {},
|
245 |
"outputs": [],
|
246 |
"source": [
|
247 |
+
"def generate_answer(query, model, token_budget, temperature):\n",
|
248 |
+
" \n",
|
249 |
+
" message, evidence = form_query(query, model, token_budget)\n",
|
250 |
+
"\n",
|
251 |
+
" messages = [\n",
|
252 |
+
" {\"role\": \"user\", \"content\": message},\n",
|
253 |
+
" ]\n",
|
254 |
+
"\n",
|
255 |
+
" print(message)\n",
|
256 |
+
"\n",
|
257 |
+
" response = openai.ChatCompletion.create(\n",
|
258 |
+
" model=model,\n",
|
259 |
+
" messages=messages,\n",
|
260 |
+
" temperature=temperature\n",
|
261 |
+
" )\n",
|
262 |
+
" \n",
|
263 |
+
" response_message = response[\"choices\"][0][\"message\"][\"content\"]\n",
|
264 |
+
"\n",
|
265 |
+
" return response_message, evidence"
|
266 |
]
|
267 |
},
|
268 |
{
|
269 |
"cell_type": "code",
|
270 |
+
"execution_count": null,
|
271 |
"metadata": {},
|
272 |
"outputs": [],
|
273 |
"source": [
|
274 |
+
"answer, evidence = generate_answer(query, \"gpt-3.5-turbo\", 1000, 0)"
|
275 |
]
|
276 |
},
|
277 |
{
|
278 |
"cell_type": "code",
|
279 |
+
"execution_count": null,
|
280 |
"metadata": {},
|
281 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
"source": [
|
283 |
+
"query"
|
284 |
]
|
285 |
},
|
286 |
{
|
287 |
"cell_type": "code",
|
288 |
+
"execution_count": null,
|
289 |
"metadata": {},
|
290 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
"source": [
|
292 |
+
"answer"
|
293 |
]
|
294 |
}
|
295 |
],
|
semantic_search.py
CHANGED
@@ -21,12 +21,12 @@ class SemanticSearch:
|
|
21 |
|
22 |
def __call__(self, text, return_data=True):
|
23 |
inp_emb = self.use([text])
|
24 |
-
neighbors = self.nn.kneighbors(inp_emb, return_distance=
|
25 |
|
26 |
if return_data:
|
27 |
-
return [self.data[i] for i in neighbors]
|
28 |
else:
|
29 |
-
return neighbors
|
30 |
|
31 |
|
32 |
def get_text_embedding(self, texts, batch=1000):
|
|
|
21 |
|
22 |
def __call__(self, text, return_data=True):
|
23 |
inp_emb = self.use([text])
|
24 |
+
neighbors, distances = self.nn.kneighbors(inp_emb, return_distance=True)[0]
|
25 |
|
26 |
if return_data:
|
27 |
+
return [self.data[i] for i in neighbors], distances
|
28 |
else:
|
29 |
+
return neighbors, distances
|
30 |
|
31 |
|
32 |
def get_text_embedding(self, texts, batch=1000):
|