Spaces:
Running
Running
"""Test SmartLLM.""" | |
from langchain_community.chat_models import FakeListChatModel | |
from langchain_community.llms import FakeListLLM | |
from langchain_core.prompts.prompt import PromptTemplate | |
from langchain_experimental.smart_llm import SmartLLMChain | |
def test_ideation() -> None: | |
# test that correct responses are returned | |
responses = ["Idea 1", "Idea 2", "Idea 3"] | |
llm = FakeListLLM(responses=responses) | |
prompt = PromptTemplate( | |
input_variables=["product"], | |
template="What is a good name for a company that makes {product}?", | |
) | |
chain = SmartLLMChain(llm=llm, prompt=prompt) | |
prompt_value, _ = chain.prep_prompts({"product": "socks"}) | |
chain.history.question = prompt_value.to_string() | |
results = chain._ideate() | |
assert results == responses | |
# test that correct number of responses are returned | |
for i in range(1, 5): | |
responses = [f"Idea {j+1}" for j in range(i)] | |
llm = FakeListLLM(responses=responses) | |
chain = SmartLLMChain(llm=llm, prompt=prompt, n_ideas=i) | |
prompt_value, _ = chain.prep_prompts({"product": "socks"}) | |
chain.history.question = prompt_value.to_string() | |
results = chain._ideate() | |
assert len(results) == i | |
def test_critique() -> None: | |
response = "Test Critique" | |
llm = FakeListLLM(responses=[response]) | |
prompt = PromptTemplate( | |
input_variables=["product"], | |
template="What is a good name for a company that makes {product}?", | |
) | |
chain = SmartLLMChain(llm=llm, prompt=prompt, n_ideas=2) | |
prompt_value, _ = chain.prep_prompts({"product": "socks"}) | |
chain.history.question = prompt_value.to_string() | |
chain.history.ideas = ["Test Idea 1", "Test Idea 2"] | |
result = chain._critique() | |
assert result == response | |
def test_resolver() -> None: | |
response = "Test resolution" | |
llm = FakeListLLM(responses=[response]) | |
prompt = PromptTemplate( | |
input_variables=["product"], | |
template="What is a good name for a company that makes {product}?", | |
) | |
chain = SmartLLMChain(llm=llm, prompt=prompt, n_ideas=2) | |
prompt_value, _ = chain.prep_prompts({"product": "socks"}) | |
chain.history.question = prompt_value.to_string() | |
chain.history.ideas = ["Test Idea 1", "Test Idea 2"] | |
chain.history.critique = "Test Critique" | |
result = chain._resolve() | |
assert result == response | |
def test_all_steps() -> None: | |
joke = "Why did the chicken cross the Mobius strip?" | |
response = "Resolution response" | |
ideation_llm = FakeListLLM(responses=["Ideation response" for _ in range(20)]) | |
critique_llm = FakeListLLM(responses=["Critique response" for _ in range(20)]) | |
resolver_llm = FakeListLLM(responses=[response for _ in range(20)]) | |
prompt = PromptTemplate( | |
input_variables=["joke"], | |
template="Explain this joke to me: {joke}?", | |
) | |
chain = SmartLLMChain( | |
ideation_llm=ideation_llm, | |
critique_llm=critique_llm, | |
resolver_llm=resolver_llm, | |
prompt=prompt, | |
) | |
result = chain(joke) | |
assert result["joke"] == joke | |
assert result["resolution"] == response | |
def test_intermediate_output() -> None: | |
joke = "Why did the chicken cross the Mobius strip?" | |
llm = FakeListLLM(responses=[f"Response {i+1}" for i in range(5)]) | |
prompt = PromptTemplate( | |
input_variables=["joke"], | |
template="Explain this joke to me: {joke}?", | |
) | |
chain = SmartLLMChain(llm=llm, prompt=prompt, return_intermediate_steps=True) | |
result = chain(joke) | |
assert result["joke"] == joke | |
assert result["ideas"] == [f"Response {i+1}" for i in range(3)] | |
assert result["critique"] == "Response 4" | |
assert result["resolution"] == "Response 5" | |
def test_all_steps_with_chat_model() -> None: | |
joke = "Why did the chicken cross the Mobius strip?" | |
response = "Resolution response" | |
ideation_llm = FakeListChatModel(responses=["Ideation response" for _ in range(20)]) | |
critique_llm = FakeListChatModel(responses=["Critique response" for _ in range(20)]) | |
resolver_llm = FakeListChatModel(responses=[response for _ in range(20)]) | |
prompt = PromptTemplate( | |
input_variables=["joke"], | |
template="Explain this joke to me: {joke}?", | |
) | |
chain = SmartLLMChain( | |
ideation_llm=ideation_llm, | |
critique_llm=critique_llm, | |
resolver_llm=resolver_llm, | |
prompt=prompt, | |
) | |
result = chain(joke) | |
assert result["joke"] == joke | |
assert result["resolution"] == response | |