Spaces:
Running
Running
updated
Browse files
agent.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import os
|
2 |
import pandas as pd
|
3 |
import requests
|
|
|
4 |
from pydantic import Field, BaseModel
|
5 |
|
6 |
from omegaconf import OmegaConf
|
@@ -11,7 +12,6 @@ from vectara_agentic.tools import ToolsFactory, VectaraToolFactory
|
|
11 |
from dotenv import load_dotenv
|
12 |
load_dotenv(override=True)
|
13 |
|
14 |
-
|
15 |
tickers = {
|
16 |
"C": "Citigroup",
|
17 |
"COF": "Capital One",
|
@@ -51,9 +51,10 @@ def create_assistant_tools(cfg):
|
|
51 |
return years
|
52 |
|
53 |
# Tool to get the income statement for a given company and year using the FMP API
|
|
|
54 |
def fmp_income_statement(
|
55 |
-
ticker: str = Field(description="the ticker symbol of the company."),
|
56 |
-
year: int = Field(description="the year for which to get the income statement."),
|
57 |
) -> str:
|
58 |
"""
|
59 |
Get the income statement for a given company and year using the FMP (https://financialmodelingprep.com) API.
|
@@ -80,7 +81,11 @@ def create_assistant_tools(cfg):
|
|
80 |
|
81 |
class QueryTranscriptsArgs(BaseModel):
|
82 |
query: str = Field(..., description="The user query, always in the form of a question", examples=["what are the risks reported?", "who are the competitors?"])
|
83 |
-
year: int | str = Field(
|
|
|
|
|
|
|
|
|
84 |
ticker: str = Field(..., description=f"The company ticker this query relates to. Must be a valid ticket symbol from the list {list(tickers.keys())}.")
|
85 |
|
86 |
vec_factory = VectaraToolFactory(vectara_api_key=cfg.api_key,
|
@@ -99,6 +104,27 @@ def create_assistant_tools(cfg):
|
|
99 |
summary_num_results = 10,
|
100 |
vectara_summarizer = summarizer,
|
101 |
include_citations = True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
)
|
103 |
|
104 |
tools_factory = ToolsFactory()
|
@@ -111,7 +137,7 @@ def create_assistant_tools(cfg):
|
|
111 |
]
|
112 |
] +
|
113 |
tools_factory.financial_tools() +
|
114 |
-
[ask_transcripts]
|
115 |
)
|
116 |
|
117 |
def initialize_agent(_cfg, agent_progress_callback=None):
|
@@ -129,12 +155,15 @@ def initialize_agent(_cfg, agent_progress_callback=None):
|
|
129 |
- When querying a tool for a numeric value or KPI, use a concise and non-ambiguous description of what you are looking for.
|
130 |
- If you calculate a metric, make sure you have all the necessary information to complete the calculation. Don't guess.
|
131 |
"""
|
|
|
|
|
132 |
|
133 |
agent = Agent(
|
134 |
tools=create_assistant_tools(_cfg),
|
135 |
topic="Financial data, annual reports and 10-K filings",
|
136 |
custom_instructions=financial_bot_instructions,
|
137 |
agent_progress_callback=agent_progress_callback,
|
|
|
138 |
)
|
139 |
agent.report()
|
140 |
return agent
|
|
|
1 |
import os
|
2 |
import pandas as pd
|
3 |
import requests
|
4 |
+
from functools import lru_cache
|
5 |
from pydantic import Field, BaseModel
|
6 |
|
7 |
from omegaconf import OmegaConf
|
|
|
12 |
from dotenv import load_dotenv
|
13 |
load_dotenv(override=True)
|
14 |
|
|
|
15 |
tickers = {
|
16 |
"C": "Citigroup",
|
17 |
"COF": "Capital One",
|
|
|
51 |
return years
|
52 |
|
53 |
# Tool to get the income statement for a given company and year using the FMP API
|
54 |
+
@lru_cache(maxsize=128)
|
55 |
def fmp_income_statement(
|
56 |
+
ticker: str = Field(description="the ticker symbol of the company.", examples=["AAPL", "GOOG", "AMZN"]),
|
57 |
+
year: int = Field(description="the year for which to get the income statement.", examples=[2020, 2021, 2022]),
|
58 |
) -> str:
|
59 |
"""
|
60 |
Get the income statement for a given company and year using the FMP (https://financialmodelingprep.com) API.
|
|
|
81 |
|
82 |
class QueryTranscriptsArgs(BaseModel):
|
83 |
query: str = Field(..., description="The user query, always in the form of a question", examples=["what are the risks reported?", "who are the competitors?"])
|
84 |
+
year: int | str = Field(
|
85 |
+
default=None,
|
86 |
+
description=f"The year this query relates to. An integer between {min(years)} and {max(years)} or a string specifying a condition on the year",
|
87 |
+
examples=[2020, '>2021', '<2023', '>=2021', '<=2023', '[2021, 2023]', '[2021, 2023)']
|
88 |
+
)
|
89 |
ticker: str = Field(..., description=f"The company ticker this query relates to. Must be a valid ticket symbol from the list {list(tickers.keys())}.")
|
90 |
|
91 |
vec_factory = VectaraToolFactory(vectara_api_key=cfg.api_key,
|
|
|
104 |
summary_num_results = 10,
|
105 |
vectara_summarizer = summarizer,
|
106 |
include_citations = True,
|
107 |
+
verbose=True,
|
108 |
+
)
|
109 |
+
|
110 |
+
class SearchTranscriptsArgs(BaseModel):
|
111 |
+
query: str = Field(..., description="The user query, always in the form of a question", examples=["what are the risks reported?", "who are the competitors?"])
|
112 |
+
top_k: int = Field(..., description="The number of results to return.")
|
113 |
+
year: int | str = Field(
|
114 |
+
default=None,
|
115 |
+
description=f"The year this query relates to. An integer between {min(years)} and {max(years)} or a string specifying a condition on the year",
|
116 |
+
examples=[2020, '>2021', '<2023', '>=2021', '<=2023', '[2021, 2023]', '[2021, 2023)']
|
117 |
+
)
|
118 |
+
ticker: str = Field(..., description=f"The company ticker this query relates to. Must be a valid ticket symbol from the list {list(tickers.keys())}.")
|
119 |
+
search_transcripts = vec_factory.create_search_tool(
|
120 |
+
tool_name = "search_transcripts",
|
121 |
+
tool_description = """
|
122 |
+
Given a company name and year, and a user query, retrieves the most relevant text from analyst call transcripts about the company related to the user query.
|
123 |
+
""",
|
124 |
+
tool_args_schema = QueryTranscriptsArgs,
|
125 |
+
reranker = "multilingual_reranker_v1", rerank_k = 100,
|
126 |
+
lambda_val = 0.005,
|
127 |
+
verbose=True
|
128 |
)
|
129 |
|
130 |
tools_factory = ToolsFactory()
|
|
|
137 |
]
|
138 |
] +
|
139 |
tools_factory.financial_tools() +
|
140 |
+
[ask_transcripts, search_transcripts]
|
141 |
)
|
142 |
|
143 |
def initialize_agent(_cfg, agent_progress_callback=None):
|
|
|
155 |
- When querying a tool for a numeric value or KPI, use a concise and non-ambiguous description of what you are looking for.
|
156 |
- If you calculate a metric, make sure you have all the necessary information to complete the calculation. Don't guess.
|
157 |
"""
|
158 |
+
def query_logging(query: str, response: str):
|
159 |
+
print(f"Logging query={query}, response={response}")
|
160 |
|
161 |
agent = Agent(
|
162 |
tools=create_assistant_tools(_cfg),
|
163 |
topic="Financial data, annual reports and 10-K filings",
|
164 |
custom_instructions=financial_bot_instructions,
|
165 |
agent_progress_callback=agent_progress_callback,
|
166 |
+
query_logging_callback=query_logging,
|
167 |
)
|
168 |
agent.report()
|
169 |
return agent
|
app.py
CHANGED
@@ -12,7 +12,10 @@ if 'device_id' not in st.session_state:
|
|
12 |
if "feedback_key" not in st.session_state:
|
13 |
st.session_state.feedback_key = 0
|
14 |
|
|
|
|
|
|
|
15 |
if __name__ == "__main__":
|
16 |
st.set_page_config(page_title="Financial Assistant", layout="wide")
|
17 |
nest_asyncio.apply()
|
18 |
-
asyncio.run(
|
|
|
12 |
if "feedback_key" not in st.session_state:
|
13 |
st.session_state.feedback_key = 0
|
14 |
|
15 |
+
async def main():
|
16 |
+
await launch_bot()
|
17 |
+
|
18 |
if __name__ == "__main__":
|
19 |
st.set_page_config(page_title="Financial Assistant", layout="wide")
|
20 |
nest_asyncio.apply()
|
21 |
+
asyncio.run(main())
|
requirements.txt
CHANGED
@@ -6,4 +6,4 @@ streamlit_feedback==0.1.3
|
|
6 |
uuid==1.30
|
7 |
langdetect==1.0.9
|
8 |
langcodes==3.4.0
|
9 |
-
vectara-agentic==0.1.
|
|
|
6 |
uuid==1.30
|
7 |
langdetect==1.0.9
|
8 |
langcodes==3.4.0
|
9 |
+
vectara-agentic==0.1.27
|
st_app.py
CHANGED
@@ -13,15 +13,6 @@ from agent import initialize_agent, get_agent_config
|
|
13 |
|
14 |
initial_prompt = "How can I help you today?"
|
15 |
|
16 |
-
def show_example_questions():
|
17 |
-
if len(st.session_state.example_messages) > 0 and st.session_state.first_turn:
|
18 |
-
selected_example = pills("Queries to Try:", st.session_state.example_messages, index=None)
|
19 |
-
if selected_example:
|
20 |
-
st.session_state.ex_prompt = selected_example
|
21 |
-
st.session_state.first_turn = False
|
22 |
-
return True
|
23 |
-
return False
|
24 |
-
|
25 |
def format_log_msg(log_msg: str):
|
26 |
max_log_msg_size = 500
|
27 |
return log_msg if len(log_msg) <= max_log_msg_size else log_msg[:max_log_msg_size]+'...'
|
@@ -37,7 +28,13 @@ def agent_progress_callback(status_type: AgentStatusType, msg: str):
|
|
37 |
latest_message = f"Calling tool {tool_name}..."
|
38 |
elif status_type == AgentStatusType.TOOL_OUTPUT:
|
39 |
latest_message = "Analyzing tool output..."
|
|
|
|
|
|
|
|
|
|
|
40 |
else:
|
|
|
41 |
return
|
42 |
|
43 |
st.session_state.status.update(label=latest_message)
|
@@ -46,6 +43,16 @@ def agent_progress_callback(status_type: AgentStatusType, msg: str):
|
|
46 |
for log_msg in st.session_state.log_messages:
|
47 |
st.markdown(format_log_msg(log_msg), unsafe_allow_html=True)
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
@st.dialog(title="Agent logs", width='large')
|
50 |
def show_modal():
|
51 |
for log_msg in st.session_state.log_messages:
|
@@ -132,8 +139,16 @@ async def launch_bot():
|
|
132 |
if st.session_state.prompt:
|
133 |
with st.chat_message("assistant", avatar='🤖'):
|
134 |
st.session_state.status = st.status('Processing...', expanded=False)
|
135 |
-
|
136 |
-
res = escape_dollars_outside_latex(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
message = {"role": "assistant", "content": res, "avatar": '🤖'}
|
138 |
st.session_state.messages.append(message)
|
139 |
st.markdown(res)
|
|
|
13 |
|
14 |
initial_prompt = "How can I help you today?"
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
def format_log_msg(log_msg: str):
|
17 |
max_log_msg_size = 500
|
18 |
return log_msg if len(log_msg) <= max_log_msg_size else log_msg[:max_log_msg_size]+'...'
|
|
|
28 |
latest_message = f"Calling tool {tool_name}..."
|
29 |
elif status_type == AgentStatusType.TOOL_OUTPUT:
|
30 |
latest_message = "Analyzing tool output..."
|
31 |
+
elif status_type == AgentStatusType.AGENT_UPDATE:
|
32 |
+
if "Thought:" in msg:
|
33 |
+
latest_message = "Thinking..."
|
34 |
+
else:
|
35 |
+
latest_message = "Updating agent..."
|
36 |
else:
|
37 |
+
print(f"callback with {status_type} and {msg}")
|
38 |
return
|
39 |
|
40 |
st.session_state.status.update(label=latest_message)
|
|
|
43 |
for log_msg in st.session_state.log_messages:
|
44 |
st.markdown(format_log_msg(log_msg), unsafe_allow_html=True)
|
45 |
|
46 |
+
|
47 |
+
def show_example_questions():
|
48 |
+
if len(st.session_state.example_messages) > 0 and st.session_state.first_turn:
|
49 |
+
selected_example = pills("Queries to Try:", st.session_state.example_messages, index=None)
|
50 |
+
if selected_example:
|
51 |
+
st.session_state.ex_prompt = selected_example
|
52 |
+
st.session_state.first_turn = False
|
53 |
+
return True
|
54 |
+
return False
|
55 |
+
|
56 |
@st.dialog(title="Agent logs", width='large')
|
57 |
def show_modal():
|
58 |
for log_msg in st.session_state.log_messages:
|
|
|
139 |
if st.session_state.prompt:
|
140 |
with st.chat_message("assistant", avatar='🤖'):
|
141 |
st.session_state.status = st.status('Processing...', expanded=False)
|
142 |
+
response = st.session_state.agent.chat(st.session_state.prompt)
|
143 |
+
res = escape_dollars_outside_latex(response.response)
|
144 |
+
|
145 |
+
#response = await st.session_state.agent.achat(st.session_state.prompt)
|
146 |
+
#res = escape_dollars_outside_latex(response.response)
|
147 |
+
|
148 |
+
#res = await st.session_state.agent.astream_chat(st.session_state.prompt)
|
149 |
+
#response = ''.join([token async for token in res.async_response_gen()])
|
150 |
+
#res = escape_dollars_outside_latex(response)
|
151 |
+
|
152 |
message = {"role": "assistant", "content": res, "avatar": '🤖'}
|
153 |
st.session_state.messages.append(message)
|
154 |
st.markdown(res)
|