wendru18 commited on
Commit
1b5e4ab
·
1 Parent(s): bffb796

method works well

Browse files
__pycache__/semantic_search.cpython-38.pyc ADDED
Binary file (1.75 kB). View file
 
main.ipynb CHANGED
@@ -2,13 +2,18 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 30,
6
  "metadata": {},
7
  "outputs": [],
8
  "source": [
 
9
  "import pandas as pd\n",
 
10
  "import praw\n",
11
  "import os\n",
 
 
 
12
  "\n",
13
  "pd.set_option('max_colwidth', 100)\n",
14
  "pd.set_option('display.max_columns', None)"
@@ -16,7 +21,7 @@
16
  },
17
  {
18
  "cell_type": "code",
19
- "execution_count": 1,
20
  "metadata": {},
21
  "outputs": [],
22
  "source": [
@@ -29,19 +34,182 @@
29
  },
30
  {
31
  "cell_type": "code",
32
- "execution_count": 233,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  "metadata": {},
34
  "outputs": [],
35
  "source": [
36
  "posts = []\n",
37
- "comments = []\n",
38
  "\n",
39
- "query = \"Which place sells the best pastizzis in Malta?\"\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  "\n",
41
- "for post in reddit.subreddit(\"all\").search(\"Malta cuisine\", limit=200):\n",
42
- " posts.append([post.title, post.subreddit, post.selftext, post.created])\n",
 
43
  "\n",
44
- "posts = pd.DataFrame(posts,columns=['title', 'subreddit', 'text', 'created'])"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  ]
46
  }
47
  ],
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 70,
6
  "metadata": {},
7
  "outputs": [],
8
  "source": [
9
+ "from semantic_search import SemanticSearch \n",
10
  "import pandas as pd\n",
11
+ "import openai\n",
12
  "import praw\n",
13
  "import os\n",
14
+ "import re\n",
15
+ "\n",
16
+ "searcher = SemanticSearch()\n",
17
  "\n",
18
  "pd.set_option('max_colwidth', 100)\n",
19
  "pd.set_option('display.max_columns', None)"
 
21
  },
22
  {
23
  "cell_type": "code",
24
+ "execution_count": 71,
25
  "metadata": {},
26
  "outputs": [],
27
  "source": [
 
34
  },
35
  {
36
  "cell_type": "code",
37
+ "execution_count": 72,
38
+ "metadata": {},
39
+ "outputs": [],
40
+ "source": [
41
+ "def generate_topics(query, model=\"gpt-3.5-turbo\"):\n",
42
+ "\n",
43
+ " messages = [\n",
44
+ " {\"role\": \"user\", \"content\": f\"Take this query '{query}' and return a list of topics to input in Search so it returns good results. Each topic must stand on its own with respect to the relation of the question.\"},\n",
45
+ " ]\n",
46
+ "\n",
47
+ " response = openai.ChatCompletion.create(\n",
48
+ " model=model,\n",
49
+ " messages=messages\n",
50
+ " )\n",
51
+ "\n",
52
+ " response_message = response[\"choices\"][0][\"message\"][\"content\"]\n",
53
+ "\n",
54
+ " topics = re.sub(r'^\\d+\\.\\s*', '', response_message, flags=re.MULTILINE).split(\"\\n\")\n",
55
+ "\n",
56
+ " return topics"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": 73,
62
+ "metadata": {},
63
+ "outputs": [
64
+ {
65
+ "name": "stdout",
66
+ "output_type": "stream",
67
+ "text": [
68
+ "['Top-rated pizza places in NYC', 'Best New York-style pizza restaurants', 'Highest reviewed pizzerias in New York City', 'Best pizza joints in Manhattan', 'Award-winning pizza spots in NYC', 'Famous pizza places in New York', 'Top 10 pizza places to try in NYC', 'Authentic pizza restaurants in New York', 'Best pizza slices in NYC', 'Gourmet pizza places in New York City']\n"
69
+ ]
70
+ }
71
+ ],
72
+ "source": [
73
+ "query = \"Best pizza place in NYC\"\n",
74
+ "topics = generate_topics(query)\n",
75
+ "print(topics)"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": 74,
81
  "metadata": {},
82
  "outputs": [],
83
  "source": [
84
  "posts = []\n",
 
85
  "\n",
86
+ "for topic in topics:\n",
87
+ " for post in reddit.subreddit(\"all\").search(\n",
88
+ " topic, limit=200):\n",
89
+ " posts.append([post.title, post.subreddit, post.selftext])\n",
90
+ "\n",
91
+ "posts = pd.DataFrame(posts,columns=['title', 'subreddit', 'text'])"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": 75,
97
+ "metadata": {},
98
+ "outputs": [],
99
+ "source": [
100
+ "# Segments is title, text and subreddit at the end\n",
101
+ "segments = (posts['title'] + ' ' + posts['subreddit'].astype(str)).tolist()"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": 76,
107
+ "metadata": {},
108
+ "outputs": [
109
+ {
110
+ "name": "stdout",
111
+ "output_type": "stream",
112
+ "text": [
113
+ "Fitting with n=5...\n"
114
+ ]
115
+ }
116
+ ],
117
+ "source": [
118
+ "searcher.fit(segments)"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": 77,
124
+ "metadata": {},
125
+ "outputs": [],
126
+ "source": [
127
+ "# TODO: Add distance check here\n",
128
+ "subreddits = set([result.split()[-1] for result in searcher(query)])"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": 78,
134
+ "metadata": {},
135
+ "outputs": [],
136
+ "source": [
137
+ "# Convert to string and \"+\" in between\n",
138
+ "subreddits = \"+\".join(subreddits)"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": 79,
144
+ "metadata": {},
145
+ "outputs": [],
146
+ "source": [
147
+ "final_posts = []\n",
148
  "\n",
149
+ "for post in reddit.subreddit(subreddits).search(\n",
150
+ " query, limit=100):\n",
151
+ " final_posts.append([post.title, post.subreddit, post.selftext])\n",
152
  "\n",
153
+ "final_posts = pd.DataFrame(final_posts,columns=['title', 'subreddit', 'text'])"
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "execution_count": 80,
159
+ "metadata": {},
160
+ "outputs": [],
161
+ "source": [
162
+ "final_segments = (final_posts['title'] + ' ' + final_posts['subreddit'].astype(str)).tolist()"
163
+ ]
164
+ },
165
+ {
166
+ "cell_type": "code",
167
+ "execution_count": 81,
168
+ "metadata": {},
169
+ "outputs": [],
170
+ "source": [
171
+ "final_searcher = SemanticSearch()"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "execution_count": 82,
177
+ "metadata": {},
178
+ "outputs": [
179
+ {
180
+ "name": "stdout",
181
+ "output_type": "stream",
182
+ "text": [
183
+ "Fitting with n=5...\n"
184
+ ]
185
+ }
186
+ ],
187
+ "source": [
188
+ "final_searcher.fit(final_segments)"
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "execution_count": 83,
194
+ "metadata": {},
195
+ "outputs": [
196
+ {
197
+ "data": {
198
+ "text/plain": [
199
+ "['Best Pizza in Manhattan? FoodNYC',\n",
200
+ " 'Best Pizza slice in NYC? FoodNYC',\n",
201
+ " 'Best Pizza, Manhattan NYC at 40.713050, -74.007230 The8BitRyanReddit',\n",
202
+ " \"Controversial: What's the best pizza place in all of NYC? newyorkcity\",\n",
203
+ " \"Best pizza place in NYC that tourists DON'T know? circlejerknyc\"]"
204
+ ]
205
+ },
206
+ "execution_count": 83,
207
+ "metadata": {},
208
+ "output_type": "execute_result"
209
+ }
210
+ ],
211
+ "source": [
212
+ "final_searcher(query)"
213
  ]
214
  }
215
  ],
semantic_search.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.neighbors import NearestNeighbors
2
+ import tensorflow_hub as hub
3
+ import numpy as np
4
+
5
+ class SemanticSearch:
6
+
7
+ def __init__(self):
8
+ self.use = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
9
+ self.fitted = False
10
+
11
+
12
+ def fit(self, data, batch=1000, n_neighbors=5):
13
+ print(f"Fitting with n={n_neighbors}...")
14
+ self.data = data
15
+ self.embeddings = self.get_text_embedding(data, batch=batch)
16
+ n_neighbors = min(n_neighbors, len(self.embeddings))
17
+ self.nn = NearestNeighbors(n_neighbors=n_neighbors)
18
+ self.nn.fit(self.embeddings)
19
+ self.fitted = True
20
+
21
+
22
+ def __call__(self, text, return_data=True):
23
+ inp_emb = self.use([text])
24
+ neighbors = self.nn.kneighbors(inp_emb, return_distance=False)[0]
25
+
26
+ if return_data:
27
+ return [self.data[i] for i in neighbors]
28
+ else:
29
+ return neighbors
30
+
31
+
32
+ def get_text_embedding(self, texts, batch=1000):
33
+ embeddings = []
34
+ for i in range(0, len(texts), batch):
35
+ text_batch = texts[i:(i+batch)]
36
+ emb_batch = self.use(text_batch)
37
+ embeddings.append(emb_batch)
38
+ embeddings = np.vstack(embeddings)
39
+ return embeddings