Spaces:
Running
Running
import dotenv | |
dotenv.load_dotenv() | |
import json | |
import os | |
import random | |
import threading | |
import time | |
from toolformers.base import Tool, parameter_from_openai_api, StringParameter | |
from toolformers.base import Toolformer | |
from toolformers.camel import make_openai_toolformer | |
from toolformers.langchain_agent import LangChainAnthropicToolformer | |
from toolformers.sambanova import SambanovaToolformer | |
from toolformers.gemini import GeminiToolformer | |
from querier import Querier | |
from responder import Responder | |
from negotiator import SenderNegotiator, ReceiverNegotiator | |
from programmer import SenderProgrammer, ReceiverProgrammer | |
from executor import UnsafeExecutor | |
from utils import compute_hash | |
def create_toolformer(model_name) -> Toolformer: | |
if model_name in ['gpt-4o', 'gpt-4o-mini']: | |
return make_openai_toolformer(model_name) | |
elif 'claude' in model_name: | |
return LangChainAnthropicToolformer(model_name, os.environ.get('ANTHROPIC_API_KEY')) | |
elif model_name in ['llama3-405b']: | |
return SambanovaToolformer(model_name) | |
elif model_name in ['gemini-1.5-pro']: | |
return GeminiToolformer(model_name) | |
else: | |
raise ValueError(f"Unknown model name: {model_name}") | |
def full_flow(schema, alice_model, bob_model): | |
NL_MESSAGES = [] | |
NEGOTIATION_MESSAGES = [] | |
STRUCTURED_MESSAGES = [] | |
ARTIFACTS = {} | |
toolformer_alice = create_toolformer(alice_model) | |
toolformer_bob = create_toolformer(bob_model) | |
querier = Querier(toolformer_alice) | |
responder = Responder(toolformer_bob) | |
tools = [] | |
for tool_schema in schema['tools']: | |
parameters = [parameter_from_openai_api(name, schema, name in tool_schema['input']['required']) for name, schema in tool_schema['input']['properties'].items()] | |
def tool_fn(*args, **kwargs): | |
print(f'Bob tool {tool_schema["name"]} called with args {args} and kwargs {kwargs}') | |
return random.choice(tool_schema['dummy_outputs']) | |
tool = Tool(tool_schema['name'], tool_schema['description'], parameters, tool_fn, tool_schema['output']) | |
tools.append(tool) | |
def nl_callback_fn(query): | |
print(query) | |
NL_MESSAGES.append({ | |
'role': 'assistant', | |
#'content': query['body'], | |
'body': query['body'], | |
'protocolHash': None | |
}) | |
response = responder.reply_to_query(query['body'], query['protocolHash'], tools, '') | |
NL_MESSAGES.append({ | |
'role': 'user', | |
#'content': response['body'] | |
'status': 'success', | |
'body': response['body'] | |
}) | |
return response | |
negotiator_sender = SenderNegotiator(toolformer_alice) | |
negotiator_receiver = ReceiverNegotiator(toolformer_bob, tools, '') | |
def negotiation_callback_fn(query): | |
print(query) | |
NEGOTIATION_MESSAGES.append({ | |
'role': 'assistant', | |
'content': query | |
}) | |
response = negotiator_receiver.handle_negotiation(query) | |
NEGOTIATION_MESSAGES.append({ | |
'role': 'user', | |
'content': response | |
}) | |
#print('CURRENT NEGOTIATION MESSAGES:', len(NEGOTIATION_MESSAGES)) | |
return response | |
def final_message_callback_fn(query): | |
NEGOTIATION_MESSAGES.append({ | |
'role': 'assistant', | |
'content': query | |
}) | |
sender_programmer = SenderProgrammer(toolformer_alice) | |
receiver_programmer = ReceiverProgrammer(toolformer_bob) | |
executor = UnsafeExecutor() | |
def structured_callback_fn(query): | |
STRUCTURED_MESSAGES.append({ | |
'role': 'assistant', | |
#'content': query | |
'body': json.dumps(query) if isinstance(query, dict) else query, | |
'protocolHash': ARTIFACTS['protocol']['hash'], | |
'protocolSources': ['https://...'] | |
}) | |
try: | |
response = executor.run_routine(ARTIFACTS['protocol']['hash'], ARTIFACTS['implementation_receiver'], query, tools) | |
except Exception as e: | |
import traceback | |
traceback.print_exc() | |
STRUCTURED_MESSAGES.append({ | |
'role': 'user', | |
'status': 'error', | |
'message': str(e) | |
}) | |
return 'Error' | |
STRUCTURED_MESSAGES.append({ | |
'role': 'user', | |
#'content': response | |
'status': 'success', | |
'body': json.dumps(response) if isinstance(response, dict) else response | |
}) | |
return response | |
def flow(): | |
task_data = random.choice(schema['examples']) | |
querier.send_query_without_protocol(schema, task_data, nl_callback_fn) | |
#time.sleep(1) | |
res = negotiator_sender.negotiate_protocol_for_task(schema, negotiation_callback_fn, final_message_callback_fn=final_message_callback_fn) | |
protocol_hash = compute_hash(res['protocol']) | |
res['hash'] = protocol_hash | |
ARTIFACTS['protocol'] = res | |
protocol_document = res['protocol'] | |
implementation_sender = sender_programmer.write_routine_for_task(schema, protocol_document) | |
ARTIFACTS['implementation_sender'] = implementation_sender | |
implementation_receiver = receiver_programmer.write_routine_for_tools(tools, protocol_document, '') | |
ARTIFACTS['implementation_receiver'] = implementation_receiver | |
send_tool = Tool('send_to_server', 'Send to server', StringParameter('query', 'The query', True), structured_callback_fn) | |
try: | |
executor.run_routine(protocol_hash, implementation_sender, task_data, [send_tool]) | |
except Exception as e: | |
# Print the error | |
import traceback | |
traceback.print_exc() | |
STRUCTURED_MESSAGES.append({ | |
'role': 'assistant', | |
'status': 'error', | |
'message': str(e) | |
}) | |
def get_info(): | |
return NL_MESSAGES, NEGOTIATION_MESSAGES, STRUCTURED_MESSAGES, ARTIFACTS.get('protocol', {}).get('protocol', ''), \ | |
ARTIFACTS.get('implementation_sender', ''), ARTIFACTS.get('implementation_receiver', '') | |
thread = threading.Thread( | |
target = lambda: flow() | |
) | |
thread.start() | |
while thread.is_alive(): | |
yield get_info() | |
time.sleep(0.2) | |
yield get_info() |