Samuel Thomas commited on
Commit
e327755
·
1 Parent(s): 27d6c77

correct agent

Browse files
Files changed (1) hide show
  1. app.py +12 -18
app.py CHANGED
@@ -1,22 +1,16 @@
1
  import gradio as gr
2
  import langgraph as lg
3
- from typing import TypedDict, Annotated
4
  from huggingface_hub import InferenceClient, login, list_models
5
  from langgraph.prebuilt import ToolNode, tools_condition
6
  from langgraph.graph.message import add_messages
7
- from langgraph.tools import Tool
8
- from langgraph.retrievers import BM25Retriever
9
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFacePipeline
10
- from langgraph.graph.message import add_messages
11
- from langgraph.prebuilt import ToolNode, tools_condition
12
- from langgraph.graph import START, StateGraph
13
- from langgraph.tools.dsl import TextDocument
14
- from langgraph.tools.dsl.query import Query
15
- from langgraph.tools.dsl.answer import Answer
16
- from langgraph.tools.dsl.answer_format import TextAnswer
17
  import os
18
- from langgraph.graph import START, StateGraph
19
- from langchain.tools import Tool
20
  """
21
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
22
  """
@@ -57,13 +51,12 @@ def get_hub_stats(author: str) -> str:
57
  return f"Error fetching models for {author}: {str(e)}"
58
 
59
  # Initialize the tool
60
- hub_stats_tool = Tool(
61
- name="get_hub_stats",
62
  func=get_hub_stats,
63
- description="Fetches the most downloaded model from a specific author on the Hugging Face Hub."
 
64
  )
65
 
66
-
67
  def predict(message, history):
68
  # Convert Gradio history to LangChain message format
69
  history_langchain_format = []
@@ -92,7 +85,7 @@ chat_with_tools = model.bind_tools(tools)
92
 
93
  # Generate the AgentState and Agent graph
94
  class AgentState(TypedDict):
95
- messages: Annotated[list[AnyMessage], add_messages]
96
 
97
  def assistant(state: AgentState):
98
  return {
@@ -113,6 +106,7 @@ builder.add_conditional_edges(
113
  # If the latest message requires a tool, route to tools
114
  # Otherwise, provide a direct response
115
  tools_condition,
 
116
  )
117
  builder.add_edge("tools", "assistant")
118
  alfred = builder.compile()
 
1
  import gradio as gr
2
  import langgraph as lg
3
+ from typing import TypedDict, Annotated, Sequence
4
  from huggingface_hub import InferenceClient, login, list_models
5
  from langgraph.prebuilt import ToolNode, tools_condition
6
  from langgraph.graph.message import add_messages
7
+ from langchain.tools import Tool # Updated import
8
+ from langchain_core.messages import HumanMessage, AIMessage, BaseMessage # Added message types
9
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
10
+ from langgraph.graph import START,END, StateGraph
 
 
 
 
 
 
11
  import os
12
+
13
+
14
  """
15
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
  """
 
51
  return f"Error fetching models for {author}: {str(e)}"
52
 
53
  # Initialize the tool
54
+ hub_stats_tool = Tool.from_function( # Use proper Tool initialization
 
55
  func=get_hub_stats,
56
+ name="get_hub_stats",
57
+ description="Fetches popular models from Hugging Face Hub"
58
  )
59
 
 
60
  def predict(message, history):
61
  # Convert Gradio history to LangChain message format
62
  history_langchain_format = []
 
85
 
86
  # Generate the AgentState and Agent graph
87
  class AgentState(TypedDict):
88
+ messages: Annotated[list[BaseMessage], add_messages]
89
 
90
  def assistant(state: AgentState):
91
  return {
 
106
  # If the latest message requires a tool, route to tools
107
  # Otherwise, provide a direct response
108
  tools_condition,
109
+ {"tools": "tools", "end": END}
110
  )
111
  builder.add_edge("tools", "assistant")
112
  alfred = builder.compile()