|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import pandas as pd |
|
import numpy as np |
|
from sklearn.preprocessing import StandardScaler, MinMaxScaler |
|
from sklearn.model_selection import train_test_split |
|
from torch.utils.data import DataLoader, TensorDataset |
|
import plotly.express as px |
|
import streamlit as st |
|
from es_class_nn import SimplePlusNN2 |
|
version_name = 'CON_31' |
|
|
|
c1, c2 = st.columns([6,6]) |
|
with c2: |
|
st.image('logo_vidad.png', width=300, caption='https://www.continental.edu.pe/') |
|
|
|
st.title("Predicción de Abandono o Permanencia") |
|
st.write("Cargue el archivo PKL para visualizar el análisis de su contenido.") |
|
|
|
uploaded_file = st.file_uploader("Cargar archivo: ", type='xlsx') |
|
|
|
cat_sel = pd.read_excel('df_cat_prior.xlsx') |
|
df_categ = pd.read_excel('lista_categorias.xlsx') |
|
df_muestra = pd.read_excel('df_muestra_carga.xlsx') |
|
|
|
if uploaded_file is not None: |
|
|
|
|
|
df_test = pd.read_excel(uploaded_file) |
|
verif = df_test.columns == df_muestra.columns |
|
st.write(verif.sum()) |
|
df_scaled = pd.concat([df_muestra, df_test], axis=0, ignore_index=True) |
|
df_scaled = df_scaled.fillna(0) |
|
df = df_scaled.tail(len(df_test)).reset_index(drop=True) |
|
st.write(df) |
|
X = df.values |
|
X_test_tensor = torch.tensor(X.astype(np.float32), dtype=torch.float32) |
|
|
|
|
|
input_size = X_test_tensor.shape[1] |
|
num_classes = 2 |
|
model = SimplePlusNN2(input_size, num_classes) |
|
data_path = '' |
|
dict_name = f'edusights_20240702_state_dict_{version_name}.pth' |
|
model.load_state_dict(torch.load(data_path+dict_name)) |
|
model.eval() |
|
|
|
|
|
inputs = X_test_tensor |
|
outputs = model(inputs) |
|
outputs_show = outputs.detach().numpy().flatten() |
|
outputs_show[outputs_show > 0.60] = 1.0 |
|
outputs_show[outputs_show < 0.40] = 0.0 |
|
filtered_arr = outputs_show[(outputs_show == 0.0) | (outputs_show == 1.0)] |
|
df['Pred'] = filtered_arr |
|
st.write(df['Pred']) |
|
|
|
csv_out = df.to_csv(encoding='iso-8859-1') |
|
|
|
st.download_button( |
|
label="Descargar CSV", |
|
data=csv_out, |
|
file_name='predicciones_carga.csv', |
|
mime='text/csv' |
|
) |
|
|
|
c3, c4 = st.columns([6,6]) |
|
with c3: |
|
st.image('gdmklogo.png', width=100, caption='Powered by GestioDinámica 2024') |
|
|