Matteo-CNPPS's picture
Update app.py
b2a1c79 verified
raw
history blame
4.49 kB
import gradio as gr
from huggingface_hub import login
from smolagents import HfApiModel, Tool, CodeAgent
import os
import sys
import json
if './lib' not in sys.path :
sys.path.append('./lib')
from ingestion_chroma import retrieve_info_from_db
############################################################################################
################################### TOOLS ##################################################
############################################################################################
def find_key(data, target_key):
if isinstance(data, dict):
for key, value in data.items():
if key == target_key:
return value
else:
result = find_key(value, target_key)
if result is not None:
return result
return "Indicator not found"
############################################################################################
class Chroma_retrieverTool(Tool):
name = "request"
description = "Using semantic similarity, retrieve the text from the knowledge base that has the embedding closest to the query."
inputs = {
"query": {
"type": "string",
"description": "The query to execute must be semantically close to the text to search. Use the affirmative form rather than a question.",
},
}
output_type = "string"
def forward(self, query: str) -> str:
assert isinstance(query, str), "The request needs to be a string."
query_results = retrieve_info_from_db(query)
str_result = "\nRetrieval texts : \n" + "".join([f"===== Text {str(i)} =====\n" + query_results['documents'][0][i] for i in range(len(query_results['documents'][0]))])
return str_result
############################################################################################
class ESRS_info_tool(Tool):
name = "find_ESRS"
description = "Find ESRS description to help you to find what indicators the user want"
inputs = {
"indicator": {
"type": "string",
"description": "The indicator name. return the description of the indicator demanded.",
},
}
output_type = "string"
def forward(self, indicator: str) -> str:
assert isinstance(indicator, str), "The request needs to be a string."
with open('./data/dico_esrs.json') as json_data:
dico_esrs = json.load(json_data)
result = find_key(dico_esrs, indicator)
return result
############################################################################################
############################################################################################
############################################################################################
def respond(message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,):
system_prompt_added = """You are an expert in environmental and corporate social responsibility. You must respond to requests using the query function in the document database.
User's question : """
agent_output = agent.run(system_prompt_added + message)
yield agent_output
############################################################################################
hf_token = os.getenv("HF_TOKEN_all")
login(hf_token)
model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct")
retriever_tool = Chroma_retrieverTool()
get_ESRS_info_tool = ESRS_info_tool()
agent = CodeAgent(
tools=[
get_ESRS_info_tool,
retriever_tool,
],
model=model,
max_steps=10,
max_print_outputs_length=16000,
additional_authorized_imports=['pandas', 'matplotlib', 'datetime']
)
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
if __name__ == "__main__":
demo.launch()