Spaces:
Runtime error
Runtime error
wendru18
commited on
Commit
·
7da8c71
1
Parent(s):
6ee98e6
added langchain
Browse files- main.ipynb +158 -141
- semantic_search.py +3 -3
main.ipynb
CHANGED
@@ -2,13 +2,11 @@
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
-
"execution_count":
|
6 |
"metadata": {},
|
7 |
"outputs": [],
|
8 |
"source": [
|
9 |
-
"from semantic_search import SemanticSearch \n",
|
10 |
"import pandas as pd\n",
|
11 |
-
"import tiktoken\n",
|
12 |
"import openai\n",
|
13 |
"import praw\n",
|
14 |
"import os\n",
|
@@ -20,16 +18,23 @@
|
|
20 |
},
|
21 |
{
|
22 |
"cell_type": "code",
|
23 |
-
"execution_count":
|
24 |
"metadata": {},
|
25 |
"outputs": [],
|
26 |
"source": [
|
27 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
]
|
29 |
},
|
30 |
{
|
31 |
"cell_type": "code",
|
32 |
-
"execution_count":
|
33 |
"metadata": {},
|
34 |
"outputs": [],
|
35 |
"source": [
|
@@ -50,19 +55,20 @@
|
|
50 |
},
|
51 |
{
|
52 |
"cell_type": "code",
|
53 |
-
"execution_count":
|
54 |
"metadata": {},
|
55 |
"outputs": [],
|
56 |
"source": [
|
57 |
"def generate_topics(query, model=\"gpt-3.5-turbo\"):\n",
|
58 |
"\n",
|
59 |
" messages = [\n",
|
60 |
-
" {\"role\": \"user\", \"content\": f\"Take this query '{query}' and return a list of
|
61 |
" ]\n",
|
62 |
"\n",
|
63 |
" response = openai.ChatCompletion.create(\n",
|
64 |
" model=model,\n",
|
65 |
-
" messages=messages
|
|
|
66 |
" )\n",
|
67 |
"\n",
|
68 |
" response_message = response[\"choices\"][0][\"message\"][\"content\"]\n",
|
@@ -74,21 +80,30 @@
|
|
74 |
},
|
75 |
{
|
76 |
"cell_type": "code",
|
77 |
-
"execution_count":
|
78 |
"metadata": {},
|
79 |
"outputs": [],
|
80 |
"source": [
|
81 |
-
"query = \"
|
82 |
]
|
83 |
},
|
84 |
{
|
85 |
"cell_type": "code",
|
86 |
-
"execution_count":
|
87 |
"metadata": {},
|
88 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
"source": [
|
90 |
"topics = generate_topics(query)\n",
|
91 |
"topics = [topic.strip() for topic in topics]\n",
|
|
|
92 |
"print(topics)"
|
93 |
]
|
94 |
},
|
@@ -97,50 +112,37 @@
|
|
97 |
"cell_type": "markdown",
|
98 |
"metadata": {},
|
99 |
"source": [
|
100 |
-
"## Relevant
|
101 |
]
|
102 |
},
|
103 |
{
|
104 |
"cell_type": "code",
|
105 |
-
"execution_count":
|
106 |
"metadata": {},
|
107 |
"outputs": [],
|
108 |
"source": [
|
109 |
"posts = []\n",
|
|
|
110 |
"\n",
|
111 |
"for topic in topics:\n",
|
112 |
" for post in reddit.subreddit(\"all\").search(\n",
|
113 |
-
" topic, limit=
|
114 |
-
" posts.append([post.title, post.subreddit, post.selftext])\n",
|
|
|
115 |
"\n",
|
116 |
-
"
|
|
|
117 |
"\n",
|
118 |
-
"
|
119 |
-
"segments = (posts['title'] + ' ' + posts['subreddit'].astype(str)).tolist()"
|
120 |
]
|
121 |
},
|
122 |
{
|
123 |
"cell_type": "code",
|
124 |
-
"execution_count":
|
125 |
"metadata": {},
|
126 |
"outputs": [],
|
127 |
"source": [
|
128 |
-
"
|
129 |
-
]
|
130 |
-
},
|
131 |
-
{
|
132 |
-
"cell_type": "code",
|
133 |
-
"execution_count": null,
|
134 |
-
"metadata": {},
|
135 |
-
"outputs": [],
|
136 |
-
"source": [
|
137 |
-
"# TODO: Add distance check here\n",
|
138 |
-
"subreddits = set([result.split()[-1] for result in searcher(query)])\n",
|
139 |
-
"\n",
|
140 |
-
"# Convert to string and \"+\" in between\n",
|
141 |
-
"subreddits = \"+\".join(subreddits)\n",
|
142 |
-
"\n",
|
143 |
-
"print(f\"Relevant subreddits: {subreddits}\")"
|
144 |
]
|
145 |
},
|
146 |
{
|
@@ -148,148 +150,163 @@
|
|
148 |
"cell_type": "markdown",
|
149 |
"metadata": {},
|
150 |
"source": [
|
151 |
-
"##
|
152 |
]
|
153 |
},
|
154 |
{
|
155 |
"cell_type": "code",
|
156 |
-
"execution_count":
|
157 |
"metadata": {},
|
158 |
"outputs": [],
|
159 |
"source": [
|
160 |
-
"
|
161 |
-
"
|
162 |
-
"\n",
|
163 |
-
"\n",
|
164 |
-
"for topic in topics:\n",
|
165 |
-
" for post in reddit.subreddit(subreddits).search(\n",
|
166 |
-
" topic, limit=50):\n",
|
167 |
-
" \n",
|
168 |
-
" comments = \"\"\n",
|
169 |
-
"\n",
|
170 |
-
" post.comments.replace_more(limit=3)\n",
|
171 |
-
" for comment in post.comments.list():\n",
|
172 |
-
" if comment.body != \"[deleted]\":\n",
|
173 |
-
" comments += comment.body + \"\\n\"\n",
|
174 |
-
"\n",
|
175 |
-
" words = comments.split()\n",
|
176 |
-
" segments.extend([post.title + \" \" + post.id + \"\\n\" + ' '.join(words[i:i+segment_length]) for i in range(0, len(words), segment_length)])"
|
177 |
]
|
178 |
},
|
179 |
{
|
180 |
"cell_type": "code",
|
181 |
-
"execution_count":
|
182 |
"metadata": {},
|
183 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
"source": [
|
185 |
-
"
|
186 |
-
|
187 |
-
},
|
188 |
-
{
|
189 |
-
"attachments": {},
|
190 |
-
"cell_type": "markdown",
|
191 |
-
"metadata": {},
|
192 |
-
"source": [
|
193 |
-
"## Answering the Query"
|
194 |
-
]
|
195 |
-
},
|
196 |
-
{
|
197 |
-
"cell_type": "code",
|
198 |
-
"execution_count": null,
|
199 |
-
"metadata": {},
|
200 |
-
"outputs": [],
|
201 |
-
"source": [
|
202 |
-
"def num_tokens(text, model):\n",
|
203 |
-
" encoding = tiktoken.encoding_for_model(model)\n",
|
204 |
-
" return len(encoding.encode(text))"
|
205 |
-
]
|
206 |
-
},
|
207 |
-
{
|
208 |
-
"cell_type": "code",
|
209 |
-
"execution_count": null,
|
210 |
-
"metadata": {},
|
211 |
-
"outputs": [],
|
212 |
-
"source": [
|
213 |
-
"def form_query(query, model, token_budget):\n",
|
214 |
-
"\n",
|
215 |
-
" relevant_segments = searcher(query)\n",
|
216 |
-
"\n",
|
217 |
-
" introduction = 'Use the below segments from multiple Reddit posts to answer the subsequent question. If the answer cannot be found in the articles, write \"I could not find an answer.\" Cite each sentence using the [postid] notation found at the start of each segment. Every sentence MUST have a citation!\\n\\n'\n",
|
218 |
-
"\n",
|
219 |
-
" message = introduction\n",
|
220 |
-
"\n",
|
221 |
-
" query = f\"\\n\\nQuestion: {query}\"\n",
|
222 |
-
"\n",
|
223 |
-
" evidence = []\n",
|
224 |
-
"\n",
|
225 |
-
" for i, result in enumerate(relevant_segments):\n",
|
226 |
-
" if (\n",
|
227 |
-
" num_tokens(message + result + query, model=model)\n",
|
228 |
-
" > token_budget\n",
|
229 |
-
" ):\n",
|
230 |
-
" break\n",
|
231 |
-
" else:\n",
|
232 |
-
" result = result + \"\\n\\n\"\n",
|
233 |
-
" message += result\n",
|
234 |
-
" evidence.append(result.split(\"\\n\")[0])\n",
|
235 |
-
"\n",
|
236 |
-
" evidence = list(set(evidence))\n",
|
237 |
"\n",
|
238 |
-
"
|
239 |
]
|
240 |
},
|
241 |
{
|
242 |
"cell_type": "code",
|
243 |
-
"execution_count":
|
244 |
"metadata": {},
|
245 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
"source": [
|
247 |
-
"
|
248 |
-
" \n",
|
249 |
-
" message, evidence = form_query(query, model, token_budget)\n",
|
250 |
-
"\n",
|
251 |
-
" messages = [\n",
|
252 |
-
" {\"role\": \"user\", \"content\": message},\n",
|
253 |
-
" ]\n",
|
254 |
-
"\n",
|
255 |
-
" print(message)\n",
|
256 |
-
"\n",
|
257 |
-
" response = openai.ChatCompletion.create(\n",
|
258 |
-
" model=model,\n",
|
259 |
-
" messages=messages,\n",
|
260 |
-
" temperature=temperature\n",
|
261 |
-
" )\n",
|
262 |
-
" \n",
|
263 |
-
" response_message = response[\"choices\"][0][\"message\"][\"content\"]\n",
|
264 |
-
"\n",
|
265 |
-
" return response_message, evidence"
|
266 |
]
|
267 |
},
|
268 |
{
|
269 |
"cell_type": "code",
|
270 |
-
"execution_count":
|
271 |
"metadata": {},
|
272 |
"outputs": [],
|
273 |
"source": [
|
274 |
-
"
|
275 |
]
|
276 |
},
|
277 |
{
|
278 |
"cell_type": "code",
|
279 |
-
"execution_count":
|
280 |
"metadata": {},
|
281 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
"source": [
|
283 |
-
"
|
284 |
]
|
285 |
},
|
286 |
{
|
287 |
"cell_type": "code",
|
288 |
-
"execution_count":
|
289 |
"metadata": {},
|
290 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
"source": [
|
292 |
-
"
|
|
|
293 |
]
|
294 |
}
|
295 |
],
|
@@ -309,7 +326,7 @@
|
|
309 |
"name": "python",
|
310 |
"nbconvert_exporter": "python",
|
311 |
"pygments_lexer": "ipython3",
|
312 |
-
"version": "3.
|
313 |
},
|
314 |
"orig_nbformat": 4
|
315 |
},
|
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
+
"execution_count": 94,
|
6 |
"metadata": {},
|
7 |
"outputs": [],
|
8 |
"source": [
|
|
|
9 |
"import pandas as pd\n",
|
|
|
10 |
"import openai\n",
|
11 |
"import praw\n",
|
12 |
"import os\n",
|
|
|
18 |
},
|
19 |
{
|
20 |
"cell_type": "code",
|
21 |
+
"execution_count": 95,
|
22 |
"metadata": {},
|
23 |
"outputs": [],
|
24 |
"source": [
|
25 |
+
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
26 |
+
"from langchain.text_splitter import CharacterTextSplitter\n",
|
27 |
+
"from langchain.vectorstores import Chroma\n",
|
28 |
+
"from langchain.docstore.document import Document\n",
|
29 |
+
"from langchain.prompts import PromptTemplate\n",
|
30 |
+
"from langchain.indexes.vectorstore import VectorstoreIndexCreator\n",
|
31 |
+
"from langchain.chains.qa_with_sources import load_qa_with_sources_chain\n",
|
32 |
+
"from langchain.llms import OpenAI"
|
33 |
]
|
34 |
},
|
35 |
{
|
36 |
"cell_type": "code",
|
37 |
+
"execution_count": 96,
|
38 |
"metadata": {},
|
39 |
"outputs": [],
|
40 |
"source": [
|
|
|
55 |
},
|
56 |
{
|
57 |
"cell_type": "code",
|
58 |
+
"execution_count": 97,
|
59 |
"metadata": {},
|
60 |
"outputs": [],
|
61 |
"source": [
|
62 |
"def generate_topics(query, model=\"gpt-3.5-turbo\"):\n",
|
63 |
"\n",
|
64 |
" messages = [\n",
|
65 |
+
" {\"role\": \"user\", \"content\": f\"Take this query '{query}' and return a list of 10 simple to understand topics (3 words or less) to input in Search so it returns good results.\"},\n",
|
66 |
" ]\n",
|
67 |
"\n",
|
68 |
" response = openai.ChatCompletion.create(\n",
|
69 |
" model=model,\n",
|
70 |
+
" messages=messages,\n",
|
71 |
+
" temperature=0\n",
|
72 |
" )\n",
|
73 |
"\n",
|
74 |
" response_message = response[\"choices\"][0][\"message\"][\"content\"]\n",
|
|
|
80 |
},
|
81 |
{
|
82 |
"cell_type": "code",
|
83 |
+
"execution_count": 108,
|
84 |
"metadata": {},
|
85 |
"outputs": [],
|
86 |
"source": [
|
87 |
+
"query = \"Are we in a recession right now?\""
|
88 |
]
|
89 |
},
|
90 |
{
|
91 |
"cell_type": "code",
|
92 |
+
"execution_count": 109,
|
93 |
"metadata": {},
|
94 |
+
"outputs": [
|
95 |
+
{
|
96 |
+
"name": "stdout",
|
97 |
+
"output_type": "stream",
|
98 |
+
"text": [
|
99 |
+
"['Current economic status', 'Recession indicators', 'Unemployment rates', 'GDP growth rate', 'Consumer spending trends', 'Stock market performance', 'Federal Reserve actions', 'Economic stimulus packages', 'Business closures impact', 'Housing market trends']\n"
|
100 |
+
]
|
101 |
+
}
|
102 |
+
],
|
103 |
"source": [
|
104 |
"topics = generate_topics(query)\n",
|
105 |
"topics = [topic.strip() for topic in topics]\n",
|
106 |
+
"topics = [topic[1:-1] if (topic.startswith('\"') and topic.endswith('\"')) or (topic.startswith(\"'\") and topic.endswith(\"'\")) else topic for topic in topics]\n",
|
107 |
"print(topics)"
|
108 |
]
|
109 |
},
|
|
|
112 |
"cell_type": "markdown",
|
113 |
"metadata": {},
|
114 |
"source": [
|
115 |
+
"## Relevant Subreddit Retrieval"
|
116 |
]
|
117 |
},
|
118 |
{
|
119 |
"cell_type": "code",
|
120 |
+
"execution_count": 110,
|
121 |
"metadata": {},
|
122 |
"outputs": [],
|
123 |
"source": [
|
124 |
"posts = []\n",
|
125 |
+
"comments = []\n",
|
126 |
"\n",
|
127 |
"for topic in topics:\n",
|
128 |
" for post in reddit.subreddit(\"all\").search(\n",
|
129 |
+
" topic, limit=5):\n",
|
130 |
+
" posts.append([post.id, post.title, post.subreddit, post.selftext])\n",
|
131 |
+
" post.comments.replace_more(limit=1)\n",
|
132 |
"\n",
|
133 |
+
" for comment in post.comments.list():\n",
|
134 |
+
" posts.append([post.id, post.title, post.subreddit, comment.body])\n",
|
135 |
"\n",
|
136 |
+
"posts = pd.DataFrame(posts,columns=['source', 'title', 'subreddit', 'text'])"
|
|
|
137 |
]
|
138 |
},
|
139 |
{
|
140 |
"cell_type": "code",
|
141 |
+
"execution_count": 111,
|
142 |
"metadata": {},
|
143 |
"outputs": [],
|
144 |
"source": [
|
145 |
+
"posts[\"subreddit\"] = posts[\"subreddit\"].apply(lambda x: x.display_name)\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
]
|
147 |
},
|
148 |
{
|
|
|
150 |
"cell_type": "markdown",
|
151 |
"metadata": {},
|
152 |
"source": [
|
153 |
+
"## Answering Query with Langchain"
|
154 |
]
|
155 |
},
|
156 |
{
|
157 |
"cell_type": "code",
|
158 |
+
"execution_count": 112,
|
159 |
"metadata": {},
|
160 |
"outputs": [],
|
161 |
"source": [
|
162 |
+
"text = posts[\"text\"].tolist()\n",
|
163 |
+
"text = \" \".join(text)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
]
|
165 |
},
|
166 |
{
|
167 |
"cell_type": "code",
|
168 |
+
"execution_count": 113,
|
169 |
"metadata": {},
|
170 |
+
"outputs": [
|
171 |
+
{
|
172 |
+
"name": "stderr",
|
173 |
+
"output_type": "stream",
|
174 |
+
"text": [
|
175 |
+
"Created a chunk of size 1635, which is longer than the specified 1000\n",
|
176 |
+
"Created a chunk of size 1298, which is longer than the specified 1000\n",
|
177 |
+
"Created a chunk of size 1109, which is longer than the specified 1000\n",
|
178 |
+
"Created a chunk of size 2072, which is longer than the specified 1000\n",
|
179 |
+
"Created a chunk of size 1498, which is longer than the specified 1000\n",
|
180 |
+
"Created a chunk of size 1419, which is longer than the specified 1000\n",
|
181 |
+
"Created a chunk of size 1127, which is longer than the specified 1000\n",
|
182 |
+
"Created a chunk of size 1576, which is longer than the specified 1000\n",
|
183 |
+
"Created a chunk of size 1314, which is longer than the specified 1000\n",
|
184 |
+
"Created a chunk of size 2563, which is longer than the specified 1000\n",
|
185 |
+
"Created a chunk of size 1287, which is longer than the specified 1000\n",
|
186 |
+
"Created a chunk of size 1649, which is longer than the specified 1000\n",
|
187 |
+
"Created a chunk of size 1616, which is longer than the specified 1000\n",
|
188 |
+
"Created a chunk of size 1573, which is longer than the specified 1000\n",
|
189 |
+
"Created a chunk of size 1024, which is longer than the specified 1000\n",
|
190 |
+
"Created a chunk of size 1395, which is longer than the specified 1000\n",
|
191 |
+
"Created a chunk of size 1712, which is longer than the specified 1000\n",
|
192 |
+
"Created a chunk of size 1175, which is longer than the specified 1000\n",
|
193 |
+
"Created a chunk of size 3872, which is longer than the specified 1000\n",
|
194 |
+
"Created a chunk of size 1098, which is longer than the specified 1000\n",
|
195 |
+
"Created a chunk of size 1429, which is longer than the specified 1000\n",
|
196 |
+
"Created a chunk of size 1002, which is longer than the specified 1000\n",
|
197 |
+
"Created a chunk of size 2241, which is longer than the specified 1000\n",
|
198 |
+
"Created a chunk of size 1923, which is longer than the specified 1000\n",
|
199 |
+
"Created a chunk of size 1716, which is longer than the specified 1000\n",
|
200 |
+
"Created a chunk of size 2563, which is longer than the specified 1000\n",
|
201 |
+
"Created a chunk of size 1221, which is longer than the specified 1000\n",
|
202 |
+
"Created a chunk of size 2449, which is longer than the specified 1000\n",
|
203 |
+
"Created a chunk of size 1321, which is longer than the specified 1000\n",
|
204 |
+
"Created a chunk of size 1302, which is longer than the specified 1000\n",
|
205 |
+
"Created a chunk of size 2182, which is longer than the specified 1000\n",
|
206 |
+
"Created a chunk of size 1027, which is longer than the specified 1000\n",
|
207 |
+
"Created a chunk of size 1156, which is longer than the specified 1000\n",
|
208 |
+
"Created a chunk of size 7334, which is longer than the specified 1000\n",
|
209 |
+
"Created a chunk of size 1849, which is longer than the specified 1000\n",
|
210 |
+
"Created a chunk of size 2829, which is longer than the specified 1000\n",
|
211 |
+
"Created a chunk of size 1567, which is longer than the specified 1000\n",
|
212 |
+
"Created a chunk of size 1245, which is longer than the specified 1000\n",
|
213 |
+
"Created a chunk of size 1299, which is longer than the specified 1000\n",
|
214 |
+
"Created a chunk of size 1003, which is longer than the specified 1000\n",
|
215 |
+
"Created a chunk of size 1327, which is longer than the specified 1000\n",
|
216 |
+
"Created a chunk of size 2079, which is longer than the specified 1000\n",
|
217 |
+
"Created a chunk of size 2780, which is longer than the specified 1000\n",
|
218 |
+
"Created a chunk of size 1522, which is longer than the specified 1000\n",
|
219 |
+
"Created a chunk of size 1766, which is longer than the specified 1000\n",
|
220 |
+
"Created a chunk of size 1079, which is longer than the specified 1000\n",
|
221 |
+
"Created a chunk of size 1080, which is longer than the specified 1000\n",
|
222 |
+
"Created a chunk of size 1755, which is longer than the specified 1000\n",
|
223 |
+
"Created a chunk of size 1232, which is longer than the specified 1000\n",
|
224 |
+
"Created a chunk of size 1279, which is longer than the specified 1000\n",
|
225 |
+
"Created a chunk of size 3189, which is longer than the specified 1000\n",
|
226 |
+
"Created a chunk of size 1549, which is longer than the specified 1000\n",
|
227 |
+
"Created a chunk of size 1124, which is longer than the specified 1000\n",
|
228 |
+
"Created a chunk of size 1033, which is longer than the specified 1000\n",
|
229 |
+
"Created a chunk of size 1676, which is longer than the specified 1000\n",
|
230 |
+
"Created a chunk of size 1011, which is longer than the specified 1000\n",
|
231 |
+
"Created a chunk of size 1723, which is longer than the specified 1000\n"
|
232 |
+
]
|
233 |
+
}
|
234 |
+
],
|
235 |
"source": [
|
236 |
+
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
237 |
+
"texts = text_splitter.split_text(text)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
"\n",
|
239 |
+
"embeddings = OpenAIEmbeddings()"
|
240 |
]
|
241 |
},
|
242 |
{
|
243 |
"cell_type": "code",
|
244 |
+
"execution_count": 114,
|
245 |
"metadata": {},
|
246 |
+
"outputs": [
|
247 |
+
{
|
248 |
+
"name": "stderr",
|
249 |
+
"output_type": "stream",
|
250 |
+
"text": [
|
251 |
+
"Using embedded DuckDB without persistence: data will be transient\n"
|
252 |
+
]
|
253 |
+
}
|
254 |
+
],
|
255 |
"source": [
|
256 |
+
"docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{\"source\": str(i)} for i in range(len(texts))]).as_retriever()"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
]
|
258 |
},
|
259 |
{
|
260 |
"cell_type": "code",
|
261 |
+
"execution_count": 115,
|
262 |
"metadata": {},
|
263 |
"outputs": [],
|
264 |
"source": [
|
265 |
+
"docs = docsearch.get_relevant_documents(query)"
|
266 |
]
|
267 |
},
|
268 |
{
|
269 |
"cell_type": "code",
|
270 |
+
"execution_count": 116,
|
271 |
"metadata": {},
|
272 |
+
"outputs": [
|
273 |
+
{
|
274 |
+
"data": {
|
275 |
+
"text/plain": [
|
276 |
+
"[Document(page_content=\"A recession is always just around the corner, meanwhile the S&P has been making higher lows since October 2022, unemployed in Canada is 5.0%, and GDP hasn’t been negative for two quarters. It is likely to happen. And economists will have to work fast to figure out a solution. You need demand to sustain supply. And without jobs there isn't demand, nor tax revenue.\\n\\nSo short term: pain is possible.\\n\\nMedium term: we might benefit from doing less work and getting paid the same.\", metadata={'source': '39'}),\n",
|
277 |
+
" Document(page_content=\"Don't listen to people who say GDP is the only thing that matters for recessions. The NBER's definition -- which hasn't changed -- requires economic downturn across the economy in a broad sense. We probably can't have a recession and very low unemployment at the same time. So if we have a recession, we will have layoffs, too. [deleted] Most layoffs I've seen headlines for are in tech. While tech is a huge sector, it is important to look at the market as a whole. Job growth and job demand is still high. There are many many different indicators people look at. Housing prices is another. The difference between the bond curve. 3 month compared to the 1 year is another common one. When that inverts it means the economy isn't in a great place. It's so hard to predict which is why people constantly say to dollar cost average. Continuously buy and don't try to time anything. This time unemployment rate will be affected.\", metadata={'source': '13'}),\n",
|
278 |
+
" Document(page_content='High tech layoffs often triggers layoff in retail as affected people no longer have the means to spend. Restaurant closing. Meta has to write off $0.67 billion dollars or more next qtr as 1 time charge. Tweeter let go 50 pct employees so SF is already slow business wise. Housing industry is going to get hit this winter and even spring. Be honest with ourselves we have been in recession most of the year. Not too many mfg jobs are in the US so you do not see that many job losses.\\n\\nAfter Xmas I imagine companies like mail order delivery will slow down also. Your pocketbook. Negative GDP, inflation, decrease in purchases, growing unemployment... decrease of consumer spending; and increase of umemployment Honestly fear of a recession is the best indicator. People and companies start holding their money close to be safe and it becomes a self fulfilling prophecy. Most reliable indicator is when people are yelling for higher prices when their asset already doubled, tripled in value.', metadata={'source': '14'}),\n",
|
279 |
+
" Document(page_content='?? Agreed. Investors are jumping at the slightest bit of good news. But the job market is still strong, Fed wont stop raising rates until that changes. Not to mention the fact that oil is going up again. Which was a big factor in how we got to 8% - 9% inflation in the first place. Gas is almost $6 in my area now. More than it has ever been where I live. We are at the beginning of a lot of pain. Not the middle or the end. Recession is here, and it will deepen and broaden until everyone finally sees it. I mean australia raised under expectations but I dont think that matter I think the jump was - Australia only bumped by 25 basis pts and for some reason the market is in love with the fed pivot idea, jobs openings reduced something like 10%, Bond yields went down as well. But no, nothing fundamentally changed. The bank of England has started QE >didn\\'t interest rates hike again?\\n\\nYou \"call a trap\" but aren\\'t entirely sure if interest rates were hiked (they were)?', metadata={'source': '179'})]"
|
280 |
+
]
|
281 |
+
},
|
282 |
+
"execution_count": 116,
|
283 |
+
"metadata": {},
|
284 |
+
"output_type": "execute_result"
|
285 |
+
}
|
286 |
+
],
|
287 |
"source": [
|
288 |
+
"docs"
|
289 |
]
|
290 |
},
|
291 |
{
|
292 |
"cell_type": "code",
|
293 |
+
"execution_count": 117,
|
294 |
"metadata": {},
|
295 |
+
"outputs": [
|
296 |
+
{
|
297 |
+
"data": {
|
298 |
+
"text/plain": [
|
299 |
+
"' It is likely that we are in a recession right now.\\nSOURCES: 39, 13, 14, 179'"
|
300 |
+
]
|
301 |
+
},
|
302 |
+
"execution_count": 117,
|
303 |
+
"metadata": {},
|
304 |
+
"output_type": "execute_result"
|
305 |
+
}
|
306 |
+
],
|
307 |
"source": [
|
308 |
+
"chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type=\"stuff\")\n",
|
309 |
+
"chain.run(input_documents=docs, question=query)"
|
310 |
]
|
311 |
}
|
312 |
],
|
|
|
326 |
"name": "python",
|
327 |
"nbconvert_exporter": "python",
|
328 |
"pygments_lexer": "ipython3",
|
329 |
+
"version": "3.10.0"
|
330 |
},
|
331 |
"orig_nbformat": 4
|
332 |
},
|
semantic_search.py
CHANGED
@@ -21,12 +21,12 @@ class SemanticSearch:
|
|
21 |
|
22 |
def __call__(self, text, return_data=True):
|
23 |
inp_emb = self.use([text])
|
24 |
-
|
25 |
|
26 |
if return_data:
|
27 |
-
return [self.data[i] for i in neighbors], distances
|
28 |
else:
|
29 |
-
return neighbors, distances
|
30 |
|
31 |
|
32 |
def get_text_embedding(self, texts, batch=1000):
|
|
|
21 |
|
22 |
def __call__(self, text, return_data=True):
|
23 |
inp_emb = self.use([text])
|
24 |
+
distances, neighbors = self.nn.kneighbors(inp_emb, return_distance=True)
|
25 |
|
26 |
if return_data:
|
27 |
+
return [self.data[i] for i in neighbors[0]], distances
|
28 |
else:
|
29 |
+
return neighbors[0], distances
|
30 |
|
31 |
|
32 |
def get_text_embedding(self, texts, batch=1000):
|