import gradio as gr
import torch
import ecco
import requests
from transformers import AutoTokenizer
from torch.nn import functional as F

header = """
import psycopg2

conn = psycopg2.connect("CONN")
cur = conn.cursor()

MIDDLE
def rename_customer(id, newName):\n\t# PROMPT\n\tcur.execute("UPDATE customer SET name =
"""

modelPath = {
    # "GPT2-Medium": "gpt2-medium",
    "CodeParrot-mini": "codeparrot/codeparrot-small",
    # "CodeGen-350-Mono": "Salesforce/codegen-350M-mono",
    # "GPT-Neo-1.3B": "EleutherAI/gpt-neo-1.3B",
    "CodeParrot": "codeparrot/codeparrot",
    # "CodeGen-2B-Mono": "Salesforce/codegen-2B-mono",
}

preloadModels = {}
for m in list(modelPath.keys()):
    preloadModels[m] = ecco.from_pretrained(modelPath[m])

def generation(tokenizer, model, content):
    decoder = 'Standard'
    num_beams = 2 if decoder == 'Beam' else None
    typical_p = 0.8 if decoder == 'Typical' else None
    do_sample = (decoder in ['Beam', 'Typical', 'Sample'])

    seek_token_ids = [
        tokenizer.encode('= \'" +')[1:],
        tokenizer.encode('= " +')[1:],
    ]

    full_output = model.generate(content, generate=6, do_sample=False)

    def next_words(code, position, seek_token_ids):
        op_model = model.generate(code, generate=1, do_sample=False)
        hidden_states = op_model.hidden_states
        layer_no = len(hidden_states) - 1
        h = hidden_states[-1]
        hidden_state = h[position - 1]
        logits = op_model.lm_head(op_model.to(hidden_state))
        softmax = F.softmax(logits, dim=-1)
        my_token_prob = softmax[seek_token_ids[0]]

        if len(seek_token_ids) > 1:
            newprompt = code + tokenizer.decode(seek_token_ids[0])
            return my_token_prob * next_words(newprompt, position + 1, seek_token_ids[1:])
        return my_token_prob

    prob = 0
    for opt in seek_token_ids:
        prob += next_words(content, len(tokenizer(content)['input_ids']), opt)
    return ["".join(full_output.tokens), str(prob.item() * 100) + '% chance of risky concatenation']

def code_from_prompts(prompt, model, type_hints, pre_content):
    tokenizer = AutoTokenizer.from_pretrained(modelPath[model])
    # model = ecco.from_pretrained(modelPath[model])
    model = preloadModels[model]

    code = header.strip().replace('CONN', "dbname='store'").replace('PROMPT', prompt)

    if type_hints:
        code = code.replace('id,', 'id: int,')
        code = code.replace('id)', 'id: int)')
        code = code.replace('newName)', 'newName: str) -> None')

    if pre_content == 'None':
        code = code.replace('MIDDLE\n', '')
    elif 'Concatenation' in pre_content:
        code = code.replace('MIDDLE', """
def get_customer(id):\n\tcur.execute('SELECT * FROM customers WHERE id = ' + str(id))\n\treturn cur.fetchall()
""".strip() + "\n")
    elif 'composition' in pre_content:
        code = code.replace('MIDDLE', """
def get_customer(id):\n\tcur.execute('SELECT * FROM customers WHERE id = %s', str(id))\n\treturn cur.fetchall()
""".strip() + "\n")

    results = generation(tokenizer, model, code)
    return results

iface = gr.Interface(
    fn=code_from_prompts,
	inputs=[
        gr.components.Textbox(label="Insert comment"),
        gr.components.Radio(list(modelPath.keys()), label="Code Model"),
        gr.components.Checkbox(label="Include type hints"),
        gr.components.Radio([
            "None",
            "Proper composition: Include function 'WHERE id = %s'",
            "Concatenation: Include a function with 'WHERE id = ' + id",
        ], label="Has user already written a function?")
    ],
	outputs=[
        gr.components.Textbox(label="Most probable code"),
        gr.components.Textbox(label="Probability of concat"),
    ],
	description="Prompt the code model to write a SQL query with string concatenation.",
)
iface.launch()