wendru18 commited on
Commit
6ee98e6
·
1 Parent(s): 13db74a

added all steps, not fully fuctional

Browse files
Files changed (2) hide show
  1. main.ipynb +158 -78
  2. semantic_search.py +3 -3
main.ipynb CHANGED
@@ -2,26 +2,34 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 70,
6
  "metadata": {},
7
  "outputs": [],
8
  "source": [
9
  "from semantic_search import SemanticSearch \n",
10
  "import pandas as pd\n",
 
11
  "import openai\n",
12
  "import praw\n",
13
  "import os\n",
14
  "import re\n",
15
  "\n",
16
- "searcher = SemanticSearch()\n",
17
- "\n",
18
  "pd.set_option('max_colwidth', 100)\n",
19
  "pd.set_option('display.max_columns', None)"
20
  ]
21
  },
22
  {
23
  "cell_type": "code",
24
- "execution_count": 71,
 
 
 
 
 
 
 
 
 
25
  "metadata": {},
26
  "outputs": [],
27
  "source": [
@@ -32,16 +40,24 @@
32
  "reddit = praw.Reddit(client_id=REDDIT_CLIENT_ID, client_secret=REDDIT_CLIENT_SECRET, user_agent=REDDIT_USER_AGENT)"
33
  ]
34
  },
 
 
 
 
 
 
 
 
