File size: 4,210 Bytes
81ff78d
92fb29b
e327755
043dc9e
92fb29b
 
e327755
 
 
 
043dc9e
e327755
 
81ff78d
 
 
043dc9e
 
 
 
 
b6feb1b
92fb29b
27d6c77
 
043dc9e
 
 
 
 
 
 
 
81ff78d
043dc9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e327755
043dc9e
e327755
 
043dc9e
81ff78d
043dc9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81ff78d
 
043dc9e
 
 
81ff78d
043dc9e
 
e327755
81ff78d
043dc9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e327755
043dc9e
 
 
81ff78d
 
 
 
 
 
043dc9e
 
81ff78d
 
 
 
 
043dc9e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
import langgraph as lg
from typing import TypedDict, Annotated, Sequence
from huggingface_hub import InferenceClient, login, list_models
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.graph.message import add_messages
from langchain.tools import Tool  # Updated import
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage  # Added message types
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langgraph.graph import START,END, StateGraph
import os


"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
login(token=HUGGINGFACEHUB_API_TOKEN, add_to_git_credential=True)

llm = HuggingFaceEndpoint(
    #repo_id="HuggingFaceH4/zephyr-7b-beta",
    #repo_id="Qwen/Qwen2.5-Coder-32B-Instruct",
    #repo_id="deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
    #repo_id="deepseek-ai/DeepSeek-Coder-V2-Instruct",
    repo_id="migtissera/Trinity-2-Codestral-22B",
    task="text-generation",
    max_new_tokens=512,
    do_sample=False,
    repetition_penalty=1.03,
    timeout=240,
)

model = ChatHuggingFace(llm=llm, verbose=True)

def get_hub_stats(author: str) -> str:
    """
    You are a helpful chatbot for programmers and data scientists with access to the Hugging Face Hub. 
    Users will want to know the most popular models from Hugging Face. This tool will enable
    you to fetch the most downloaded model from a specific author on the Hugging Face Hub.
    """
    try:
        # List models from the specified author, sorted by downloads
        models = list(list_models(author=author, sort="downloads", direction=-1, limit=1))

        if models:
            model = models[0]
            return f"The most downloaded model by {author} is {model.id} with {model.downloads:,} downloads."
        else:
            return f"No models found for author {author}."
    except Exception as e:
        return f"Error fetching models for {author}: {str(e)}"

# Initialize the tool
hub_stats_tool = Tool.from_function(  # Use proper Tool initialization
    func=get_hub_stats,
    name="get_hub_stats",
    description="Fetches popular models from Hugging Face Hub"
)

def predict(message, history):
    # Convert Gradio history to LangChain message format
    history_langchain_format = []
    for msg in history:
        if msg['role'] == "user":
            history_langchain_format.append(HumanMessage(content=msg['content']))
        elif msg['role'] == "assistant":
            history_langchain_format.append(AIMessage(content=msg['content']))
    
    # Add new user message
    history_langchain_format.append(HumanMessage(content=message))
    
    # Invoke Alfred agent with full message history
    response = alfred.invoke(
        input={"messages": history_langchain_format},
        config={"recursion_limit": 100}
    )
    
    # Extract final assistant message
    return response["messages"][-1].content


# setup agents
tools = [hub_stats_tool]
chat_with_tools = model.bind_tools(tools)

# Generate the AgentState and Agent graph
class AgentState(TypedDict):
    messages: Annotated[list[BaseMessage], add_messages] 

def assistant(state: AgentState):
    return {
        "messages": [chat_with_tools.invoke(state["messages"])],
    }

## The graph
builder = StateGraph(AgentState)

# Define nodes: these do the work
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))

# Define edges: these determine how the control flow moves
builder.add_edge(START, "assistant")
builder.add_conditional_edges(
    "assistant",
    # If the latest message requires a tool, route to tools
    # Otherwise, provide a direct response
    tools_condition,
    {"tools": "tools", "end": END}
)
builder.add_edge("tools", "assistant")
alfred = builder.compile()


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    predict,
    type="messages"
)


if __name__ == "__main__":
    demo.launch()