|
import numpy as np |
|
np.random.seed(0) |
|
import pickle |
|
from sklearn.compose import ColumnTransformer |
|
from sklearn.datasets import fetch_openml |
|
from sklearn.pipeline import Pipeline |
|
from sklearn.impute import SimpleImputer |
|
from sklearn.preprocessing import StandardScaler, OneHotEncoder |
|
from sklearn.linear_model import LogisticRegression |
|
from sklearn.model_selection import train_test_split |
|
|
|
from sklearn import tree |
|
|
|
from fastapi import FastAPI, HTTPException |
|
from fastapi.responses import HTMLResponse |
|
from pydantic import BaseModel |
|
from typing import List |
|
|
|
class InputData(BaseModel): |
|
data: List[float] |
|
|
|
|
|
app = FastAPI() |
|
|
|
def build_model(): |
|
with open('miarbol.pkl', 'rb') as fid: |
|
miarbol = pickle.load(fid) |
|
return miarbol |
|
|
|
miarbol = build_model() |
|
|
|
|
|
@app.post("/predict/") |
|
async def predict(data: InputData): |
|
print(f"Data: {data}") |
|
global miarbol |
|
try: |
|
|
|
input_data = np.array(data.data).reshape( |
|
1, -1 |
|
) |
|
prediction = miarbol.predict(input_data).round() |
|
return {"prediction": prediction.tolist()} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |