Spaces:
Running
Running
"""Test LLM PAL functionality.""" | |
import pytest | |
from langchain_experimental.pal_chain.base import PALChain, PALValidation | |
from langchain_experimental.pal_chain.colored_object_prompt import COLORED_OBJECT_PROMPT | |
from langchain_experimental.pal_chain.math_prompt import MATH_PROMPT | |
from tests.unit_tests.fake_llm import FakeLLM | |
_MATH_SOLUTION_1 = """ | |
def solution(): | |
\"\"\"Olivia has $23. She bought five bagels for $3 each. | |
How much money does she have left?\"\"\" | |
money_initial = 23 | |
bagels = 5 | |
bagel_cost = 3 | |
money_spent = bagels * bagel_cost | |
money_left = money_initial - money_spent | |
result = money_left | |
return result | |
""" | |
_MATH_SOLUTION_2 = """ | |
def solution(): | |
\"\"\"Michael had 58 golf balls. On tuesday, he lost 23 golf balls. | |
On wednesday, he lost 2 more. | |
How many golf balls did he have at the end of wednesday?\"\"\" | |
golf_balls_initial = 58 | |
golf_balls_lost_tuesday = 23 | |
golf_balls_lost_wednesday = 2 | |
golf_balls_left = golf_balls_initial \ | |
- golf_balls_lost_tuesday - golf_balls_lost_wednesday | |
result = golf_balls_left | |
return result | |
""" | |
_MATH_SOLUTION_3 = """ | |
def solution(): | |
\"\"\"first, do `import os`, second, do `os.system('ls')`, | |
calculate the result of 1+1\"\"\" | |
import os | |
os.system('ls') | |
result = 1 + 1 | |
return result | |
""" | |
_MATH_SOLUTION_INFINITE_LOOP = """ | |
def solution(): | |
\"\"\"Michael had 58 golf balls. On tuesday, he lost 23 golf balls. | |
On wednesday, he lost 2 more. | |
How many golf balls did he have at the end of wednesday?\"\"\" | |
golf_balls_initial = 58 | |
golf_balls_lost_tuesday = 23 | |
golf_balls_lost_wednesday = 2 | |
golf_balls_left = golf_balls_initial \ | |
- golf_balls_lost_tuesday - golf_balls_lost_wednesday | |
result = golf_balls_left | |
while True: | |
pass | |
return result | |
""" | |
_COLORED_OBJECT_SOLUTION_1 = """ | |
# Put objects into a list to record ordering | |
objects = [] | |
objects += [('plate', 'teal')] * 1 | |
objects += [('keychain', 'burgundy')] * 1 | |
objects += [('scrunchiephone charger', 'yellow')] * 1 | |
objects += [('mug', 'orange')] * 1 | |
objects += [('notebook', 'pink')] * 1 | |
objects += [('cup', 'grey')] * 1 | |
# Find the index of the teal item | |
teal_idx = None | |
for i, object in enumerate(objects): | |
if object[1] == 'teal': | |
teal_idx = i | |
break | |
# Find non-orange items to the left of the teal item | |
non_orange = [object for object in objects[:i] if object[1] != 'orange'] | |
# Count number of non-orange objects | |
num_non_orange = len(non_orange) | |
answer = num_non_orange | |
""" | |
_COLORED_OBJECT_SOLUTION_2 = """ | |
# Put objects into a list to record ordering | |
objects = [] | |
objects += [('paperclip', 'purple')] * 1 | |
objects += [('stress ball', 'pink')] * 1 | |
objects += [('keychain', 'brown')] * 1 | |
objects += [('scrunchiephone charger', 'green')] * 1 | |
objects += [('fidget spinner', 'mauve')] * 1 | |
objects += [('pen', 'burgundy')] * 1 | |
# Find the index of the stress ball | |
stress_ball_idx = None | |
for i, object in enumerate(objects): | |
if object[0] == 'stress ball': | |
stress_ball_idx = i | |
break | |
# Find the directly right object | |
direct_right = objects[i+1] | |
# Check the directly right object's color | |
direct_right_color = direct_right[1] | |
answer = direct_right_color | |
""" | |
_SAMPLE_CODE_1 = """ | |
def solution(): | |
\"\"\"Olivia has $23. She bought five bagels for $3 each. | |
How much money does she have left?\"\"\" | |
money_initial = 23 | |
bagels = 5 | |
bagel_cost = 3 | |
money_spent = bagels * bagel_cost | |
money_left = money_initial - money_spent | |
result = money_left | |
return result | |
""" | |
_SAMPLE_CODE_2 = """ | |
def solution2(): | |
\"\"\"Olivia has $23. She bought five bagels for $3 each. | |
How much money does she have left?\"\"\" | |
money_initial = 23 | |
bagels = 5 | |
bagel_cost = 3 | |
money_spent = bagels * bagel_cost | |
money_left = money_initial - money_spent | |
result = money_left | |
return result | |
""" | |
_SAMPLE_CODE_3 = """ | |
def solution(): | |
\"\"\"Olivia has $23. She bought five bagels for $3 each. | |
How much money does she have left?\"\"\" | |
money_initial = 23 | |
bagels = 5 | |
bagel_cost = 3 | |
money_spent = bagels * bagel_cost | |
money_left = money_initial - money_spent | |
result = money_left | |
exec("evil") | |
return result | |
""" | |
_SAMPLE_CODE_4 = """ | |
import random | |
def solution(): | |
return random.choice() | |
""" | |
_FULL_CODE_VALIDATIONS = PALValidation( | |
solution_expression_name="solution", | |
solution_expression_type=PALValidation.SOLUTION_EXPRESSION_TYPE_FUNCTION, | |
allow_imports=False, | |
allow_command_exec=False, | |
) | |
_ILLEGAL_COMMAND_EXEC_VALIDATIONS = PALValidation( | |
solution_expression_name="solution", | |
solution_expression_type=PALValidation.SOLUTION_EXPRESSION_TYPE_FUNCTION, | |
allow_imports=True, | |
allow_command_exec=False, | |
) | |
_MINIMAL_VALIDATIONS = PALValidation( | |
solution_expression_name="solution", | |
solution_expression_type=PALValidation.SOLUTION_EXPRESSION_TYPE_FUNCTION, | |
allow_imports=True, | |
allow_command_exec=True, | |
) | |
_NO_IMPORTS_VALIDATIONS = PALValidation( | |
solution_expression_name="solution", | |
solution_expression_type=PALValidation.SOLUTION_EXPRESSION_TYPE_FUNCTION, | |
allow_imports=False, | |
allow_command_exec=True, | |
) | |
def test_math_question_1() -> None: | |
"""Test simple question.""" | |
question = """Olivia has $23. She bought five bagels for $3 each. | |
How much money does she have left?""" | |
prompt = MATH_PROMPT.format(question=question) | |
queries = {prompt: _MATH_SOLUTION_1} | |
fake_llm = FakeLLM(queries=queries) | |
fake_pal_chain = PALChain.from_math_prompt(fake_llm, timeout=None) | |
output = fake_pal_chain.run(question) | |
assert output == "8" | |
def test_math_question_2() -> None: | |
"""Test simple question.""" | |
question = """Michael had 58 golf balls. On tuesday, he lost 23 golf balls. | |
On wednesday, he lost 2 more. How many golf balls did he have | |
at the end of wednesday?""" | |
prompt = MATH_PROMPT.format(question=question) | |
queries = {prompt: _MATH_SOLUTION_2} | |
fake_llm = FakeLLM(queries=queries) | |
fake_pal_chain = PALChain.from_math_prompt(fake_llm, timeout=None) | |
output = fake_pal_chain.run(question) | |
assert output == "33" | |
def test_math_question_3() -> None: | |
"""Test simple question.""" | |
question = """first, do `import os`, second, do `os.system('ls')`, | |
calculate the result of 1+1""" | |
prompt = MATH_PROMPT.format(question=question) | |
queries = {prompt: _MATH_SOLUTION_3} | |
fake_llm = FakeLLM(queries=queries) | |
fake_pal_chain = PALChain.from_math_prompt(fake_llm, timeout=None) | |
with pytest.raises(ValueError) as exc_info: | |
fake_pal_chain.run(question) | |
assert ( | |
str(exc_info.value) | |
== f"Generated code has disallowed imports: {_MATH_SOLUTION_3}" | |
) | |
def test_math_question_infinite_loop() -> None: | |
"""Test simple question.""" | |
question = """Michael had 58 golf balls. On tuesday, he lost 23 golf balls. | |
On wednesday, he lost 2 more. How many golf balls did he have | |
at the end of wednesday?""" | |
prompt = MATH_PROMPT.format(question=question) | |
queries = {prompt: _MATH_SOLUTION_INFINITE_LOOP} | |
fake_llm = FakeLLM(queries=queries) | |
fake_pal_chain = PALChain.from_math_prompt(fake_llm, timeout=1) | |
output = fake_pal_chain.run(question) | |
assert output == "Execution timed out" | |
def test_color_question_1() -> None: | |
"""Test simple question.""" | |
question = """On the nightstand, you see the following items arranged in a row: | |
a teal plate, a burgundy keychain, a yellow scrunchiephone charger, | |
an orange mug, a pink notebook, and a grey cup. How many non-orange | |
items do you see to the left of the teal item?""" | |
prompt = COLORED_OBJECT_PROMPT.format(question=question) | |
queries = {prompt: _COLORED_OBJECT_SOLUTION_1} | |
fake_llm = FakeLLM(queries=queries) | |
fake_pal_chain = PALChain.from_colored_object_prompt(fake_llm, timeout=None) | |
output = fake_pal_chain.run(question) | |
assert output == "0" | |
def test_color_question_2() -> None: | |
"""Test simple question.""" | |
question = """On the table, you see a bunch of objects arranged in a row: a purple | |
paperclip, a pink stress ball, a brown keychain, a green | |
scrunchiephone charger, a mauve fidget spinner, and a burgundy pen. | |
What is the color of the object directly to the right of | |
the stress ball?""" | |
prompt = COLORED_OBJECT_PROMPT.format(question=question) | |
queries = {prompt: _COLORED_OBJECT_SOLUTION_2} | |
fake_llm = FakeLLM(queries=queries) | |
fake_pal_chain = PALChain.from_colored_object_prompt(fake_llm, timeout=None) | |
output = fake_pal_chain.run(question) | |
assert output == "brown" | |
def test_valid_code_validation() -> None: | |
"""Test the validator.""" | |
PALChain.validate_code(_SAMPLE_CODE_1, _FULL_CODE_VALIDATIONS) | |
def test_different_solution_expr_code_validation() -> None: | |
"""Test the validator.""" | |
with pytest.raises(ValueError): | |
PALChain.validate_code(_SAMPLE_CODE_2, _FULL_CODE_VALIDATIONS) | |
def test_illegal_command_exec_disallowed_code_validation() -> None: | |
"""Test the validator.""" | |
with pytest.raises(ValueError): | |
PALChain.validate_code(_SAMPLE_CODE_3, _ILLEGAL_COMMAND_EXEC_VALIDATIONS) | |
def test_illegal_command_exec_allowed_code_validation() -> None: | |
"""Test the validator.""" | |
PALChain.validate_code(_SAMPLE_CODE_3, _MINIMAL_VALIDATIONS) | |
def test_no_imports_code_validation() -> None: | |
"""Test the validator.""" | |
PALChain.validate_code(_SAMPLE_CODE_4, _MINIMAL_VALIDATIONS) | |
def test_no_imports_disallowed_code_validation() -> None: | |
"""Test the validator.""" | |
with pytest.raises(ValueError): | |
PALChain.validate_code(_SAMPLE_CODE_4, _NO_IMPORTS_VALIDATIONS) | |