Blackhards / app.py
ositamiles's picture
Update app.py
2ed0106 verified
from fastapi import FastAPI
from pydantic import BaseModel
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.preprocessing import OneHotEncoder, StandardScaler
import joblib
# Load the trained model
model = tf.keras.models.load_model('trained_game_price_model.h5')
# Load pre-trained OneHotEncoder and StandardScaler
ohe = joblib.load('ohe.pkl')
scaler = joblib.load('scaler.pkl')
# FastAPI app
app = FastAPI()
# Pydantic model for input validation
class GameDetails(BaseModel):
genre: str
targetPlatform: str
gamePlays: int
competitorPricing: float
currencyFluctuations: float
# Function to preprocess the input data
def preprocess_input(data, ohe, scaler):
# Convert input into DataFrame for processing
input_data = pd.DataFrame([data], columns=['genre', 'targetPlatform', 'gamePlays', 'competitorPricing', 'currencyFluctuations'])
# Apply OneHotEncoder for categorical features
input_data_transformed = ohe.transform(input_data[['genre', 'targetPlatform']])
# Ensure numerical features are 2D
numerical_features = input_data[['gamePlays', 'competitorPricing', 'currencyFluctuations']].values.reshape(1, -1)
# Merge with numerical features
input_data = np.hstack((input_data_transformed.toarray(), numerical_features))
# Scale the features
input_data_scaled = scaler.transform(input_data)
return input_data_scaled
# Function to make a prediction
def make_prediction(input_data):
# Preprocess the data for the model
input_data_scaled = preprocess_input(input_data, ohe, scaler)
# Make prediction
prediction = model.predict(input_data_scaled)
return prediction[0][0]
# API endpoint for price prediction
@app.post("/predict_price/")
def predict_price(game_details: GameDetails):
# Prepare input data for prediction
input_data = {
'genre': game_details.genre,
'targetPlatform': game_details.targetPlatform,
'gamePlays': game_details.gamePlays,
'competitorPricing': game_details.competitorPricing,
'currencyFluctuations': game_details.currencyFluctuations
}
# Make prediction
predicted_price = make_prediction(input_data)
# Return the predicted price
return {
"predicted_price": f"${predicted_price:.2f}",
"input_details": {
"genre": game_details.genre,
"platform": game_details.targetPlatform,
"game_plays": game_details.gamePlays,
"competitor_pricing": game_details.competitorPricing,
"currency_fluctuations": game_details.currencyFluctuations
}
}
@app.get("/")
def greet_json():
return {"Hello": "Blackhards♠️♣️!"}