import streamlit as st
import pandas as pd
import sentencepiece

# 모델 준비하기
from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizer
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pandas as pd
import torch
import os
from tqdm import tqdm

# [theme]
# base="dark"
# primaryColor="purple"

# 제목 입력
st.header('한국표준산업분류 자동코딩 서비스')

# 재로드 안하도록
@st.experimental_memo(max_entries=20)
def md_loading():
    ## cpu
    device = torch.device("cpu")

    tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-base')
    model = XLMRobertaForSequenceClassification.from_pretrained('xlm-roberta-base', num_labels=493)
    
    model_checkpoint = 'en_ko_4mix_proto.bin'
    project_path = './'
    output_model_file = os.path.join(project_path, model_checkpoint)

#    model.load_state_dict(torch.load(output_model_file))
    model.load_state_dict(torch.load(output_model_file, map_location=torch.device('cpu')))
#    ckpt = torch.load(output_model_file, map_location=torch.device('cpu'))
#    model.load_state_dict(ckpt['model_state_dict'])
    
#    device = torch.device("cuda" if torch.cuda.is_available() and not False else "cpu")
#    device = torch.device("cpu")
        
    model.to(device)
    
    label_tbl = np.load('./label_table.npy')
    loc_tbl = pd.read_csv('./kisc_table.csv', encoding='utf-8')

    print('ready')

    return tokenizer, model, label_tbl, loc_tbl, device

# 모델 로드
tokenizer, model, label_tbl, loc_tbl, device = md_loading()


# 데이터 셋 준비용
max_len = 64    # 64

class TVT_Dataset(Dataset):
    
    def __init__(self, df):
        self.df_data = df
        
    def __getitem__(self, index):
    
        # 데이터프레임 칼럼 들고오기
        # sentence = self.df_data.loc[index, 'text']
        sentence = self.df_data.loc[index, ['CMPNY_NM', 'MAJ_ACT', 'WORK_TYPE', 'POSITION', 'DEPT_NM']]
        
        encoded_dict = tokenizer(
                    ' <s> '.join(sentence.to_list()),            
                    add_special_tokens = True,      
                    max_length = max_len,
                    padding='max_length',
                    truncation=True,
                    return_attention_mask = True,   
                    return_tensors = 'pt')
        
        
        padded_token_list = encoded_dict['input_ids'][0]
        att_mask = encoded_dict['attention_mask'][0]
        
        # 숫자로 변환된 label을 텐서로 변환
        # target = torch.tensor(self.df_data.loc[index, 'NEW_CD'])
        # input_ids, attention_mask, label을 하나의 인풋으로 묶음
        # sample = (padded_token_list, att_mask, target)
        sample = (padded_token_list, att_mask)

        return sample

    def __len__(self):
        return len(self.df_data)



# 텍스트 input 박스
business = st.text_input('사업체명')
business_work = st.text_input('사업체 하는일')
work_department = st.text_input('근무부서')
work_position = st.text_input('직책')
what_do_i = st.text_input('내가 하는 일')

# business_work = ''
# work_department = ''
# work_position = ''
# what_do_i = ''

# data 준비

# test dataset을 만들어줍니다. 
input_col_type = ['CMPNY_NM', 'MAJ_ACT', 'WORK_TYPE', 'POSITION', 'DEPT_NM']

def preprocess_dataset(dataset):
    dataset.reset_index(drop=True, inplace=True)
    dataset.fillna('')
    return dataset[input_col_type]


## 임시 확인
# st.write(md_input)

