|
import streamlit as st |
|
import torch |
|
from normflows import nflow |
|
import numpy as np |
|
import seaborn as sns |
|
import pandas as pd |
|
|
|
uploaded_file = st.file_uploader("Choose original dataset") |
|
bw = st.number_input('Scale',value=3.05) |
|
|
|
|
|
|
|
def compute(): |
|
api = nflow(dim=8,latent=16,dataset=uploaded_file) |
|
api.compile(optim=torch.optim.ASGD,bw=bw,lr=0.0001,wd=None) |
|
|
|
my_bar = st.progress(0) |
|
|
|
for idx in api.train(iters=10000): |
|
st.write('Loss:',idx[1]) |
|
my_bar.progress(idx[0]/10000) |
|
|
|
samples = np.array(api.model.sample( |
|
torch.tensor(api.scaled).float()).detach()) |
|
|
|
|
|
g = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind='kde',cmap=sns.color_palette("Blues", as_cmap=True),fill=True,label='Gaussian KDE',levels=50) |
|
|
|
w = sns.scatterplot(x=api.scaled[:,0],y=api.scaled[:,1],ax=g.ax_joint,c='orange',marker='+',s=100,label='Real') |
|
st.pyplot(w.get_figure()) |
|
|
|
|
|
def random_normal_samples(n, dim=2): |
|
return torch.zeros(n, dim).normal_(mean=0, std=1) |
|
|
|
samples = np.array(api.model.sample(torch.tensor(random_normal_samples(1000,api.scaled.shape[-1])).float()).detach()) |
|
|
|
return api.scaler.inverse_transform(samples) |
|
|
|
|
|
|
|
if uploaded_file is not None: |
|
samples=compute() |
|
st.download_button('Download generated CSV', pd.DataFrame(samples).to_csv(), 'text/csv') |