File size: 5,734 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import re
import unittest
from typing import Tuple

import pytest

from langchain_experimental.tot.base import ToTChain
from langchain_experimental.tot.checker import ToTChecker
from langchain_experimental.tot.controller import ToTController
from langchain_experimental.tot.memory import ToTDFSMemory
from langchain_experimental.tot.thought import Thought, ThoughtValidity
from langchain_experimental.tot.thought_generation import SampleCoTStrategy
from tests.unit_tests.fake_llm import FakeLLM

sudoku_puzzle = "3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1"
solutions = [
    "3,*,4,2|1,*,3,*|*,1,*,3|4,*,*,1",  # VALID_INTERMEDIATE
    "   3,4,1,2|1,6,3,*|*,1,*,3|4,*,*,1",  # INVALID c=1
    "   3,4,1,2|1,7,3,*|*,1,*,3|4,*,*,1",  # INVALID c=2
    "   3,4,1,2|1,8,3,*|*,1,*,3|4,*,*,1",  # INVALID c=3
    "   3,4,1,2|1,2,3,*|*,1,*,3|4,*,*,1",  # VALID_INTERMEDIATE c=4 (rollback)
    "3,1,4,2|1,*,3,*|*,1,*,3|4,*,*,1",  # INVALID (rollback)
    "3,4,1,2|1,2,3,4|*,1,*,3|4,*,*,1",  # VALID_INTERMEDIATE
    "   3,4,1,2|1,2,3,4|4,1,*,3|4,*,*,1",  # INVALID (rollback)
    "   3,4,1,2|1,2,3,4|2,1,4,3|4,*,*,1",  # VALID_INTERMEDIATE
    "       3,4,1,2|1,2,3,4|2,1,4,3|4,3,*,1",  # VALID_INTERMEDIATE
    "           3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1",  # VALID_FINAL
]
sudoku_solution = "3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1"


@pytest.fixture
def fake_llm_sudoku() -> FakeLLM:
    """This is a fake LLM that responds to the sudoku problem."""
    queries = {i: next_step.strip() for i, next_step in enumerate(solutions)}
    return FakeLLM(queries=queries, sequential_responses=True)


class SudokuChecker(ToTChecker):
    def evaluate(
        self, problem_description: str, thoughts: Tuple[str, ...] = ()
    ) -> ThoughtValidity:
        last_thought = thoughts[-1]
        clean_solution = last_thought.replace(" ", "").replace('"', "")
        regex_solution = clean_solution.replace("*", ".").replace("|", "\\|")
        if sudoku_solution in clean_solution:
            return ThoughtValidity.VALID_FINAL
        elif re.search(regex_solution, sudoku_solution):
            return ThoughtValidity.VALID_INTERMEDIATE
        else:
            return ThoughtValidity.INVALID


@pytest.mark.requires("jinja2")
def test_solve_sudoku(fake_llm_sudoku: FakeLLM) -> None:
    """Test simple question that should not need python."""
    tot_chain = ToTChain(
        llm=fake_llm_sudoku,
        checker=SudokuChecker(),
        k=len(solutions),
        c=4,
        tot_strategy_class=SampleCoTStrategy,
    )
    output = tot_chain.run({"problem_description": ""})
    assert output == sudoku_solution


@pytest.mark.requires("jinja2")
def test_solve_sudoku_k_too_small(fake_llm_sudoku: FakeLLM) -> None:
    """Test simple question that should not need python."""
    tot_chain = ToTChain(
        llm=fake_llm_sudoku,
        checker=SudokuChecker(),
        k=len(solutions) - 1,
        c=4,
        tot_strategy_class=SampleCoTStrategy,
    )
    output = tot_chain.run({"problem_description": ""})
    assert output != sudoku_solution


@pytest.fixture
def fake_llm_checker() -> FakeLLM:
    """This is a fake LLM that responds with a thought validity."""
    responses = [
        "VALID",
        "valid",
        "INVALID",
        "invalid",
        "INTERMEDIATE",
        "intermediate",
        "SOMETHING ELSE",
    ]
    queries = dict(enumerate(responses))
    return FakeLLM(queries=queries, sequential_responses=True)


class ControllerTestCase(unittest.TestCase):
    def setUp(self) -> None:
        self.controller = ToTController(c=3)

    def test_empty(self) -> None:
        memory = ToTDFSMemory([])
        self.assertEqual(self.controller(memory), ())

    def test_one_thoghts(self) -> None:
        thoughts = [Thought(text="a", validity=ThoughtValidity.VALID_FINAL)]
        memory = ToTDFSMemory(thoughts)
        self.assertEqual(self.controller(memory), ("a",))

    def test_two_thoghts(self) -> None:
        memory = ToTDFSMemory(
            [
                Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE),
                Thought(text="b", validity=ThoughtValidity.VALID_INTERMEDIATE),
            ]
        )
        self.assertEqual(self.controller(memory), ("a", "b"))

    def test_two_thoughts_invalid(self) -> None:
        memory = ToTDFSMemory(
            [
                Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE),
                Thought(text="b", validity=ThoughtValidity.INVALID),
            ]
        )
        self.assertEqual(self.controller(memory), ("a",))

    def test_thoughts_rollback(self) -> None:
        a = Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE)
        b = Thought(text="b", validity=ThoughtValidity.VALID_INTERMEDIATE)
        c_1 = Thought(text="c_1", validity=ThoughtValidity.VALID_INTERMEDIATE)
        c_2 = Thought(text="c_2", validity=ThoughtValidity.VALID_INTERMEDIATE)
        c_3 = Thought(text="c_3", validity=ThoughtValidity.VALID_INTERMEDIATE)

        a.children = {b}
        b.children = {c_1, c_2, c_3}

        memory = ToTDFSMemory([a, b, c_3])
        self.assertEqual(self.controller(memory), ("a",))

    def test_thoughts_rollback_invalid(self) -> None:
        a = Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE)
        b = Thought(text="b", validity=ThoughtValidity.VALID_INTERMEDIATE)
        c_1 = Thought(text="c_1", validity=ThoughtValidity.VALID_INTERMEDIATE)
        c_2 = Thought(text="c_2", validity=ThoughtValidity.VALID_INTERMEDIATE)
        c_3 = Thought(text="c_3", validity=ThoughtValidity.INVALID)

        a.children = {b}
        b.children = {c_1, c_2, c_3}

        memory = ToTDFSMemory([a, b, c_3])
        self.assertEqual(self.controller(memory), ("a",))