ofermend commited on
Commit
4792c87
·
1 Parent(s): 1770a97
Files changed (4) hide show
  1. agent.py +34 -5
  2. app.py +4 -1
  3. requirements.txt +1 -1
  4. st_app.py +26 -11
agent.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import pandas as pd
3
  import requests
 
4
  from pydantic import Field, BaseModel
5
 
6
  from omegaconf import OmegaConf
@@ -11,7 +12,6 @@ from vectara_agentic.tools import ToolsFactory, VectaraToolFactory
11
  from dotenv import load_dotenv
12
  load_dotenv(override=True)
13
 
14
-
15
  tickers = {
16
  "C": "Citigroup",
17
  "COF": "Capital One",
@@ -51,9 +51,10 @@ def create_assistant_tools(cfg):
51
  return years
52
 
53
  # Tool to get the income statement for a given company and year using the FMP API
 
54
  def fmp_income_statement(
55
- ticker: str = Field(description="the ticker symbol of the company."),
56
- year: int = Field(description="the year for which to get the income statement."),
57
  ) -> str:
58
  """
59
  Get the income statement for a given company and year using the FMP (https://financialmodelingprep.com) API.
@@ -80,7 +81,11 @@ def create_assistant_tools(cfg):
80
 
81
  class QueryTranscriptsArgs(BaseModel):
82
  query: str = Field(..., description="The user query, always in the form of a question", examples=["what are the risks reported?", "who are the competitors?"])
83
- year: int | str = Field(..., description=f"The year this query relates to. An integer between {min(years)} and {max(years)} or a string specifying a condition on the year (example: '>2020').")
 
 
 
 
84
  ticker: str = Field(..., description=f"The company ticker this query relates to. Must be a valid ticket symbol from the list {list(tickers.keys())}.")
85
 
86
  vec_factory = VectaraToolFactory(vectara_api_key=cfg.api_key,
@@ -99,6 +104,27 @@ def create_assistant_tools(cfg):
99
  summary_num_results = 10,
100
  vectara_summarizer = summarizer,
101
  include_citations = True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  )
103
 
104
  tools_factory = ToolsFactory()
@@ -111,7 +137,7 @@ def create_assistant_tools(cfg):
111
  ]
112
  ] +
113
  tools_factory.financial_tools() +
114
- [ask_transcripts]
115
  )
116
 
117
  def initialize_agent(_cfg, agent_progress_callback=None):
@@ -129,12 +155,15 @@ def initialize_agent(_cfg, agent_progress_callback=None):
129
  - When querying a tool for a numeric value or KPI, use a concise and non-ambiguous description of what you are looking for.
130
  - If you calculate a metric, make sure you have all the necessary information to complete the calculation. Don't guess.
131
  """
 
 
132
 
133
  agent = Agent(
134
  tools=create_assistant_tools(_cfg),
135
  topic="Financial data, annual reports and 10-K filings",
136
  custom_instructions=financial_bot_instructions,
137
  agent_progress_callback=agent_progress_callback,
 
138
  )
139
  agent.report()
140
  return agent
 
1
  import os
2
  import pandas as pd
3
  import requests
4
+ from functools import lru_cache
5
  from pydantic import Field, BaseModel
6
 
7
  from omegaconf import OmegaConf
 
12
  from dotenv import load_dotenv
13
  load_dotenv(override=True)
14
 
 
15
  tickers = {
16
  "C": "Citigroup",
17
  "COF": "Capital One",
 
51
  return years
52
 
53
  # Tool to get the income statement for a given company and year using the FMP API
54
+ @lru_cache(maxsize=128)
55
  def fmp_income_statement(
56
+ ticker: str = Field(description="the ticker symbol of the company.", examples=["AAPL", "GOOG", "AMZN"]),
57
+ year: int = Field(description="the year for which to get the income statement.", examples=[2020, 2021, 2022]),
58
  ) -> str:
59
  """
60
  Get the income statement for a given company and year using the FMP (https://financialmodelingprep.com) API.
 
81
 
82
  class QueryTranscriptsArgs(BaseModel):
83
  query: str = Field(..., description="The user query, always in the form of a question", examples=["what are the risks reported?", "who are the competitors?"])
84
+ year: int | str = Field(
85
+ default=None,
86
+ description=f"The year this query relates to. An integer between {min(years)} and {max(years)} or a string specifying a condition on the year",
87
+ examples=[2020, '>2021', '<2023', '>=2021', '<=2023', '[2021, 2023]', '[2021, 2023)']
88
+ )
89
  ticker: str = Field(..., description=f"The company ticker this query relates to. Must be a valid ticket symbol from the list {list(tickers.keys())}.")
90
 
91
  vec_factory = VectaraToolFactory(vectara_api_key=cfg.api_key,
 
104
  summary_num_results = 10,
105
  vectara_summarizer = summarizer,
106
  include_citations = True,
107
+ verbose=True,
108
+ )
109
+
110
+ class SearchTranscriptsArgs(BaseModel):
111
+ query: str = Field(..., description="The user query, always in the form of a question", examples=["what are the risks reported?", "who are the competitors?"])
112
+ top_k: int = Field(..., description="The number of results to return.")
113
+ year: int | str = Field(
114
+ default=None,
115
+ description=f"The year this query relates to. An integer between {min(years)} and {max(years)} or a string specifying a condition on the year",
116
+ examples=[2020, '>2021', '<2023', '>=2021', '<=2023', '[2021, 2023]', '[2021, 2023)']
117
+ )
118
+ ticker: str = Field(..., description=f"The company ticker this query relates to. Must be a valid ticket symbol from the list {list(tickers.keys())}.")
119
+ search_transcripts = vec_factory.create_search_tool(
120
+ tool_name = "search_transcripts",
121
+ tool_description = """
122
+ Given a company name and year, and a user query, retrieves the most relevant text from analyst call transcripts about the company related to the user query.
123
+ """,
124
+ tool_args_schema = QueryTranscriptsArgs,
125
+ reranker = "multilingual_reranker_v1", rerank_k = 100,
126
+ lambda_val = 0.005,
127
+ verbose=True
128
  )
129
 
130
  tools_factory = ToolsFactory()
 
137
  ]
138
  ] +
139
  tools_factory.financial_tools() +
140
+ [ask_transcripts, search_transcripts]
141
  )
142
 
143
  def initialize_agent(_cfg, agent_progress_callback=None):
 
155
  - When querying a tool for a numeric value or KPI, use a concise and non-ambiguous description of what you are looking for.
156
  - If you calculate a metric, make sure you have all the necessary information to complete the calculation. Don't guess.
157
  """
158
+ def query_logging(query: str, response: str):
159
+ print(f"Logging query={query}, response={response}")
160
 
161
  agent = Agent(
162
  tools=create_assistant_tools(_cfg),
163
  topic="Financial data, annual reports and 10-K filings",
164
  custom_instructions=financial_bot_instructions,
165
  agent_progress_callback=agent_progress_callback,
166
+ query_logging_callback=query_logging,
167
  )
168
  agent.report()
169
  return agent
app.py CHANGED
@@ -12,7 +12,10 @@ if 'device_id' not in st.session_state:
12
  if "feedback_key" not in st.session_state:
13
  st.session_state.feedback_key = 0
14
 
 
 
 
15
  if __name__ == "__main__":
16
  st.set_page_config(page_title="Financial Assistant", layout="wide")
17
  nest_asyncio.apply()
18
- asyncio.run(launch_bot())
 
12
  if "feedback_key" not in st.session_state:
13
  st.session_state.feedback_key = 0
14
 
15
+ async def main():
16
+ await launch_bot()
17
+
18
  if __name__ == "__main__":
19
  st.set_page_config(page_title="Financial Assistant", layout="wide")
20
  nest_asyncio.apply()
21
+ asyncio.run(main())
requirements.txt CHANGED
@@ -6,4 +6,4 @@ streamlit_feedback==0.1.3
6
  uuid==1.30
7
  langdetect==1.0.9
8
  langcodes==3.4.0
9
- vectara-agentic==0.1.22
 
6
  uuid==1.30
7
  langdetect==1.0.9
8
  langcodes==3.4.0
9
+ vectara-agentic==0.1.27
st_app.py CHANGED
@@ -13,15 +13,6 @@ from agent import initialize_agent, get_agent_config
13
 
14
  initial_prompt = "How can I help you today?"
15
 
16
- def show_example_questions():
17
- if len(st.session_state.example_messages) > 0 and st.session_state.first_turn:
18
- selected_example = pills("Queries to Try:", st.session_state.example_messages, index=None)
19
- if selected_example:
20
- st.session_state.ex_prompt = selected_example
21
- st.session_state.first_turn = False
22
- return True
23
- return False
24
-
25
  def format_log_msg(log_msg: str):
26
  max_log_msg_size = 500
27
  return log_msg if len(log_msg) <= max_log_msg_size else log_msg[:max_log_msg_size]+'...'
@@ -37,7 +28,13 @@ def agent_progress_callback(status_type: AgentStatusType, msg: str):
37
  latest_message = f"Calling tool {tool_name}..."
38
  elif status_type == AgentStatusType.TOOL_OUTPUT:
39
  latest_message = "Analyzing tool output..."
 
 
 
 
 
40
  else:
 
41
  return
42
 
43
  st.session_state.status.update(label=latest_message)
@@ -46,6 +43,16 @@ def agent_progress_callback(status_type: AgentStatusType, msg: str):
46
  for log_msg in st.session_state.log_messages:
47
  st.markdown(format_log_msg(log_msg), unsafe_allow_html=True)
48
 
 
 
 
 
 
 
 
 
 
 
49
  @st.dialog(title="Agent logs", width='large')
50
  def show_modal():
51
  for log_msg in st.session_state.log_messages:
@@ -132,8 +139,16 @@ async def launch_bot():
132
  if st.session_state.prompt:
133
  with st.chat_message("assistant", avatar='🤖'):
134
  st.session_state.status = st.status('Processing...', expanded=False)
135
- res = st.session_state.agent.chat(st.session_state.prompt)
136
- res = escape_dollars_outside_latex(res)
 
 
 
 
 
 
 
 
137
  message = {"role": "assistant", "content": res, "avatar": '🤖'}
138
  st.session_state.messages.append(message)
139
  st.markdown(res)
 
13
 
14
  initial_prompt = "How can I help you today?"
15
 
 
 
 
 
 
 
 
 
 
16
  def format_log_msg(log_msg: str):
17
  max_log_msg_size = 500
18
  return log_msg if len(log_msg) <= max_log_msg_size else log_msg[:max_log_msg_size]+'...'
 
28
  latest_message = f"Calling tool {tool_name}..."
29
  elif status_type == AgentStatusType.TOOL_OUTPUT:
30
  latest_message = "Analyzing tool output..."
31
+ elif status_type == AgentStatusType.AGENT_UPDATE:
32
+ if "Thought:" in msg:
33
+ latest_message = "Thinking..."
34
+ else:
35
+ latest_message = "Updating agent..."
36
  else:
37
+ print(f"callback with {status_type} and {msg}")
38
  return
39
 
40
  st.session_state.status.update(label=latest_message)
 
43
  for log_msg in st.session_state.log_messages:
44
  st.markdown(format_log_msg(log_msg), unsafe_allow_html=True)
45
 
46
+
47
+ def show_example_questions():
48
+ if len(st.session_state.example_messages) > 0 and st.session_state.first_turn:
49
+ selected_example = pills("Queries to Try:", st.session_state.example_messages, index=None)
50
+ if selected_example:
51
+ st.session_state.ex_prompt = selected_example
52
+ st.session_state.first_turn = False
53
+ return True
54
+ return False
55
+
56
  @st.dialog(title="Agent logs", width='large')
57
  def show_modal():
58
  for log_msg in st.session_state.log_messages:
 
139
  if st.session_state.prompt:
140
  with st.chat_message("assistant", avatar='🤖'):
141
  st.session_state.status = st.status('Processing...', expanded=False)
142
+ response = st.session_state.agent.chat(st.session_state.prompt)
143
+ res = escape_dollars_outside_latex(response.response)
144
+
145
+ #response = await st.session_state.agent.achat(st.session_state.prompt)
146
+ #res = escape_dollars_outside_latex(response.response)
147
+
148
+ #res = await st.session_state.agent.astream_chat(st.session_state.prompt)
149
+ #response = ''.join([token async for token in res.async_response_gen()])
150
+ #res = escape_dollars_outside_latex(response)
151
+
152
  message = {"role": "assistant", "content": res, "avatar": '🤖'}
153
  st.session_state.messages.append(message)
154
  st.markdown(res)