File size: 3,810 Bytes
8176958
1f7a3b4
2ce0b48
 
 
 
8176958
2ce0b48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8176958
2ce0b48
 
 
 
 
 
 
 
 
 
 
 
 
8176958
2ce0b48
 
 
 
8176958
2ce0b48
 
 
 
 
8176958
2ce0b48
8176958
2ce0b48
 
 
 
 
 
 
 
 
 
 
 
8176958
 
2ce0b48
 
 
 
 
 
8176958
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import gradio as gr
from smolagents import HfApiModel
import sys
if './lib' not in sys.path :
    sys.path.append('./lib')
from ingestion_chroma import retrieve_info_from_db

############################################################################################
################################### TOOLS ##################################################
############################################################################################

def find_key(data, target_key):
    if isinstance(data, dict):
        for key, value in data.items():
            if key == target_key:
                return value
            else:
                result = find_key(value, target_key)
                if result is not None:
                    return result
    return "Indicator not found"

############################################################################################

class Chroma_retrieverTool(Tool):
    name = "request"
    description = "Using semantic similarity, retrieve the text from the knowledge base that has the embedding closest to the query."
    inputs = {
        "query": {
            "type": "string",
            "description": "The query to execute must be semantically close to the text to search. Use the affirmative form rather than a question.",
        },
    }
    output_type = "string"
    
    def forward(self, query: str) -> str:
        assert isinstance(query, str), "The request needs to be a string."
        
        query_results = retrieve_info_from_db(query)
        str_result = "\nRetrieval texts : \n" + "".join([f"===== Text {str(i)} =====\n" + query_results['documents'][0][i] for i in range(len(query_results['documents'][0]))])

        return str_result
        
############################################################################################

class ESRS_info_tool(Tool):
    name = "find_ESRS"
    description = "Find ESRS description to help you to find what indicators the user want"
    inputs = {
        "indicator": {
            "type": "string",
            "description": "The indicator name. return the description of the indicator demanded.",
        },
    }
    output_type = "string"
    
    def forward(self, indicator: str) -> str:
        assert isinstance(indicator, str), "The request needs to be a string."

        with open('./data/dico_esrs.json') as json_data:
            dico_esrs = json.load(json_data)
        
        result = find_key(dico_esrs, indicator)

        return result
        
############################################################################################
############################################################################################
############################################################################################

model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct")

retriever_tool = Chroma_retrieverTool()
get_ESRS_info_tool = ESRS_info_tool()
agent = CodeAgent(
    tools=[
            get_ESRS_info_tool,
            retriever_tool,
          ],
    model=model,
    max_steps=10,
    max_print_outputs_length=16000,
    additional_authorized_imports=['pandas', 'matplotlib', 'datetime']
    )


def respond(message):
    system_prompt_added = """You are an expert in environmental and corporate social responsibility. You must respond to requests using the query function in the document database. 
User's question : """
    agent_output = agent.run(system_prompt_added+"""Find all informations about the ESRS E1–5: Energy consumption from fossil sources in Sartorius documents.""")
    
    yield agent_output


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
)


if __name__ == "__main__":
    demo.launch()