CodeWriterFlowModule / CodeWriterCtrlFlow.py
Tachi67's picture
change links from Tachi repo to aiflows
1206897 verified
Raw
History Blame Contribute Delete
4.63 kB
import json
from copy import deepcopy
from typing import Any, Dict, List
from flow_modules.aiflows.ChatFlowModule import ChatAtomicFlow
from dataclasses import dataclass
@dataclass
class Command:
name: str
description: str
input_args: List[str]
class CodeWriterCtrlFlow(ChatAtomicFlow):
"""refer to https://huggingface.co/aiflows/JarvisFlowModule/blob/main/Controller_JarvisFlow.py
This class controls the execution of the CodeWriterFlow.
*Input Interface Non Initialized*:
- `goal`
*Input Interface Initialized*:
- `goal`
- `code`
- `feedback`
*Output Interface*:
- `command`
- `command_args`
*Config Parameters*:
- `backend`: the backend used to call the LLM.
- `commands`: a list of commands that the controller can use.
- `system_message_prompt_template`: the template of the system message prompt.
- `init_human_message_prompt_template`: the template of the init human (user) message prompt.
- `human_message_prompt_template`: the template of the human (user) message prompt.
- `previous_messages`: the sliding window of previous messages.
"""
def __init__(
self,
commands: List[Command],
**kwargs):
super().__init__(**kwargs)
self.system_message_prompt_template = self.system_message_prompt_template.partial(
commands=self._build_commands_manual(commands),
)
self.hint_for_model = """
Make sure your response is in the following format:
Response Format:
{
"command": "call code writer, the tester, or to finish",
"command_args": {
"arg name": "value"
}
}
"""
@staticmethod
def _build_commands_manual(commands: List[Command]) -> str:
ret = ""
for i, command in enumerate(commands):
command_input_json_schema = json.dumps(
{input_arg: f"YOUR_{input_arg.upper()}" for input_arg in command.input_args})
ret += f"{i + 1}. {command.name}: {command.description} Input arguments (given in the JSON schema): {command_input_json_schema}\n"
return ret
@classmethod
def instantiate_from_config(cls, config):
flow_config = deepcopy(config)
kwargs = {"flow_config": flow_config}
# ~~~ Set up prompts ~~~
kwargs.update(cls._set_up_prompts(flow_config))
# ~~~Set up backend ~~~
kwargs.update(cls._set_up_backend(flow_config))
# ~~~ Set up commands ~~~
commands = flow_config["commands"]
commands = [
Command(name, command_conf["description"], command_conf["input_args"]) for name, command_conf in
commands.items()
]
kwargs.update({"commands": commands})
# ~~~ Instantiate flow ~~~
return cls(**kwargs)
def _update_prompts_and_input(self, input_data: Dict[str, Any]):
if 'goal' in input_data:
input_data['goal'] += self.hint_for_model
if 'feedback' in input_data:
input_data['feedback'] += self.hint_for_model
def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
self._update_prompts_and_input(input_data)
# ~~~when conversation is initialized, append the updated system prompts to the chat history ~~~
if self._is_conversation_initialized():
updated_system_message_content = self._get_message(self.system_message_prompt_template, input_data)
self._state_update_add_chat_message(content=updated_system_message_content,
role=self.flow_config["system_name"])
while True:
api_output = super().run(input_data)["api_output"].strip()
try:
response = json.loads(api_output)
return response
except (json.decoder.JSONDecodeError, json.JSONDecodeError):
updated_system_message_content = self._get_message(self.system_message_prompt_template, input_data)
self._state_update_add_chat_message(content=updated_system_message_content,
role=self.flow_config["system_name"])
new_goal = "The previous respond cannot be parsed with json.loads. Next time, do not provide any comments or code blocks. Make sure your next response is purely json parsable."
new_input_data = input_data.copy()
new_input_data['feedback'] = new_goal
input_data = new_input_data