ositamiles commited on
Commit
e13ae45
·
verified ·
1 Parent(s): 0e0fcbe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -0
app.py CHANGED
@@ -1,7 +1,85 @@
1
  from fastapi import FastAPI
 
 
 
 
 
 
2
 
 
 
 
 
 
 
 
 
3
  app = FastAPI()
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  @app.get("/")
6
  def greet_json():
7
  return {"Hello": "World!"}
 
1
  from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ import numpy as np
4
+ import pandas as pd
5
+ import tensorflow as tf
6
+ from sklearn.preprocessing import OneHotEncoder, StandardScaler
7
+ import joblib
8
 
9
+ # Load the trained model
10
+ model = tf.keras.models.load_model('trained_game_price_model.h5')
11
+
12
+ # Load pre-trained OneHotEncoder and StandardScaler
13
+ ohe = joblib.load('ohe.pkl')
14
+ scaler = joblib.load('scaler.pkl')
15
+
16
+ # FastAPI app
17
  app = FastAPI()
18
 
19
+ # Pydantic model for input validation
20
+ class GameDetails(BaseModel):
21
+ genre: str
22
+ targetPlatform: str
23
+ gamePlays: int
24
+ competitorPricing: float
25
+ currencyFluctuations: float
26
+
27
+ # Function to preprocess the input data
28
+ def preprocess_input(data, ohe, scaler):
29
+ # Convert input into DataFrame for processing
30
+ input_data = pd.DataFrame([data], columns=['genre', 'targetPlatform', 'gamePlays', 'competitorPricing', 'currencyFluctuations'])
31
+
32
+ # Apply OneHotEncoder for categorical features
33
+ input_data_transformed = ohe.transform(input_data[['genre', 'targetPlatform']])
34
+
35
+ # Ensure numerical features are 2D
36
+ numerical_features = input_data[['gamePlays', 'competitorPricing', 'currencyFluctuations']].values.reshape(1, -1)
37
+
38
+ # Merge with numerical features
39
+ input_data = np.hstack((input_data_transformed.toarray(), numerical_features))
40
+
41
+ # Scale the features
42
+ input_data_scaled = scaler.transform(input_data)
43
+
44
+ return input_data_scaled
45
+
46
+ # Function to make a prediction
47
+ def make_prediction(input_data):
48
+ # Preprocess the data for the model
49
+ input_data_scaled = preprocess_input(input_data, ohe, scaler)
50
+
51
+ # Make prediction
52
+ prediction = model.predict(input_data_scaled)
53
+
54
+ return prediction[0][0]
55
+
56
+ # API endpoint for price prediction
57
+ @app.post("/predict_price/")
58
+ def predict_price(game_details: GameDetails):
59
+ # Prepare input data for prediction
60
+ input_data = {
61
+ 'genre': game_details.genre,
62
+ 'targetPlatform': game_details.targetPlatform,
63
+ 'gamePlays': game_details.gamePlays,
64
+ 'competitorPricing': game_details.competitorPricing,
65
+ 'currencyFluctuations': game_details.currencyFluctuations
66
+ }
67
+
68
+ # Make prediction
69
+ predicted_price = make_prediction(input_data)
70
+
71
+ # Return the predicted price
72
+ return {
73
+ "predicted_price": f"${predicted_price:.2f}",
74
+ "input_details": {
75
+ "genre": game_details.genre,
76
+ "platform": game_details.targetPlatform,
77
+ "game_plays": game_details.gamePlays,
78
+ "competitor_pricing": game_details.competitorPricing,
79
+ "currency_fluctuations": game_details.currencyFluctuations
80
+ }
81
+ }
82
+
83
  @app.get("/")
84
  def greet_json():
85
  return {"Hello": "World!"}