35
  {
36
  "cell_type": "code",
37
- "execution_count": 72,
38
  "metadata": {},
39
  "outputs": [],
40
  "source": [
41
  "def generate_topics(query, model=\"gpt-3.5-turbo\"):\n",
42
  "\n",
43
  " messages = [\n",
44
- " {\"role\": \"user\", \"content\": f\"Take this query '{query}' and return a list of topics to input in Search so it returns good results. Each topic must stand on its own with respect to the relation of the question.\"},\n",
45
  " ]\n",
46
  "\n",
47
  " response = openai.ChatCompletion.create(\n",
@@ -58,26 +74,35 @@
58
  },
59
  {
60
  "cell_type": "code",
61
- "execution_count": 73,
62
  "metadata": {},
63
- "outputs": [
64
- {
65
- "name": "stdout",
66
- "output_type": "stream",
67
- "text": [
68
- "['Top-rated pizza places in NYC', 'Best New York-style pizza restaurants', 'Highest reviewed pizzerias in New York City', 'Best pizza joints in Manhattan', 'Award-winning pizza spots in NYC', 'Famous pizza places in New York', 'Top 10 pizza places to try in NYC', 'Authentic pizza restaurants in New York', 'Best pizza slices in NYC', 'Gourmet pizza places in New York City']\n"
69
- ]
70
- }
71
- ],
 
72
  "source": [
73
- "query = \"Best pizza place in NYC\"\n",
74
  "topics = generate_topics(query)\n",
 
75
  "print(topics)"
76
  ]
77
  },
 
 
 
 
 
 
 
 
78
  {
79
  "cell_type": "code",
80
- "execution_count": 74,
81
  "metadata": {},
82
  "outputs": [],
83
  "source": [
@@ -88,128 +113,183 @@
88
  " topic, limit=200):\n",
89
  " posts.append([post.title, post.subreddit, post.selftext])\n",
90
  "\n",
91
- "posts = pd.DataFrame(posts,columns=['title', 'subreddit', 'text'])"
 
 
 
92
  ]
93
  },
94
  {
95
  "cell_type": "code",
96
- "execution_count": 75,
97
  "metadata": {},
98
  "outputs": [],
99
  "source": [
100
- "# Segments is title, text and subreddit at the end\n",
101
- "segments = (posts['title'] + ' ' + posts['subreddit'].astype(str)).tolist()"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  ]
103
  },
104
  {
105
  "cell_type": "code",
106
- "execution_count": 76,
107
  "metadata": {},
108
- "outputs": [
109
- {
110
- "name": "stdout",
111
- "output_type": "stream",
112
- "text": [
113
- "Fitting with n=5...\n"
114
- ]
115
- }
116
- ],
117
  "source": [
118
- "searcher.fit(segments)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  ]
120
  },
121
  {
122
  "cell_type": "code",
123
- "execution_count": 77,
124
  "metadata": {},
125
  "outputs": [],
126
  "source": [
127
- "# TODO: Add distance check here\n",
128
- "subreddits = set([result.split()[-1] for result in searcher(query)])"
 
 
 
 
 
 
 
129
  ]
130
  },
131
  {
132
  "cell_type": "code",
133
- "execution_count": 78,
134
  "metadata": {},
135
  "outputs": [],
136
  "source": [
137
- "# Convert to string and \"+\" in between\n",
138
- "subreddits = \"+\".join(subreddits)"
 
139
  ]
140
  },
141
  {
142
  "cell_type": "code",
143
- "execution_count": 79,
144
  "metadata": {},
145
  "outputs": [],
146
  "source": [
147
- "final_posts = []\n",
 
 
 
 
 
 
 
 
148
  "\n",
149
- "for post in reddit.subreddit(subreddits).search(\n",
150
- " query, limit=100):\n",
151
- " final_posts.append([post.title, post.subreddit, post.selftext])\n",
152
  "\n",
153
- "final_posts = pd.DataFrame(final_posts,columns=['title', 'subreddit', 'text'])"
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  ]
155
  },
156
  {
157
  "cell_type": "code",
158
- "execution_count": 80,
159
  "metadata": {},
160
  "outputs": [],
161
  "source": [
162
- "final_segments = (final_posts['title'] + ' ' + final_posts['subreddit'].astype(str)).tolist()"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  ]
164
  },
165
  {
166
  "cell_type": "code",
167
- "execution_count": 81,
168
  "metadata": {},
169
  "outputs": [],
170
  "source": [
171
- "final_searcher = SemanticSearch()"
172
  ]
173
  },
174
  {
175
  "cell_type": "code",
176
- "execution_count": 82,
177
  "metadata": {},
178
- "outputs": [
179
- {
180
- "name": "stdout",
181
- "output_type": "stream",
182
- "text": [
183
- "Fitting with n=5...\n"
184
- ]
185
- }
186
- ],
187
  "source": [
188
- "final_searcher.fit(final_segments)"
189
  ]
190
  },
191
  {
192
  "cell_type": "code",
193
- "execution_count": 83,
194
  "metadata": {},
195
- "outputs": [
196
- {
197
- "data": {
198
- "text/plain": [
199
- "['Best Pizza in Manhattan? FoodNYC',\n",
200
- " 'Best Pizza slice in NYC? FoodNYC',\n",
201
- " 'Best Pizza, Manhattan NYC at 40.713050, -74.007230 The8BitRyanReddit',\n",
202
- " \"Controversial: What's the best pizza place in all of NYC? newyorkcity\",\n",
203
- " \"Best pizza place in NYC that tourists DON'T know? circlejerknyc\"]"
204
- ]
205
- },
206
- "execution_count": 83,
207
- "metadata": {},
208
- "output_type": "execute_result"
209
- }
210
- ],
211
  "source": [
212
- "final_searcher(query)"
213
  ]
214
  }
215
  ],
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": null,
6
  "metadata": {},
7
  "outputs": [],
8
  "source": [
9
  "from semantic_search import SemanticSearch \n",
10
  "import pandas as pd\n",
11
+ "import tiktoken\n",
12
  "import openai\n",
13
  "import praw\n",
14
  "import os\n",
15
  "import re\n",
16
  "\n",
 
 
17
  "pd.set_option('max_colwidth', 100)\n",
18
  "pd.set_option('display.max_columns', None)"
19
  ]
20
  },
