InvestGenie / app.py
zayeem00's picture
Update app.py
36edb30 verified
raw
history blame
5.51 kB
import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
import gradio as gr
import yfinance as yf
import finnhub
import logging
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Initialize Finnhub client (for financial data)
finnhub_client = finnhub.Client(api_key="cr5lji9r01qgfrnl5q8gcr5lji9r01qgfrnl5q90")
# Define Shariah compliance rules
def is_shariah_compliant(row):
try:
# Check debt/asset ratio
if row['debt_to_asset_ratio'] > 0.33:
return False
# Check interest income
if row['interest_income'] / row['total_revenue'] > 0.05:
return False
# Check prohibited business activities
if row['alcohol_revenue'] > 0 or row['gambling_revenue'] > 0 or row['tobacco_revenue'] > 0:
return False
return True
except Exception as e:
logging.error(f"Error in Shariah compliance check: {e}")
return False
def get_stock_data(market):
tickers = []
if market == 'US':
tickers = yf.Tickers("^GSPC").symbols
elif market == 'HK':
tickers = yf.Tickers("^HSI").symbols
elif market == 'IN':
tickers = yf.Tickers("^BSESN").symbols
else:
logging.warning(f"Market {market} is not supported")
return pd.DataFrame()
# Fetch financial data from Finnhub
data = []
for ticker in tickers:
try:
company_profile = finnhub_client.company_profile2(symbol=ticker)
financials = finnhub_client.company_basic_financials(symbol=ticker, metric='all')
row = {
'stock_symbol': ticker,
'market': market,
'sector': company_profile.get('finnhubIndustry', 'Unknown'),
'debt_to_asset_ratio': financials['metric'].get('debtToAssets', np.nan),
'interest_income': financials['metric'].get('interestIncome', np.nan),
'total_revenue': financials['metric'].get('totalRevenue', np.nan),
'alcohol_revenue': 0, # Placeholder, needs additional data source
'gambling_revenue': 0, # Placeholder, needs additional data source
'tobacco_revenue': 0 # Placeholder, needs additional data source
}
data.append(row)
except Exception as e:
logging.error(f"Failed to fetch data for {ticker}: {e}")
continue
return pd.DataFrame(data)
def predict_stock_performance(market, sector, investment_horizon, model_type):
# Fetch stock data for the selected market
df = get_stock_data(market)
if df.empty:
return "No data available for the selected market."
# Apply Shariah compliance filtering
df['is_shariah_compliant'] = df.apply(is_shariah_compliant, axis=1)
shariah_compliant_df = df[df['is_shariah_compliant']]
if shariah_compliant_df.empty:
return "No Shariah-compliant stocks found."
# Check for required columns before training the model
required_columns = ['price_to_earnings', 'dividend_yield', 'revenue_growth']
for col in required_columns:
if col not in shariah_compliant_df.columns:
logging.error(f"Missing column: {col}")
return f"Error: Missing required financial data ({col})."
# Financial analysis and forecasting
X = shariah_compliant_df[required_columns]
y = shariah_compliant_df.get('future_price', pd.Series())
if y.empty:
return "Error: Future price data not available."
# Train the selected model
try:
if model_type == 'linear_regression':
model = LinearRegression()
elif model_type == 'random_forest':
model = RandomForestRegressor()
else:
return "Invalid model type selected."
model.fit(X, y)
# Predict future stock performance
shariah_compliant_df['predicted_price'] = model.predict(X)
except Exception as e:
logging.error(f"Model training failed: {e}")
return "Error in model training or prediction."
# Rank and recommend the top Shariah-compliant stocks
top_stocks = shariah_compliant_df[shariah_compliant_df['sector'] == sector].sort_values('predicted_price', ascending=False).head(5)
if top_stocks.empty:
return "No top stocks found in the selected sector."
# Prepare the output for the Gradio interface
output = '\n'.join(f"{row['stock_symbol']}: {row['predicted_price']}" for _, row in top_stocks.iterrows())
return output
def update_sectors(market):
df = get_stock_data(market)
if df.empty:
return []
return df['sector'].unique().tolist()
# Create the Gradio interface
with gr.Blocks() as app:
with gr.Row():
market = gr.Dropdown(['US', 'HK', 'IN'], label='Market', interactive=True)
sector = gr.Dropdown([], label='Sector', interactive=True)
investment_horizon = gr.Slider(1, 10, step=1, label='Investment Horizon (years)')
model_type = gr.Radio(['linear_regression', 'random_forest'], label='Model Type')
output = gr.Textbox(label='Recommended Shariah-Compliant Stocks')
market.change(fn=update_sectors, inputs=[market], outputs=[sector])
gr.Button('Predict').click(fn=predict_stock_performance, inputs=[market, sector, investment_horizon, model_type], outputs=[output])
app.launch()