File size: 4,490 Bytes
4d5a5ba
b2a1c79
f878b1b
e49d5ca
 
2ce0b48
f878b1b
 
2ce0b48
 
 
4d5a5ba
2ce0b48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d5a5ba
2ce0b48
 
 
 
 
 
 
 
 
 
 
 
 
4d5a5ba
2ce0b48
 
 
 
4d5a5ba
2ce0b48
 
 
 
 
4d5a5ba
178fcdd
 
 
 
 
 
564da1a
 
 
 
 
 
 
4b1ac20
b2a1c79
 
4d5a5ba
2ce0b48
 
 
 
 
 
 
 
 
 
 
 
4d5a5ba
 
 
 
 
 
 
178fcdd
 
 
 
 
 
 
 
 
 
 
 
4d5a5ba
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import gradio as gr
from huggingface_hub import login
from smolagents import HfApiModel, Tool, CodeAgent

import os
import sys
import json

if './lib' not in sys.path :
    sys.path.append('./lib')
from ingestion_chroma import retrieve_info_from_db

############################################################################################
################################### TOOLS ##################################################
############################################################################################

def find_key(data, target_key):
    if isinstance(data, dict):
        for key, value in data.items():
            if key == target_key:
                return value
            else:
                result = find_key(value, target_key)
                if result is not None:
                    return result
    return "Indicator not found"

############################################################################################

class Chroma_retrieverTool(Tool):
    name = "request"
    description = "Using semantic similarity, retrieve the text from the knowledge base that has the embedding closest to the query."
    inputs = {
        "query": {
            "type": "string",
            "description": "The query to execute must be semantically close to the text to search. Use the affirmative form rather than a question.",
        },
    }
    output_type = "string"
    
    def forward(self, query: str) -> str:
        assert isinstance(query, str), "The request needs to be a string."
        
        query_results = retrieve_info_from_db(query)
        str_result = "\nRetrieval texts : \n" + "".join([f"===== Text {str(i)} =====\n" + query_results['documents'][0][i] for i in range(len(query_results['documents'][0]))])

        return str_result
        
############################################################################################

class ESRS_info_tool(Tool):
    name = "find_ESRS"
    description = "Find ESRS description to help you to find what indicators the user want"
    inputs = {
        "indicator": {
            "type": "string",
            "description": "The indicator name. return the description of the indicator demanded.",
        },
    }
    output_type = "string"
    
    def forward(self, indicator: str) -> str:
        assert isinstance(indicator, str), "The request needs to be a string."

        with open('./data/dico_esrs.json') as json_data:
            dico_esrs = json.load(json_data)
        
        result = find_key(dico_esrs, indicator)

        return result
        
############################################################################################
############################################################################################
############################################################################################

def respond(message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,):
    system_prompt_added = """You are an expert in environmental and corporate social responsibility. You must respond to requests using the query function in the document database. 
User's question : """
    agent_output = agent.run(system_prompt_added + message)
    
    yield agent_output
    
############################################################################################
hf_token = os.getenv("HF_TOKEN_all")
login(hf_token)
model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct")

retriever_tool = Chroma_retrieverTool()
get_ESRS_info_tool = ESRS_info_tool()
agent = CodeAgent(
    tools=[
            get_ESRS_info_tool,
            retriever_tool,
          ],
    model=model,
    max_steps=10,
    max_print_outputs_length=16000,
    additional_authorized_imports=['pandas', 'matplotlib', 'datetime']
    )


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)


if __name__ == "__main__":
    demo.launch()