Spaces:
Sleeping
Sleeping
import gradio as gr | |
from huggingface_hub import login | |
from smolagents import HfApiModel, Tool, CodeAgent | |
import os | |
import sys | |
import json | |
if './lib' not in sys.path : | |
sys.path.append('./lib') | |
from ingestion_chroma import retrieve_info_from_db | |
############################################################################################ | |
################################### TOOLS ################################################## | |
############################################################################################ | |
def search_key(d, target_key): | |
""" | |
Recherche une clé dans un dictionnaire imbriqué. | |
:param d: Le dictionnaire dans lequel chercher. | |
:param target_key: La clé à chercher. | |
:return: Une liste des valeurs associées à la clé trouvée. | |
""" | |
results = [] | |
def recursive_search(d): | |
if isinstance(d, dict): | |
for key, value in d.items(): | |
if key == target_key: | |
results.append(value) | |
if isinstance(value, dict): | |
recursive_search(value) | |
elif isinstance(value, list): | |
for item in value: | |
if isinstance(item, dict): | |
recursive_search(item) | |
recursive_search(d) | |
if len(results)>0: | |
return str(results[0]) | |
else : | |
return "Indicator not found. Try globals indicators in this list : ['ESRS E4', 'ESRS 2 MDR', 'ESRS S2', 'ESRS E2', 'ESRS S4', 'ESRS E5', 'ESRS 2', 'ESRS E1', 'ESRS S3', 'ESRS S1', 'ESRS G1', 'ESRS E3']" | |
############################################################################################ | |
class Chroma_retrieverTool(Tool): | |
name = "request" | |
description = "Using semantic similarity, retrieve the text from the knowledge base that has the embedding closest to the query." | |
inputs = { | |
"query": { | |
"type": "string", | |
"description": "The query to execute must be semantically close to the text to search. Use the affirmative form rather than a question.", | |
}, | |
} | |
output_type = "string" | |
def forward(self, query: str) -> str: | |
assert isinstance(query, str), "The request needs to be a string." | |
query_results = retrieve_info_from_db(query) | |
str_result = "\nRetrieval texts : \n" + "".join([f"===== Text {str(i)} =====\n" + query_results['documents'][0][i] for i in range(len(query_results['documents'][0]))]) | |
return str_result | |
############################################################################################ | |
class ESRS_info_tool(Tool): | |
name = "find_ESRS" | |
description = "Find ESRS description to help you to find what indicators the user want to analyze" | |
inputs = { | |
"indicator": { | |
"type": "string", | |
"description": "The indicator name with format for example like following 'ESRS EX' or 'EX'. return the description of the indicator demanded.", | |
}, | |
} | |
output_type = "string" | |
def forward(self, indicator: str) -> str: | |
assert isinstance(indicator, str), "The request needs to be a string." | |
with open('./data/dico_esrs.json') as json_data: | |
dico_esrs = json.load(json_data) | |
result = search_key(dico_esrs, indicator) | |
return result | |
############################################################################################ | |
############################################################################################ | |
############################################################################################ | |
def respond(message, | |
history: list[tuple[str, str]], | |
system_message, | |
max_tokens, | |
temperature, | |
top_p,): | |
system_prompt_added = """You are an expert in environmental and corporate social responsibility. You must respond to requests using the query function in the document database. | |
User's question : """ | |
agent_output = agent.run(system_prompt_added + message) | |
yield agent_output | |
############################################################################################ | |
hf_token = os.getenv("HF_TOKEN_all") | |
login(hf_token) | |
model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct") | |
retriever_tool = Chroma_retrieverTool() | |
get_ESRS_info_tool = ESRS_info_tool() | |
agent = CodeAgent( | |
tools=[ | |
get_ESRS_info_tool, | |
retriever_tool, | |
], | |
model=model, | |
max_steps=10, | |
max_print_outputs_length=16000, | |
additional_authorized_imports=['pandas', 'matplotlib', 'datetime'] | |
) | |
""" | |
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface | |
""" | |
demo = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox(value="You are a friendly Chatbot.", label="System message"), | |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (nucleus sampling)", | |
), | |
], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |