alireza / app.py
aromidvar's picture
Update app.py
458d562 verified
import os
import logging
import numpy as np
import pandas as pd
import xgboost as xgb
import requests
from flask import Flask, request, jsonify
from flask_sqlalchemy import SQLAlchemy
import gradio as gr
# Initialize Flask application
app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = os.getenv('DATABASE_URL') # Default to SQLite for testing
db = SQLAlchemy(app)
# Set up logging
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("app.log"), # Log to file
logging.StreamHandler() # Log to console
]
)
logger = logging.getLogger(__name__)
# Database model for storing economic indicators
class EconomicData(db.Model):
id = db.Column(db.Integer, primary_key=True)
indicator = db.Column(db.String(100))
value = db.Column(db.Float)
def initialize_db():
"""Create database tables and add dummy data if not present."""
with app.app_context():
db.create_all()
if EconomicData.query.count() == 0:
create_dummy_data()
def create_dummy_data():
"""Add dummy data to the database for initial testing."""
economic_records = [
EconomicData(indicator='GDP', value=2.5),
EconomicData(indicator='Inflation', value=1.2),
EconomicData(indicator='Unemployment Rate', value=4.5)
]
db.session.bulk_save_objects(economic_records)
db.session.commit()
@app.route('/search', methods=['POST'])
def search():
"""Search for economic indicators."""
logger.debug("Search endpoint hit.")
try:
query = request.json['query']
results = EconomicData.query.filter(EconomicData.indicator.ilike(f'%{query}%')).all()
logger.debug(f"Search results for query '{query}': {[(record.indicator, record.value) for record in results]}")
return jsonify({record.indicator: record.value for record in results})
except Exception as e:
logger.error(f"Error in search: {e}")
return jsonify({"error": str(e)}), 500
@app.route('/predict', methods=['POST'])
def predict():
"""Predict future values based on historical data."""
logger.debug("Predict endpoint hit.")
try:
data = request.json
num_predictions = data['num_predictions']
historical_data = pd.Series([100, 101, 102, 104, 107])
model = xgb.XGBRegressor()
model.fit(np.arange(len(historical_data)).reshape(-1, 1), historical_data.values)
last_days = np.array([len(historical_data)]).reshape(-1, 1)
predictions = []
for _ in range(num_predictions):
next_value = model.predict(last_days)[0]
predictions.append(next_value)
last_days = np.append(last_days, [[last_days[-1][0] + 1]], axis=0)
logger.debug(f"Predictions: {predictions}")
return jsonify({"predictions": predictions})
except Exception as e:
logger.error(f"Error in prediction: {e}")
return jsonify({"error": str(e)}), 500
@app.route('/economic_data', methods=['GET'])
def get_economic_data():
"""Retrieve all economic data records."""
logger.debug("Retrieve economic data endpoint hit.")
try:
records = EconomicData.query.all()
return jsonify({record.indicator: record.value for record in records})
except Exception as e:
logger.error(f"Error in retrieving economic data: {e}")
return jsonify({"error": str(e)}), 500
def gradio_ui():
"""Launch the Gradio interface."""
logger.debug("Launching Gradio UI.")
def call_search(query):
response = requests.post('http://localhost:5000/search', json={'query': query})
return response.json()
def call_predict(num_predictions):
response = requests.post('http://localhost:5000/predict', json={'num_predictions': num_predictions})
return response.json()
def call_economic_data():
response = requests.get('http://localhost:5000/economic_data')
return response.json()
with gr.Blocks() as demo:
gr.Markdown("# Economic DataHub")
gr.Markdown("Welcome to the interactive Economic DataHub. Explore economic indicators, make predictions, and access detailed economic data.")
with gr.Tab("Search Economic Indicators"):
search_input = gr.Textbox(label="Search for Economic Terms:", placeholder="Enter an economic term...")
search_button = gr.Button("Search")
search_output = gr.Textbox(label="Results", interactive=False)
search_button.click(call_search, inputs=search_input, outputs=search_output)
with gr.Tab("Time Series Prediction"):
num_predictions = gr.Slider(minimum=1, maximum=30, step=1, label="Number of Days to Predict")
predict_button = gr.Button("Predict")
prediction_output = gr.Textbox(label="Predictions", interactive=False)
predict_button.click(call_predict, inputs=num_predictions, outputs=prediction_output)
with gr.Tab("Retrieve Economic Data"):
retrieve_button = gr.Button("Get Economic Data")
data_output = gr.Textbox(label="Economic Data", interactive=False)
retrieve_button.click(call_economic_data, outputs=data_output)
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
if __name__ == '__main__':
initialize_db() # Ensure the database is initialized before starting the app
# Start Flask app
app.run(debug=True, host='0.0.0.0', port=5000) # Start Flask in the main thread
# Launch Gradio UI in the same thread for simplicity (no threading)
gradio_ui()