File size: 3,385 Bytes
de651ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d10681
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import json
import chainlit as cl

from operator import itemgetter
from dotenv import load_dotenv

from langchain_community.tools.ddg_search import DuckDuckGoSearchRun
from langchain_community.tools.reddit_search.tool import RedditSearchRun
from langchain_community.utilities.reddit_search import RedditSearchAPIWrapper
from langchain_openai import ChatOpenAI
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_core.messages import FunctionMessage, HumanMessage
from langchain.schema.runnable.config import RunnableConfig
from langchain.schema import StrOutputParser

from langgraph.prebuilt import ToolExecutor
from langgraph.prebuilt import ToolInvocation
from langgraph.graph import StateGraph, END

from searches import GoodReadsSearch
from utils import AgentState

async def call_model(state: AgentState, config: RunnableConfig):
  messages = state["messages"]
  response = await model.ainvoke(messages, config)
  return {"messages" : [response]}

def call_tool(state):
  last_message = state["messages"][-1]

  action = ToolInvocation(
      tool=last_message.additional_kwargs["function_call"]["name"],
      tool_input=json.loads(
          last_message.additional_kwargs["function_call"]["arguments"]
      )
  )

  response = tool_executor.invoke(action)

  function_message = FunctionMessage(content=str(response), name=action.tool)

  return {"messages" : [function_message]}

def should_continue(state):
  last_message = state["messages"][-1]

  if "function_call" not in last_message.additional_kwargs:
    return "end"

  return "continue"

load_dotenv()

REDDIT_CLIENT_ID = os.environ["REDDIT_CLIENT_ID"]
REDDIT_CLIENT_SECRET = os.environ["REDDIT_CLIENT_SECRET"]
REDDIT_USER_AGENT = os.environ["REDDIT_USER_AGENT"]
OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]

tool_belt = [
    DuckDuckGoSearchRun(),
    RedditSearchRun(
        api_wrapper=RedditSearchAPIWrapper(
            reddit_client_id=REDDIT_CLIENT_ID,
            reddit_client_secret=REDDIT_CLIENT_SECRET,
            reddit_user_agent=REDDIT_USER_AGENT,
        )
    ),
    GoodReadsSearch()
]

tool_executor = ToolExecutor(tool_belt)
model = ChatOpenAI(model="gpt-4o-mini", temperature=0, streaming=True)
functions = [convert_to_openai_function(t) for t in tool_belt]
model = model.bind_functions(functions)

workflow = StateGraph(AgentState)
workflow.add_node("agent", call_model)
workflow.add_node("action", call_tool)
workflow.set_entry_point("agent")
workflow.add_conditional_edges(
    "agent",
    should_continue,
    {
        "continue" : "action",
        "end" : END
    }
)
workflow.add_edge("action", "agent")

app = workflow.compile()

@cl.on_chat_start
async def start_chat():
    """
    """
    cl.user_session.set("agent", app)

@cl.on_message  
async def main(message: cl.Message):
    """
    """
    agent = cl.user_session.get("agent")
    inputs = {"messages" : [HumanMessage(content=str(message.content))]}
    cb = cl.LangchainCallbackHandler(stream_final_answer=True)
    config = RunnableConfig(callbacks=[cb])

    msg = cl.Message(content="")   
    await msg.send()    

    async for event in agent.astream_events(inputs, config=config, version="v1"):
        kind = event["event"]
        if kind == "on_chat_model_stream":
            await msg.stream_token(event["data"]["chunk"].content)

    await msg.update()