import streamlit as st
from transformers import AutoTokenizer, EsmModel
import torch
import json

def embed(aa_seq, tokenizer, model):
    inputs = tokenizer(aa_seq, return_tensors="pt")
    outputs = model(**inputs)
    last_hidden_states = outputs.last_hidden_state.detach().numpy().tolist()
    
    return last_hidden_states

# selecing and loading a model
model_name = st.selectbox(
    'Choose a model',
    ["facebook/esm2_t6_8M_UR50D", "facebook/esm2_t48_15B_UR50D"])

#aa_seq_input = st.text_input('Type AA sequance here')

#uploading AA sequences file
uploaded_file = st.file_uploader("Upload JSON with AA sequences", type='json')
if uploaded_file is not None:
    data = json.load(uploaded_file)
    #st.write(data)

def embed_upload_file(upload_dict_dania, tokenizer, model):
    # upload_dict_dania = {
    #                    'uid1': ['aa', 'aan'] 
    #                     }
    # output = {
    #          'uid1': {'aa':[[[0.1298, ....]]], 'aan':[[[0.1298, ....]]]} 
    #          }
    output = {}
    
    # Add a placeholder
    latest_iteration = st.empty()
    bar = st.progress(0)

    for idx, (uid, seqs) in enumerate(upload_dict_dania.items()):
        output[uid] = {}
        # Update the progress bar with each iteration.
        latest_iteration.text(f'Iteration {uid}')
        bar.progress(idx + 1)
        for seq in seqs:
            output[uid][seq] = embed(seq, tokenizer, model)
        
    json_data = json.dumps(output)

    st.download_button(
        label = "Download JSON file",
        data = json_data,
        file_name = "esm-2 last hidden states.json",
        mime = 'application/json'
    )

    
if st.button('Get embedding'):
    st.write('You selected model:', model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = EsmModel.from_pretrained(model_name)
    embed_upload_file(data, tokenizer, model)

st.write('Also, Dania is not gay')