File size: 2,397 Bytes
e854384
6388682
 
9f6641f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cb7519
6388682
9f6641f
 
 
 
 
 
 
 
 
6388682
9f6641f
 
 
6388682
9f6641f
 
 
 
6388682
9f6641f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# type: ignore
from __future__ import annotations

from gradio import ChatMessage
from transformers.agents import ReactCodeAgent, agent_types
from typing import Generator

def pull_message(step_log: dict):
    if step_log.get("rationale"):
        yield ChatMessage(
            role="assistant", content=step_log["rationale"]
        )
    if step_log.get("tool_call"):
        used_code = step_log["tool_call"]["tool_name"] == "code interpreter"
        content = step_log["tool_call"]["tool_arguments"]
        if used_code:
            content = f"```py\n{content}\n```"
        yield ChatMessage(
            role="assistant",
            metadata={"title": f"🛠️ Used tool {step_log['tool_call']['tool_name']}"},
            content=content,
        )
    if step_log.get("observation"):
        yield ChatMessage(
            role="assistant", content=f"```\n{step_log['observation']}\n```"
        )
    if step_log.get("error"):
        yield ChatMessage(
            role="assistant",
            content=str(step_log["error"]),
            metadata={"title": "💥 Error"},
        )

def stream_from_transformers_agent(
    agent: ReactCodeAgent, prompt: str
) -> Generator[ChatMessage, None, ChatMessage | None]:
    """Runs an agent with the given prompt and streams the messages from the agent as ChatMessages."""

    class Output:
        output: agent_types.AgentType | str = None

    step_log = None
    for step_log in agent.run(prompt, stream=True):
        if isinstance(step_log, dict):
            for message in pull_message(step_log):
                print("message", message)
                yield message

    Output.output = step_log
    if isinstance(Output.output, agent_types.AgentText):
        yield ChatMessage(
            role="assistant", content=f"**Final answer:**\n```\n{Output.output.to_string()}\n```")  # type: ignore
    elif isinstance(Output.output, agent_types.AgentImage):
        yield ChatMessage(
            role="assistant",
            content={"path": Output.output.to_string(), "mime_type": "image/png"},  # type: ignore
        )
    elif isinstance(Output.output, agent_types.AgentAudio):
        yield ChatMessage(
            role="assistant",
            content={"path": Output.output.to_string(), "mime_type": "audio/wav"},  # type: ignore
        )
    else:
        return ChatMessage(role="assistant", content=Output.output)