Oriaz's picture
Update app.py
6aea901 verified
import gradio as gr
from huggingface_hub import login
from smolagents import HfApiModel, Tool, CodeAgent
import os
import sys
import json
if './lib' not in sys.path :
sys.path.append('./lib')
from ingestion_chroma import retrieve_info_from_db
############################################################################################
################################### TOOLS ##################################################
############################################################################################
def search_key(d, target_key):
"""
Recherche une clé dans un dictionnaire imbriqué.
:param d: Le dictionnaire dans lequel chercher.
:param target_key: La clé à chercher.
:return: Une liste des valeurs associées à la clé trouvée.
"""
results = []
def recursive_search(d):
if isinstance(d, dict):
for key, value in d.items():
if key == target_key:
results.append(value)
if isinstance(value, dict):
recursive_search(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
recursive_search(item)
recursive_search(d)
if len(results)>0:
return str(results[0])
else :
return "Indicator not found. Try globals indicators in this list : ['ESRS E4', 'ESRS 2 MDR', 'ESRS S2', 'ESRS E2', 'ESRS S4', 'ESRS E5', 'ESRS 2', 'ESRS E1', 'ESRS S3', 'ESRS S1', 'ESRS G1', 'ESRS E3']"
############################################################################################
class Chroma_retrieverTool(Tool):
name = "request"
description = "Using semantic similarity, retrieve the text from the knowledge base that has the embedding closest to the query."
inputs = {
"query": {
"type": "string",
"description": "The query to execute must be semantically close to the text to search. Use the affirmative form rather than a question.",
},
}
output_type = "string"
def forward(self, query: str) -> str:
assert isinstance(query, str), "The request needs to be a string."
query_results = retrieve_info_from_db(query)
str_result = "\nRetrieval texts : \n" + "".join([f"===== Text {str(i)} =====\n" + query_results['documents'][0][i] for i in range(len(query_results['documents'][0]))])
return str_result
############################################################################################
class ESRS_info_tool(Tool):
name = "find_ESRS"
description = "Find ESRS description to help you to find what indicators the user want to analyze"
inputs = {
"indicator": {
"type": "string",
"description": "The indicator name with format for example like following 'ESRS EX' or 'EX'. return the description of the indicator demanded.",
},
}
output_type = "string"
def forward(self, indicator: str) -> str:
assert isinstance(indicator, str), "The request needs to be a string."
with open('./data/dico_esrs.json') as json_data:
dico_esrs = json.load(json_data)
result = search_key(dico_esrs, indicator)
return result
############################################################################################
############################################################################################
############################################################################################
def respond(message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,):
system_prompt_added = """You are an expert in environmental and corporate social responsibility. You must respond to requests using the query function in the document database.
User's question : """
agent_output = agent.run(system_prompt_added + message)
yield agent_output
############################################################################################
hf_token = os.getenv("HF_TOKEN_all")
login(hf_token)
model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct")
retriever_tool = Chroma_retrieverTool()
get_ESRS_info_tool = ESRS_info_tool()
agent = CodeAgent(
tools=[
get_ESRS_info_tool,
retriever_tool,
],
model=model,
max_steps=10,
max_print_outputs_length=16000,
additional_authorized_imports=['pandas', 'matplotlib', 'datetime']
)
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
if __name__ == "__main__":
demo.launch()