agora-demo / flow.py
samuelemarro's picture
Fixed & improved error messages.
7b2f713
raw
history blame
6.38 kB
import dotenv
dotenv.load_dotenv()
import json
import os
import random
import threading
import time
from toolformers.base import Tool, parameter_from_openai_api, StringParameter
from toolformers.base import Toolformer
from toolformers.camel import make_openai_toolformer
from toolformers.langchain_agent import LangChainAnthropicToolformer
from toolformers.sambanova import SambanovaToolformer
from toolformers.gemini import GeminiToolformer
from querier import Querier
from responder import Responder
from negotiator import SenderNegotiator, ReceiverNegotiator
from programmer import SenderProgrammer, ReceiverProgrammer
from executor import UnsafeExecutor
from utils import compute_hash
def create_toolformer(model_name) -> Toolformer:
if model_name in ['gpt-4o', 'gpt-4o-mini']:
return make_openai_toolformer(model_name)
elif 'claude' in model_name:
return LangChainAnthropicToolformer(model_name, os.environ.get('ANTHROPIC_API_KEY'))
elif model_name in ['llama3-405b']:
return SambanovaToolformer(model_name)
elif model_name in ['gemini-1.5-pro']:
return GeminiToolformer(model_name)
else:
raise ValueError(f"Unknown model name: {model_name}")
def full_flow(schema, alice_model, bob_model):
NL_MESSAGES = []
NEGOTIATION_MESSAGES = []
STRUCTURED_MESSAGES = []
ARTIFACTS = {}
toolformer_alice = create_toolformer(alice_model)
toolformer_bob = create_toolformer(bob_model)
querier = Querier(toolformer_alice)
responder = Responder(toolformer_bob)
tools = []
for tool_schema in schema['tools']:
parameters = [parameter_from_openai_api(name, schema, name in tool_schema['input']['required']) for name, schema in tool_schema['input']['properties'].items()]
def tool_fn(*args, **kwargs):
print(f'Bob tool {tool_schema["name"]} called with args {args} and kwargs {kwargs}')
return random.choice(tool_schema['dummy_outputs'])
tool = Tool(tool_schema['name'], tool_schema['description'], parameters, tool_fn, tool_schema['output'])
tools.append(tool)
def nl_callback_fn(query):
print(query)
NL_MESSAGES.append({
'role': 'assistant',
#'content': query['body'],
'body': query['body'],
'protocolHash': None
})
response = responder.reply_to_query(query['body'], query['protocolHash'], tools, '')
NL_MESSAGES.append({
'role': 'user',
#'content': response['body']
'status': 'success',
'body': response['body']
})
return response
negotiator_sender = SenderNegotiator(toolformer_alice)
negotiator_receiver = ReceiverNegotiator(toolformer_bob, tools, '')
def negotiation_callback_fn(query):
print(query)
NEGOTIATION_MESSAGES.append({
'role': 'assistant',
'content': query
})
response = negotiator_receiver.handle_negotiation(query)
NEGOTIATION_MESSAGES.append({
'role': 'user',
'content': response
})
#print('CURRENT NEGOTIATION MESSAGES:', len(NEGOTIATION_MESSAGES))
return response
def final_message_callback_fn(query):
NEGOTIATION_MESSAGES.append({
'role': 'assistant',
'content': query
})
sender_programmer = SenderProgrammer(toolformer_alice)
receiver_programmer = ReceiverProgrammer(toolformer_bob)
executor = UnsafeExecutor()
def structured_callback_fn(query):
STRUCTURED_MESSAGES.append({
'role': 'assistant',
#'content': query
'body': json.dumps(query) if isinstance(query, dict) else query,
'protocolHash': ARTIFACTS['protocol']['hash'],
'protocolSources': ['https://...']
})
try:
response = executor.run_routine(ARTIFACTS['protocol']['hash'], ARTIFACTS['implementation_receiver'], query, tools)
except Exception as e:
import traceback
traceback.print_exc()
STRUCTURED_MESSAGES.append({
'role': 'user',
'status': 'error',
'message': str(e)
})
return 'Error'
STRUCTURED_MESSAGES.append({
'role': 'user',
#'content': response
'status': 'success',
'body': json.dumps(response) if isinstance(response, dict) else response
})
return response
def flow():
task_data = random.choice(schema['examples'])
querier.send_query_without_protocol(schema, task_data, nl_callback_fn)
#time.sleep(1)
res = negotiator_sender.negotiate_protocol_for_task(schema, negotiation_callback_fn, final_message_callback_fn=final_message_callback_fn)
protocol_hash = compute_hash(res['protocol'])
res['hash'] = protocol_hash
ARTIFACTS['protocol'] = res
protocol_document = res['protocol']
implementation_sender = sender_programmer.write_routine_for_task(schema, protocol_document)
ARTIFACTS['implementation_sender'] = implementation_sender
implementation_receiver = receiver_programmer.write_routine_for_tools(tools, protocol_document, '')
ARTIFACTS['implementation_receiver'] = implementation_receiver
send_tool = Tool('send_to_server', 'Send to server', StringParameter('query', 'The query', True), structured_callback_fn)
try:
executor.run_routine(protocol_hash, implementation_sender, task_data, [send_tool])
except Exception as e:
# Print the error
import traceback
traceback.print_exc()
STRUCTURED_MESSAGES.append({
'role': 'assistant',
'status': 'error',
'message': str(e)
})
def get_info():
return NL_MESSAGES, NEGOTIATION_MESSAGES, STRUCTURED_MESSAGES, ARTIFACTS.get('protocol', {}).get('protocol', ''), \
ARTIFACTS.get('implementation_sender', ''), ARTIFACTS.get('implementation_receiver', '')
thread = threading.Thread(
target = lambda: flow()
)
thread.start()
while thread.is_alive():
yield get_info()
time.sleep(0.2)
yield get_info()