# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/dr.ipynb.

# %% auto 0
__all__ = ['get_gpu_memory', 'color_for_percentage', 'create_bar', 'gpu_memory_status', 'check_compatibility', 'get_UMAP_prjs',
           'get_PCA_prjs', 'get_TSNE_prjs', 'cluster_score']

# %% ../nbs/dr.ipynb 2
import subprocess
def get_gpu_memory(device = 0):
    total_memory = subprocess.check_output(["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits", "--id=" + str(device)])
    total_memory = int(total_memory.decode().split('\n')[0])
    used_memory = subprocess.check_output(["nvidia-smi", "--query-gpu=memory.used", "--format=csv,noheader,nounits",  "--id=" + str(device)])
    used_memory = int(used_memory.decode().split('\n')[0])

    percentage = round((used_memory / total_memory) * 100)
    return used_memory, total_memory, percentage

def color_for_percentage(percentage):
    if percentage < 20:
        return "\033[90m"  # Gray
    elif percentage < 40:
        return "\033[94m"  # Blue
    elif percentage < 60:
        return "\033[92m"  # Green
    elif percentage < 80:
        return "\033[93m"  # Orange
    else:
        return "\033[91m"  # Red
        
def create_bar(percentage, color_code, length=20):
    filled_length = int(length * percentage // 100)
    bar = "█" * filled_length + "-" * (length - filled_length)
    return color_code + bar + "\033[0m"  # Apply color and reset after bar

def gpu_memory_status(device=0):
    used, total, percentage = get_gpu_memory(device)
    color_code = color_for_percentage(percentage)
    bar = create_bar(percentage, color_code)
    print(f"GPU | Used mem: {used}")
    print(f"GPU | Used mem: {total}")
    print(f"GPU | Memory Usage: [{bar}] {color_code}{percentage}%\033[0m")

# %% ../nbs/dr.ipynb 4
import umap
import cudf
import cuml
import pandas as pd
import numpy as np
from fastcore.all import *
from .imports import *
from .load import TSArtifact

# %% ../nbs/dr.ipynb 5
def check_compatibility(dr_ar:TSArtifact, enc_ar:TSArtifact):
    "Function to check that the artifact used by the encoder model and the artifact that is \
    going to be passed through the DR are compatible"
    try:
        # Check that both artifacts have the same variables
        chk_vars = dr_ar.metadata['TS']['vars'] == enc_ar.metadata['TS']['vars']
        # Check that both artifacts have the same freq
        chk_freq = dr_ar.metadata['TS']['freq'] == enc_ar.metadata['TS']['freq']
        # Check that the dr artifact is not normalized (not normalized data has not the key normalization)
        chk_norm = dr_ar.metadata['TS'].get('normalization') is None
        # Check that the dr artifact has not missing values
        chk_miss = dr_ar.metadata['TS']['has_missing_values'] == "False"
        # Check all logical vars.
        if chk_vars and chk_freq and chk_norm and chk_miss:
            print("Artifacts are compatible.")
        else:
            raise Exception
    except Exception as e:
        print("Artifacts are not compatible.")
        raise e
    return None

# %% ../nbs/dr.ipynb 7
#Comment this part after 4_seconds debugged
import hashlib

# %% ../nbs/dr.ipynb 8
import warnings
import sys
from numba.core.errors import NumbaPerformanceWarning
@delegates(cuml.UMAP)
def get_UMAP_prjs(
    input_data, 
    cpu=True, 
    print_flag = False, 
    check_memory_usage = True,
    **kwargs
):
    "Compute the projections of `input_data` using UMAP, with a configuration contained in `**kwargs`."
    if print_flag: 
        print("--> get_UMAP_prjs")
        print("kwargs: ", kwargs)
        sys.stdout.flush()
        ####
        checksum = hashlib.md5(input_data.tobytes()).hexdigest()
        print(checksum)
        ####
        
    if check_memory_usage: gpu_memory_status()
    
    warnings.filterwarnings("ignore", category=NumbaPerformanceWarning) # silence NumbaPerformanceWarning
    
    #reducer = umap.UMAP(**kwargs) if cpu else cuml.UMAP(**kwargs)
    if cpu:
        print("-- umap.UMAP --", cpu)
        sys.stdout.flush()
        reducer = umap.UMAP(**kwargs)
    else:
        print("-- cuml.UMAP --", cpu)
        sys.stdout.flush()
        if 'random_state' in kwargs:
            kwargs['random_state'] = np.uint64(kwargs['random_state'])
        reducer = cuml.UMAP(**kwargs)
    
    if print_flag: 
        print("------- reducer --------")
        print(reducer)
        print(reducer.get_params())
        print("------- reducer --------")
        sys.stdout.flush()
    
    projections = reducer.fit_transform(input_data)
    
    if check_memory_usage: gpu_memory_status()
    if print_flag: 
        checksum = hashlib.md5(projections.tobytes()).hexdigest()
        print("prjs checksum ", checksum)
        print("get_UMAP_prjs -->")
        sys.stdout.flush()
    return projections

# %% ../nbs/dr.ipynb 13
@delegates(cuml.PCA)
def get_PCA_prjs(X, cpu=False, **kwargs):
    r"""
    Computes PCA projections of X
    """
    if cpu:
        raise NotImplementedError
    else:
        reducer = cuml.PCA(**kwargs)
    projections = reducer.fit_transform(X)
    return projections

# %% ../nbs/dr.ipynb 15
@delegates(cuml.TSNE)
def get_TSNE_prjs(X, cpu=False, **kwargs):
    r"""
    Computes TSNE projections of X
    """
    if cpu:
        raise NotImplementedError
    else:
        reducer = cuml.TSNE(**kwargs)
    projections = reducer.fit_transform(X)
    return projections

# %% ../nbs/dr.ipynb 18
from sklearn.metrics import silhouette_score
def cluster_score(prjs, clusters_labels, print_flag):
    score = silhouette_score(prjs, clusters_labels)
    if print_flag: print("Silhouette_score:", score)
    return score