File size: 8,239 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tree of Thought (ToT) example\n",
    "\n",
    "The Tree of Thought (ToT) is a chain that allows you to query a Large Language Model (LLM) using the Tree of Thought technique. This is based on the paper [\"Large Language Model Guided Tree-of-Thought\"](https://arxiv.org/pdf/2305.08291.pdf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.13) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from langchain_openai import OpenAI\n",
    "\n",
    "llm = OpenAI(temperature=1, max_tokens=512, model=\"gpt-3.5-turbo-instruct\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1\n",
      "\n",
      "- This is a 4x4 Sudoku puzzle.\n",
      "- The * represents a cell to be filled.\n",
      "- The | character separates rows.\n",
      "- At each step, replace one or more * with digits 1-4.\n",
      "- There must be no duplicate digits in any row, column or 2x2 subgrid.\n",
      "- Keep the known digits from previous valid thoughts in place.\n",
      "- Each thought can be a partial or the final solution.\n"
     ]
    }
   ],
   "source": [
    "sudoku_puzzle = \"3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1\"\n",
    "sudoku_solution = \"3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1\"\n",
    "problem_description = f\"\"\"\n",
    "{sudoku_puzzle}\n",
    "\n",
    "- This is a 4x4 Sudoku puzzle.\n",
    "- The * represents a cell to be filled.\n",
    "- The | character separates rows.\n",
    "- At each step, replace one or more * with digits 1-4.\n",
    "- There must be no duplicate digits in any row, column or 2x2 subgrid.\n",
    "- Keep the known digits from previous valid thoughts in place.\n",
    "- Each thought can be a partial or the final solution.\n",
    "\"\"\".strip()\n",
    "print(problem_description)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Rules Based Checker\n",
    "\n",
    "Each thought is evaluated by the thought checker and is given a validity type: valid, invalid or partial. A simple checker can be rule based. For example, in the case of a sudoku puzzle, the checker can check if the puzzle is valid, invalid or partial.\n",
    "\n",
    "In the following code we implement a simple rule based checker for a specific 4x4 sudoku puzzle.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "from typing import Tuple\n",
    "\n",
    "from langchain_experimental.tot.checker import ToTChecker\n",
    "from langchain_experimental.tot.thought import ThoughtValidity\n",
    "\n",
    "\n",
    "class MyChecker(ToTChecker):\n",
    "    def evaluate(\n",
    "        self, problem_description: str, thoughts: Tuple[str, ...] = ()\n",
    "    ) -> ThoughtValidity:\n",
    "        last_thought = thoughts[-1]\n",
    "        clean_solution = last_thought.replace(\" \", \"\").replace('\"', \"\")\n",
    "        regex_solution = clean_solution.replace(\"*\", \".\").replace(\"|\", \"\\\\|\")\n",
    "        if sudoku_solution in clean_solution:\n",
    "            return ThoughtValidity.VALID_FINAL\n",
    "        elif re.search(regex_solution, sudoku_solution):\n",
    "            return ThoughtValidity.VALID_INTERMEDIATE\n",
    "        else:\n",
    "            return ThoughtValidity.INVALID"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Just testing the MyChecker class above:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "checker = MyChecker()\n",
    "assert (\n",
    "    checker.evaluate(\"\", (\"3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1\",))\n",
    "    == ThoughtValidity.VALID_INTERMEDIATE\n",
    ")\n",
    "assert (\n",
    "    checker.evaluate(\"\", (\"3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1\",))\n",
    "    == ThoughtValidity.VALID_FINAL\n",
    ")\n",
    "assert (\n",
    "    checker.evaluate(\"\", (\"3,4,1,2|1,2,3,4|2,1,4,3|4,3,*,1\",))\n",
    "    == ThoughtValidity.VALID_INTERMEDIATE\n",
    ")\n",
    "assert (\n",
    "    checker.evaluate(\"\", (\"3,4,1,2|1,2,3,4|2,1,4,3|4,*,3,1\",))\n",
    "    == ThoughtValidity.INVALID\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tree of Thought Chain\n",
    "\n",
    "Initialize and run the ToT chain, with maximum number of interactions `k` set to `30` and the maximum number child thoughts `c` set to `8`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new ToTChain chain...\u001b[0m\n",
      "Starting the ToT solve procedure.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/harrisonchase/workplace/langchain/libs/langchain/langchain/chains/llm.py:275: UserWarning: The predict_and_parse method is deprecated, instead pass an output parser directly to LLMChain.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[31;1m\u001b[1;3mThought: 3*,*,2|1*,3,*|*,1,*,3|4,*,*,1\n",
      "\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3*,1,2|1*,3,*|*,1,*,3|4,*,*,1\n",
      "\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3*,1,2|1*,3,4|*,1,*,3|4,*,*,1\n",
      "\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3*,1,2|1*,3,4|*,1,2,3|4,*,*,1\n",
      "\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3*,1,2|1*,3,4|2,1,*,3|4,*,*,1\n",
      "\u001b[0m"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Type <enum 'ThoughtValidity'> not serializable\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[31;1m\u001b[1;3mThought: 3,*,*,2|1,*,3,*|*,1,*,3|4,1,*,*\n",
      "\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3,*,*,2|*,3,2,*|*,1,*,3|4,1,*,*\n",
      "\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3,2,*,2|1,*,3,*|*,1,*,3|4,1,*,*\n",
      "\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3,2,*,2|1,*,3,*|1,1,*,3|4,1,*,*\n",
      "\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3,2,*,2|1,1,3,*|1,1,*,3|4,1,*,*\n",
      "\u001b[0m\u001b[33;1m\u001b[1;3mThought: 3,*,*,2|1,2,3,*|*,1,*,3|4,*,*,1\n",
      "\u001b[0m\u001b[31;1m\u001b[1;3m    Thought: 3,1,4,2|1,2,3,4|2,1,4,3|4,3,2,1\n",
      "\u001b[0m\u001b[32;1m\u001b[1;3m    Thought: 3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1\n",
      "\u001b[0m\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from langchain_experimental.tot.base import ToTChain\n",
    "\n",
    "tot_chain = ToTChain(\n",
    "    llm=llm, checker=MyChecker(), k=30, c=5, verbose=True, verbose_llm=False\n",
    ")\n",
    "tot_chain.run(problem_description=problem_description)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}