21
  {
22
  "cell_type": "code",
23
+ "execution_count": null,
24
+ "metadata": {},
25
+ "outputs": [],
26
+ "source": [
27
+ "searcher = SemanticSearch()"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": null,
33
  "metadata": {},
34
  "outputs": [],
35
  "source": [
 
40
  "reddit = praw.Reddit(client_id=REDDIT_CLIENT_ID, client_secret=REDDIT_CLIENT_SECRET, user_agent=REDDIT_USER_AGENT)"
41
  ]
42
  },
43
+ {
44
+ "attachments": {},
45
+ "cell_type": "markdown",
46
+ "metadata": {},
47
+ "source": [
48
+ "## Topic Retrieval"
49
+ ]
50
+ },
51
  {
52
  "cell_type": "code",
53
+ "execution_count": null,
54
  "metadata": {},
55
  "outputs": [],
56
  "source": [
57
  "def generate_topics(query, model=\"gpt-3.5-turbo\"):\n",
58
  "\n",
59
  " messages = [\n",
60
+ " {\"role\": \"user\", \"content\": f\"Take this query '{query}' and return a list of short topics to input in Search so it returns good results. Each topic must stand on its own with respect to the relation of the question.\"},\n",
61
  " ]\n",
62
  "\n",
63
  " response = openai.ChatCompletion.create(\n",
 
74
  },
75
  {
76
  "cell_type": "code",
77
+ "execution_count": null,
78
  "metadata": {},
79
+ "outputs": [],
80
+ "source": [
81
+ "query = \"Where are some nice places where I can work remotely in Malta?\""
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": null,
87
+ "metadata": {},
88
+ "outputs": [],
89
  "source": [
 
90
  "topics = generate_topics(query)\n",
91
+ "topics = [topic.strip() for topic in topics]\n",
92
  "print(topics)"
93
  ]
94
  },
95
+ {
96
+ "attachments": {},
97
+ "cell_type": "markdown",
98
+ "metadata": {},
99
+ "source": [
100
+ "## Relevant Subreddits Retrieval"
101
+ ]
102
+ },
103
  {
104
  "cell_type": "code",
105
+ "execution_count": null,
106
  "metadata": {},
107
  "outputs": [],
108
  "source": [
 
113
  " topic, limit=200):\n",
114
  " posts.append([post.title, post.subreddit, post.selftext])\n",
115
  "\n",
116
+ "posts = pd.DataFrame(posts,columns=['title', 'subreddit', 'text'])\n",
117
+ "\n",
118
+ "# Segments is title, text and subreddit at the end\n",
119
+ "segments = (posts['title'] + ' ' + posts['subreddit'].astype(str)).tolist()"
120
  ]
121
  },
122
  {
123
  "cell_type": "code",
124
+ "execution_count": null,
125
  "metadata": {},
126
  "outputs": [],
127
  "source": [
128
+ "searcher.fit(segments, n_neighbors=5)"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": null,
134
+ "metadata": {},
135
+ "outputs": [],
136
+ "source": [
137
+ "# TODO: Add distance check here\n",
138
+ "subreddits = set([result.split()[-1] for result in searcher(query)])\n",
139
+ "\n",
140
+ "# Convert to string and \"+\" in between\n",
141
+ "subreddits = \"+\".join(subreddits)\n",
142
+ "\n",
143
+ "print(f\"Relevant subreddits: {subreddits}\")"
144
+ ]
145
+ },
146
+ {
147
+ "attachments": {},
148
+ "cell_type": "markdown",
149
+ "metadata": {},
150
+ "source": [
151
+ "## Relevant Posts Retrieval"
152
  ]
153
  },
154
  {
155
  "cell_type": "code",
156
+ "execution_count": null,
157
  "metadata": {},
158
+ "outputs": [],
 
 
 
 
 
 
 
 
159
  "source": [
160
+ "segments = []\n",
161
+ "segment_length = 100\n",
162
+ "\n",
163
+ "\n",
164
+ "for topic in topics:\n",
165
+ " for post in reddit.subreddit(subreddits).search(\n",
166
+ " topic, limit=50):\n",
167
+ " \n",
168
+ " comments = \"\"\n",
169
+ "\n",
170
+ " post.comments.replace_more(limit=3)\n",
171
+ " for comment in post.comments.list():\n",
172
+ " if comment.body != \"[deleted]\":\n",
173
+ " comments += comment.body + \"\\n\"\n",
174
+ "\n",
175
+ " words = comments.split()\n",
176
+ " segments.extend([post.title + \" \" + post.id + \"\\n\" + ' '.join(words[i:i+segment_length]) for i in range(0, len(words), segment_length)])"
177
  ]
178
  },
179
  {
180
  "cell_type": "code",
181
+ "execution_count": null,
182
  "metadata": {},
183
  "outputs": [],
184
  "source": [
185
+ "searcher.fit(segments, n_neighbors=5)"
186
+ ]
187
+ },
188
+ {
189
+ "attachments": {},
190
+ "cell_type": "markdown",
191
+ "metadata": {},
192
+ "source": [
193
+ "## Answering the Query"
194
  ]
195
  },
196
  {
197
  "cell_type": "code",
198
+ "execution_count": null,
199
  "metadata": {},
200
  "outputs": [],
201
  "source": [
202
+ "def num_tokens(text, model):\n",
203
+ " encoding = tiktoken.encoding_for_model(model)\n",
204
+ " return len(encoding.encode(text))"
205
  ]
206
  },
207
  {
208
  "cell_type": "code",
209
+ "execution_count": null,
210
  "metadata": {},
211
  "outputs": [],
212
  "source": [
213
+ "def form_query(query, model, token_budget):\n",
214
+ "\n",
215
+ " relevant_segments = searcher(query)\n",
216
+ "\n",
217
+ " introduction = 'Use the below segments from multiple Reddit posts to answer the subsequent question. If the answer cannot be found in the articles, write \"I could not find an answer.\" Cite each sentence using the [postid] notation found at the start of each segment. Every sentence MUST have a citation!\\n\\n'\n",
218
+ "\n",
219
+ " message = introduction\n",
220
+ "\n",
221
+ " query = f\"\\n\\nQuestion: {query}\"\n",
222
  "\n",
223
+ " evidence = []\n",
 
 
224
  "\n",
225
+ " for i, result in enumerate(relevant_segments):\n",
226
+ " if (\n",
227
+ " num_tokens(message + result + query, model=model)\n",
228
+ " > token_budget\n",
229
+ " ):\n",
230
+ " break\n",
231
+ " else:\n",
232
+ " result = result + \"\\n\\n\"\n",
233
+ " message += result\n",
234
+ " evidence.append(result.split(\"\\n\")[0])\n",
235
+ "\n",
236
+ " evidence = list(set(evidence))\n",
237
+ "\n",
238
+ " return message + query, evidence"
239
  ]
240
  },
241
  {
242
  "cell_type": "code",
243
+ "execution_count": null,
244
  "metadata": {},
245
  "outputs": [],
246
  "source": [
247
+ "def generate_answer(query, model, token_budget, temperature):\n",
248
+ " \n",
249
+ " message, evidence = form_query(query, model, token_budget)\n",
250
+ "\n",
251
+ " messages = [\n",
252
+ " {\"role\": \"user\", \"content\": message},\n",
253
+ " ]\n",
254
+ "\n",
255
+ " print(message)\n",
256
+ "\n",
257
+ " response = openai.ChatCompletion.create(\n",
258
+ " model=model,\n",
259
+ " messages=messages,\n",
260
+ " temperature=temperature\n",
261
+ " )\n",
262
+ " \n",
263
+ " response_message = response[\"choices\"][0][\"message\"][\"content\"]\n",
264
+ "\n",
265
+ " return response_message, evidence"
266
  ]
267
  },
268
  {
269
  "cell_type": "code",
270
+ "execution_count": null,
271
  "metadata": {},
272
  "outputs": [],
273
  "source": [
274
+ "answer, evidence = generate_answer(query, \"gpt-3.5-turbo\", 1000, 0)"
275
  ]
276
  },
277
  {
278
  "cell_type": "code",
279
+ "execution_count": null,
280
  "metadata": {},
281
+ "outputs": [],
 
 
 
 
 
 
 
 
282
  "source": [
283
+ "query"
284
  ]
285
  },
286
  {
287
  "cell_type": "code",
288
+ "execution_count": null,
289
  "metadata": {},
290
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  "source": [
292
+ "answer"
293
  ]
294
  }
295
  ],
semantic_search.py CHANGED
@@ -21,12 +21,12 @@ class SemanticSearch:
21
 
22
  def __call__(self, text, return_data=True):
23
  inp_emb = self.use([text])
24
- neighbors = self.nn.kneighbors(inp_emb, return_distance=False)[0]
25
 
26
  if return_data:
27
- return [self.data[i] for i in neighbors]
28
  else:
29
- return neighbors
30
 
31
 
32
  def get_text_embedding(self, texts, batch=1000):
 
21
 
22
  def __call__(self, text, return_data=True):
23
  inp_emb = self.use([text])
24
+ neighbors, distances = self.nn.kneighbors(inp_emb, return_distance=True)[0]
25
 
26
  if return_data:
27
+ return [self.data[i] for i in neighbors], distances
28
  else:
29
+ return neighbors, distances
30
 
31
 
32
  def get_text_embedding(self, texts, batch=1000):