import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs
from rdkit.Chem import Descriptors
from rdkit.Chem import Draw
import selfies as sf
from rdkit.Chem import RDConfig
import os
import sys
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
import sascorer


def get_largest_ring_size(mol):
    cycle_list = mol.GetRingInfo().AtomRings()
    if cycle_list:
        cycle_length = max([len(j) for j in cycle_list])
    else:
        cycle_length = 0
    return cycle_length

def plogp(smile):
    if smile:
        mol = Chem.MolFromSmiles(smile)
        if mol:
            log_p = Descriptors.MolLogP(mol)
            sas_score = sascorer.calculateScore(mol)
            largest_ring_size = get_largest_ring_size(mol)
            cycle_score = max(largest_ring_size - 6, 0)
            if log_p and sas_score and largest_ring_size:
                p_logp = log_p - sas_score - cycle_score
                return p_logp
            else: 
                return -100
        else:
            return -100
    else:
        return -100
    
def sf_decode(selfies):
    try:
        decode = sf.decoder(selfies)
        return decode
    except sf.DecoderError:
        return ''
    
def sim(input_smile, output_smile):
    if input_smile and output_smile:
        input_mol = Chem.MolFromSmiles(input_smile)
        output_mol = Chem.MolFromSmiles(output_smile)
        if input_mol and output_mol:
            input_fp = AllChem.GetMorganFingerprint(input_mol, 2)
            output_fp = AllChem.GetMorganFingerprint(output_mol, 2)
            sim = DataStructs.TanimotoSimilarity(input_fp, output_fp)
            return sim
        else: return None
    else: return None 


def gen_process(gen_input):
    tokenizer = AutoTokenizer.from_pretrained("zjunlp/MolGen-large")
    model = AutoModelForSeq2SeqLM.from_pretrained("zjunlp/MolGen-large")
    
    sf_input = tokenizer(gen_input, return_tensors="pt")
    
    # beam search
    molecules = model.generate(input_ids=sf_input["input_ids"],
                              attention_mask=sf_input["attention_mask"],
                              max_length=15,
                              min_length=5,
                              num_return_sequences=4,
                              num_beams=5)

    gen_output = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).replace(" ","") for g in molecules]
    smis = [sf.decoder(i) for i in gen_output]
    mols = []
    for smi in smis:
        mol = Chem.MolFromSmiles(smi)
        mols.append(mol)
    
    gen_output_image = Draw.MolsToGridImage(
        mols,
        molsPerRow=4,
        subImgSize=(200,200),
        legends=['' for x in mols]
    )
    
    return "\n".join(gen_output), gen_output_image
 
def opt_process(opt_input):

    tokenizer = AutoTokenizer.from_pretrained("zjunlp/MolGen-large-opt")
    model = AutoModelForSeq2SeqLM.from_pretrained("zjunlp/MolGen-large-opt")
    
    input = opt_input

    smis_input = sf.decoder(input)
    mol_input = []
    mol = Chem.MolFromSmiles(smis_input)
    mol_input.append(mol)
    
    opt_input_img = Draw.MolsToGridImage(
        mol_input,
        molsPerRow=4,
        subImgSize=(200,200),
        legends=['' for x in mol_input]
    )
    
    sf_input = tokenizer(input, return_tensors="pt")
    molecules = model.generate(
                    input_ids=sf_input["input_ids"],
                    attention_mask=sf_input["attention_mask"],
                    do_sample=True,
                    max_length=100,
                    min_length=5,
                    top_k=30,
                    top_p=1,
                    num_return_sequences=10
                    )
    sf_output = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).replace(" ","") for g in molecules]
    sf_output = list(set(sf_output))
    input_sm = sf_decode(input)
    sm_output = [sf_decode(sf) for sf in sf_output]
    
    
    input_plogp = plogp(input_sm)
    plogp_improve = [plogp(i)-input_plogp for i in sm_output]
    
    
    simm = [sim(i,input_sm) for i in sm_output]
    
    candidate_selfies = {"candidates": sf_output, "improvement": plogp_improve, "sim": simm}
    data = pd.DataFrame(candidate_selfies)
    
    results = data[(data['improvement']> 0) & (data['sim']>0.4)]
    opt_output  = results["candidates"].tolist()
    opt_output_imp = results["improvement"].tolist()
    opt_output_imp = [str(i) for i in opt_output_imp]
    opt_output_sim = results["sim"].tolist()
    opt_output_sim = [str(i) for i in opt_output_sim]
    

    smis = [sf.decoder(i) for i in opt_output]
    mols = []
    for smi in smis:
        mol = Chem.MolFromSmiles(smi)
        mols.append(mol)
    
    opt_output_img = Draw.MolsToGridImage(
        mols,
        molsPerRow=4,
        subImgSize=(200,200),
        legends=['' for x in mols]
    )
    return opt_input_img, "\n".join(opt_output), "\n".join(opt_output_imp), "\n".join(opt_output_sim), opt_output_img

with gr.Blocks() as demo:
    gr.Markdown("# MolGen: Domain-Agnostic Molecular Generation with Self-feedback")

    with gr.Tabs():
        with gr.TabItem("Molecular Generation"):
            with gr.Row():
                with gr.Column():
                    gen_input = gr.Textbox(label="Input", lines=1, placeholder="SELFIES Input")
                    gen_button = gr.Button("Generate")

                with gr.Column():
                    gen_output = gr.Textbox(label="Generation Results", lines=5, placeholder="")
                    gen_output_image = gr.Image(label="Visualization")
       
            gr.Examples(
                examples=[["[C][=C][C][=C][C][=C][Ring1][=Branch1]"], 
                          ["[C]"]
                          ],
                inputs=[gen_input],
                outputs=[gen_output, gen_output_image],
                fn=gen_process,
                cache_examples=True,
            )

        with gr.TabItem("Constrained Molecular Property Optimization"):
            with gr.Row():
                with gr.Column():
                    opt_input = gr.Textbox(label="Input", lines=1, placeholder="SELFIES Input")
                    opt_button = gr.Button("Optimize")

                with gr.Column():
                    opt_input_img = gr.Image(label="Input Visualization")
                    opt_output = gr.Textbox(label="Optimization Results", lines=3, placeholder="")
                    opt_output_imp = gr.Textbox(label="Optimization Property Improvements", lines=3, placeholder="")
                    opt_output_sim = gr.Textbox(label="Similarity", lines=3, placeholder="")
                    opt_output_img = gr.Image(label="Output Visualization")
                    

            gr.Examples(
                examples=[["[C][C][=Branch1][C][=O][N][C][C][O][C][C][O][C][C][O][C][C][Ring1][N]"], 
                          ["[C][C][S][C][C][S][C][C][C][S][C][C][S][C][Ring1][=C]"],
                          ["[N][#C][C][C][C@@H1][C][C][C][C][C][C][C][C][C][C][C][Ring1][N][=O]"]
                          ],
                inputs=[opt_input],
                outputs=[opt_input_img, opt_output, opt_output_imp, opt_output_sim, opt_output_img],
                fn=opt_process,
                cache_examples=True,
            )

    gen_button.click(fn=gen_process, inputs=[gen_input], outputs=[gen_output, gen_output_image])
    opt_button.click(fn=opt_process, inputs=[opt_input], outputs=[opt_input_img, opt_output, opt_output_imp, opt_output_sim, opt_output_img])

demo.launch()