|
import streamlit as st |
|
import os |
|
import json |
|
import requests |
|
from google import genai |
|
from google.genai import types |
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
|
|
|
|
st.set_page_config( |
|
page_title="Flaix - Financial Assistant", |
|
page_icon="π°", |
|
layout="centered" |
|
) |
|
|
|
|
|
api_key = st.secrets["GOOGLE_API_KEY"] |
|
client = genai.Client(api_key=api_key) |
|
|
|
|
|
INDIAN_API_KEY = st.secrets["FINANCE_KEY"] |
|
INDIAN_API_BASE_URL = "https://stock.indianapi.in" |
|
|
|
|
|
API_ENDPOINTS = { |
|
"get_stock_details": { |
|
"endpoint": "/stock", |
|
"required_params": ["stock_name"], |
|
"param_mapping": {"stock_name": "name"}, |
|
"description": "Get details for a specific stock" |
|
}, |
|
"get_trending_stocks": { |
|
"endpoint": "/trending", |
|
"required_params": [], |
|
"param_mapping": {}, |
|
"description": "Get trending stocks in the market" |
|
}, |
|
"get_market_news": { |
|
"endpoint": "/news", |
|
"required_params": [], |
|
"param_mapping": {}, |
|
"description": "Get latest stock market news" |
|
}, |
|
"get_mutual_funds": { |
|
"endpoint": "/mutual_funds", |
|
"required_params": [], |
|
"param_mapping": {}, |
|
"description": "Get mutual funds data" |
|
}, |
|
"get_ipo_data": { |
|
"endpoint": "/ipo", |
|
"required_params": [], |
|
"param_mapping": {}, |
|
"description": "Get IPO data" |
|
}, |
|
"get_bse_most_active": { |
|
"endpoint": "/BSE_most_active", |
|
"required_params": [], |
|
"param_mapping": {}, |
|
"description": "Get BSE most active stocks" |
|
}, |
|
"get_nse_most_active": { |
|
"endpoint": "/NSE_most_active", |
|
"required_params": [], |
|
"param_mapping": {}, |
|
"description": "Get NSE most active stocks" |
|
}, |
|
"get_historical_data": { |
|
"endpoint": "/historical_data", |
|
"required_params": ["stock_name"], |
|
"optional_params": ["period"], |
|
"default_values": {"period": "1m", "filter": "default"}, |
|
"param_mapping": {}, |
|
"description": "Get historical data for a stock" |
|
} |
|
} |
|
|
|
|
|
def call_indian_api(endpoint, params=None): |
|
""" |
|
Generic function to call the Indian Stock Market API |
|
|
|
Args: |
|
endpoint: API endpoint suffix (e.g., '/stock', '/trending') |
|
params: Optional parameters for the API call |
|
|
|
Returns: |
|
JSON response from the API |
|
""" |
|
url = f"{INDIAN_API_BASE_URL}{endpoint}" |
|
headers = {"X-Api-Key": INDIAN_API_KEY} |
|
|
|
try: |
|
response = requests.get(url, headers=headers, params=params) |
|
return response.json() |
|
except Exception as e: |
|
return {"error": str(e)} |
|
|
|
|
|
def call_api_by_name(api_name, **kwargs): |
|
""" |
|
Call an API by its name from the API_ENDPOINTS dictionary |
|
|
|
Args: |
|
api_name: Name of the API to call (key in API_ENDPOINTS) |
|
**kwargs: Parameters to pass to the API |
|
|
|
Returns: |
|
JSON response from the API |
|
""" |
|
if api_name not in API_ENDPOINTS: |
|
return {"error": f"Unknown API: {api_name}"} |
|
|
|
api_info = API_ENDPOINTS[api_name] |
|
endpoint = api_info["endpoint"] |
|
|
|
|
|
for param in api_info.get("required_params", []): |
|
if param not in kwargs: |
|
return {"error": f"Missing required parameter: {param}"} |
|
|
|
|
|
mapped_params = {} |
|
for param, value in kwargs.items(): |
|
mapped_name = api_info.get("param_mapping", {}).get(param, param) |
|
mapped_params[mapped_name] = value |
|
|
|
|
|
for param, value in api_info.get("default_values", {}).items(): |
|
if param not in mapped_params: |
|
mapped_params[param] = value |
|
|
|
return call_indian_api(endpoint, mapped_params) |
|
|
|
|
|
def orchestrator(query): |
|
""" |
|
Determines if the query requires market data and which API to call |
|
Returns: (needs_api, api_function, params) |
|
""" |
|
|
|
orchestrator_prompt = """ |
|
You are an orchestrator for a financial assistant specialized in Indian markets. Your job is to analyze user queries and determine if they need real-time market data. |
|
|
|
IMPORTANT: Be very precise in your analysis. Only return TRUE for "needs_api" when the query EXPLICITLY asks for current market data, stock prices, or listings. |
|
|
|
Examples where needs_api should be TRUE: |
|
- "Show me the most active stocks on NSE today" β get_nse_most_active |
|
- "What is the current price of Reliance?" β get_stock_details with stock_name="Reliance" |
|
- "Tell me about trending stocks" β get_trending_stocks |
|
- "What are the latest IPOs?" β get_ipo_data |
|
|
|
Examples where needs_api should be FALSE: |
|
- "What is compound interest?" |
|
- "How should I start investing?" |
|
- "What are the tax benefits of PPF?" |
|
- "Explain mutual funds to me" |
|
|
|
Available API functions: |
|
- get_stock_details(stock_name): Get details for a specific stock |
|
- get_trending_stocks(): Get trending stocks in the market |
|
- get_market_news(): Get latest stock market news |
|
- get_mutual_funds(): Get mutual funds data |
|
- get_ipo_data(): Get IPO data |
|
- get_bse_most_active(): Get BSE most active stocks |
|
- get_nse_most_active(): Get NSE most active stocks |
|
- get_historical_data(stock_name, period="1m"): Get historical data for a stock |
|
|
|
User query: """ + query + """ |
|
|
|
Respond in JSON format with the following structure: |
|
{ |
|
"needs_api": true/false, |
|
"function": "function_name_if_needed", |
|
"params": { |
|
"param1": "value1", |
|
"param2": "value2" |
|
} |
|
} |
|
""" |
|
|
|
|
|
contents = [ |
|
types.Content( |
|
role="user", |
|
parts=[ |
|
types.Part.from_text(text=orchestrator_prompt) |
|
], |
|
), |
|
] |
|
|
|
|
|
generate_content_config = types.GenerateContentConfig( |
|
temperature=0.2, |
|
top_p=0.95, |
|
top_k=40, |
|
max_output_tokens=500, |
|
response_mime_type="text/plain", |
|
) |
|
|
|
|
|
response = client.models.generate_content( |
|
model="gemini-1.5-flash", |
|
contents=contents, |
|
config=generate_content_config, |
|
) |
|
|
|
|
|
try: |
|
decision_text = response.text |
|
|
|
if "```json" in decision_text: |
|
json_str = decision_text.split("```json")[1].split("```")[0].strip() |
|
elif "```" in decision_text: |
|
json_str = decision_text.split("```")[1].strip() |
|
else: |
|
json_str = decision_text.strip() |
|
|
|
decision = json.loads(json_str) |
|
return decision |
|
except Exception as e: |
|
print(f"Error parsing orchestrator response: {e}") |
|
return {"needs_api": False} |
|
|
|
|
|
|
|
|
|
SYSTEM_PROMPT = f"""You are Flaix, a helpful and knowledgeable financial assistant designed specifically for Indian users. Your purpose is to improve financial literacy and provide guidance on investments in the Indian market. |
|
|
|
Key responsibilities: |
|
1. Explain financial concepts in simple, easy-to-understand language |
|
2. Provide information about different investment options available in India (stocks, mutual funds, bonds, PPF, FDs, etc.) |
|
3. Help users understand investment risks and returns |
|
4. Explain tax implications of different investments in the Indian context |
|
5. Guide users on how to start investing based on their goals and risk tolerance |
|
6. Answer questions about market trends and financial news in India |
|
""" |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [ |
|
{"role": "user", "content": SYSTEM_PROMPT}, |
|
{"role": "model", "content": "Hello! I am Flaix, your financial assistant. You can ask me about investments, financial planning, or any other financial topic."} |
|
] |
|
|
|
|
|
st.title("Flaix - Your Financial Assistant") |
|
st.markdown("Ask any questions about investing, financial planning, or the Indian financial market.") |
|
|
|
|
|
for message in st.session_state.messages: |
|
if message["role"] == "user" and message["content"] != SYSTEM_PROMPT: |
|
with st.chat_message("user"): |
|
st.write(message["content"]) |
|
elif message["role"] == "model": |
|
with st.chat_message("assistant"): |
|
st.write(message["content"]) |
|
|
|
|
|
if prompt := st.chat_input("Ask me anything about finance or investing..."): |
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
|
|
|
with st.chat_message("user"): |
|
st.write(prompt) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
message_placeholder = st.empty() |
|
full_response = "" |
|
|
|
try: |
|
|
|
decision = orchestrator(prompt) |
|
|
|
|
|
api_context = "" |
|
if decision.get("needs_api", False): |
|
function_name = decision.get("function", "") |
|
params = decision.get("params", {}) |
|
|
|
message_placeholder.write("Fetching real-time market data...") |
|
|
|
if function_name in API_ENDPOINTS: |
|
api_result = call_api_by_name(function_name, **params) |
|
api_context = f"\nHere is the real-time market data from the Indian Stock Market API:\n{json.dumps(api_result, indent=2)}\n\nPlease use this data to provide an informative response to the user's query." |
|
|
|
|
|
|
|
user_query = prompt |
|
if api_context: |
|
user_query = f"{prompt}\n\n[SYSTEM NOTE: {api_context}]" |
|
|
|
|
|
system_message = SYSTEM_PROMPT |
|
if len(st.session_state.messages) > 2: |
|
|
|
conversation_history = "" |
|
for i in range(1, min(5, len(st.session_state.messages) - 1)): |
|
if st.session_state.messages[i]["role"] == "user" and st.session_state.messages[i]["content"] != SYSTEM_PROMPT: |
|
conversation_history += f"User: {st.session_state.messages[i]['content']}\n" |
|
elif st.session_state.messages[i]["role"] == "model": |
|
conversation_history += f"Assistant: {st.session_state.messages[i]['content']}\n" |
|
|
|
system_message += f"\n\nPrevious conversation:\n{conversation_history}" |
|
|
|
|
|
contents = [ |
|
types.Content( |
|
role="user", |
|
parts=[ |
|
types.Part.from_text(text=system_message) |
|
], |
|
), |
|
types.Content( |
|
role="model", |
|
parts=[ |
|
types.Part.from_text(text="I understand my role as Flaix, a financial assistant for Indian users. I'll provide helpful information about investing and financial planning in simple language.") |
|
], |
|
), |
|
types.Content( |
|
role="user", |
|
parts=[ |
|
types.Part.from_text(text=user_query) |
|
], |
|
), |
|
] |
|
|
|
|
|
generate_content_config = types.GenerateContentConfig( |
|
temperature=0.7, |
|
top_p=0.95, |
|
top_k=40, |
|
max_output_tokens=8192, |
|
response_mime_type="text/plain", |
|
) |
|
|
|
|
|
response_stream = client.models.generate_content_stream( |
|
model="gemini-1.5-flash", |
|
contents=contents, |
|
config=generate_content_config, |
|
) |
|
|
|
|
|
for chunk in response_stream: |
|
if hasattr(chunk, 'text'): |
|
full_response += chunk.text |
|
message_placeholder.write(full_response + "β") |
|
|
|
|
|
message_placeholder.write(full_response) |
|
|
|
except Exception as e: |
|
st.error(f"Error: {str(e)}") |
|
full_response = "I'm sorry, I encountered an error. Please try again later." |
|
message_placeholder.write(full_response) |
|
|
|
|
|
st.session_state.messages.append({"role": "model", "content": full_response}) |
|
|