# 버튼
if st.button('확인'):
    ## 버튼 클릭 시 수행사항
    
    ### 데이터 준비
        
    # md_input: 모델에 입력할 input 값 정의
    # md_input = '|'.join([business, business_work, what_do_i, work_position, work_department])
    md_input = [str(business), str(business_work), str(what_do_i), str(work_position), str(work_department)]

    test_dataset = pd.DataFrame({
        input_col_type[0]: md_input[0],
        input_col_type[1]: md_input[1],
        input_col_type[2]: md_input[2],
        input_col_type[3]: md_input[3],
        input_col_type[4]: md_input[4]
    }, index=[0])

    # test_dataset = pd.read_csv(DATA_IN_PATH + test_set_name, sep='|', na_filter=False)
    
    test_dataset.reset_index(inplace=True)
    
    test_dataset = preprocess_dataset(test_dataset)

    print(len(test_dataset))
    print(test_dataset)

    print('base_data_loader 사용 시점점')
    test_data = TVT_Dataset(test_dataset)

    train_batch_size = 48

    # batch_size 만큼 데이터 분할
    test_dataloader = DataLoader(test_data,
                                batch_size=train_batch_size,
                                shuffle=False)


    ### 모델 실행


    # Put model in evaluation mode
    model.eval()
    model.zero_grad()

    # Tracking variables 
    predictions , true_labels = [], []

    # Predict 
    for batch in tqdm(test_dataloader):
        # Add batch to GPU
        batch = tuple(t.to(device) for t in batch)

        # Unpack the inputs from our dataloader
        test_input_ids, test_attention_mask = batch

        # Telling the model not to compute or store gradients, saving memory and 
        # speeding up prediction
        with torch.no_grad():
            # Forward pass, calculate logit predictions
            outputs = model(test_input_ids, token_type_ids=None, attention_mask=test_attention_mask)

        logits = outputs.logits

        # Move logits and labels to CPU
#        logits = logits.detach().cpu().numpy()

    pred_m = torch.nn.Softmax(dim=1)
    pred_ = pred_m(logits)
    # st.write(logits.size())
    # # 단독 예측 시
    # arg_idx = torch.argmax(logits, dim=1)
    # print('arg_idx:', arg_idx)

    # num_ans = label_tbl[arg_idx]
    # str_ans = loc_tbl['항목명'][loc_tbl['코드'] == num_ans].values

    # 상위 k번째까지 예측 시
    k = 10
    topk_idx = torch.topk(pred_.flatten(), k).indices
    topk_values = torch.topk(pred_.flatten(), k).values


    num_ans_topk = label_tbl[topk_idx]
    str_ans_topk = [loc_tbl['항목명'][loc_tbl['코드'] == k] for k in num_ans_topk]
    percent_ans_topk = topk_values.numpy()
    
    st.write(sum(torch.topk(pred_.flatten(), 493).values.numpy()))
    # print(num_ans, str_ans)
    # print(num_ans_topk)

    # print('사업체명:', query_tokens[0])
    # print('사업체 하는일:', query_tokens[1])
    # print('근무부서:', query_tokens[2])
    # print('직책:', query_tokens[3])
    # print('내가 하는일:', query_tokens[4])
    # print('산업코드 및 분류:', num_ans, str_ans)

    # ans = ''
    # ans1, ans2, ans3 = '', '', ''

    ## 모델 결과값 출력
    # st.write("산업코드 및 분류:", num_ans, str_ans[0])
    # st.write("세분류 코드")
    # for i in range(k):
    #     st.write(str(i+1) + '순위:', num_ans_topk[i], str_ans_topk[i].iloc[0])

    # print(num_ans)
    # print(str_ans, type(str_ans))

    str_ans_topk_list = []
    percent_ans_topk_list = []
    for i in range(k):
        str_ans_topk_list.append(str_ans_topk[i].iloc[0])
        percent_ans_topk_list.append(percent_ans_topk[i]*100)

    # print(str_ans_topk_list)

    ans_topk_df = pd.DataFrame({
        'NO': range(1, k+1),
        '세분류 코드': num_ans_topk,
        '세분류 명칭': str_ans_topk_list,
        '확률': percent_ans_topk_list
    })
    ans_topk_df = ans_topk_df.set_index('NO')

#    ans_topk_df.style.bar(subset='확률', align='left', color='blue')
#    ans_topk_df['확률'].style.applymap(color='black', font_color='blue')

#    st.dataframe(ans_topk_df)
#    st.dataframe(ans_topk_df.style.bar(subset='확률', align='left', color='blue'))
    st.write(ans_topk_df.style.bar(subset='확률', align='left', color='blue'))