import os

import rwkv_rs
import numpy as np
import huggingface_hub
import tokenizers

import gradio as gr

model_path = "./rnn.safetensors"
if not os.path.exists(model_path):
    model_path = huggingface_hub.hf_hub_download(repo_id="mrsteyk/RWKV-LM-safetensors", filename="RWKV-4-Pile-7B-Instruct-test1-20230124.rnn.safetensors")
assert model_path is not None

model = rwkv_rs.Rwkv(model_path)
tokenizer = tokenizers.Tokenizer.from_pretrained("EleutherAI/gpt-neox-20b")

GT = [
    gr.Button.update(visible=False),
    gr.Button.update(visible=True),
]
GF = [
    gr.Button.update(visible=True),
    gr.Button.update(visible=False),
]

def complete_fn(inpt, max_tokens, min_tokens, alpha_f, alpha_p):
    try:
        state = rwkv_rs.State(model)
        text = inpt
        counts = [0]*tokenizer.get_vocab_size()
        tokens = tokenizer.encode(inpt).ids
        yield (None, gr.Text.update(visible=False))
        # yield ("Preproc...", gr.Text.update(visible=False))
        # logits = model.forward(tokens, state)
        for i in range(len(tokens) - 1):
            model.forward_token_preproc(tokens[i], state)
            yield (tokenizer.decode(tokens[:i + 1]), None)
        logits = model.forward_token(tokens[-1], state)
        yield (text, None)
        max_tokens = int(max_tokens)
        for i in range(max_tokens):
            if i < min_tokens:
                logits[0] = -100
            for i in range(len(counts)):
                logits[i] -= (counts[i]* alpha_f) + (float(counts[i] > 0) * alpha_p)
            token = np.argmax(logits)
            counts[token] += 1
            if token == 0:
                break
            tokens += [token]
            text = tokenizer.decode(tokens)
            yield (text, None)
            if i == max_tokens - 1:
                break
            logits = model.forward_token(token, state)
        yield (text, None)
    except Exception as e:
        print(e)
        yield ("Error...", gr.Text.update(value=str(e), visible=True))
    # finally:
    #     return (None, None)

def insert_fn(inpt: str, max_tokens, min_tokens, alpha_f, alpha_p, num_tokens_insert):
    try:
        if inpt.count("<|INSERT|>") != 1:
            yield ("Error...", gr.Text.update(value="Exactly one replace is allowed!", visible=True))
            return
        state = rwkv_rs.State(model)
        text, end = inpt.split("<|INSERT|>")
        counts = [0]*tokenizer.get_vocab_size()
        tokens = tokenizer.encode(text).ids
        tokens_end = tokenizer.encode(end).ids
        tokens_i = tokens_end[:num_tokens_insert]
        ins = [0]*len(tokens_i)
        yield (None, gr.Text.update(visible=False))
        for i in range(len(tokens) - 1):
            model.forward_token_preproc(tokens[i], state)
            yield (tokenizer.decode(tokens[:i + 1]), None)
        logits = model.forward_token(tokens[-1], state)
        yield (text, None)
        max_tokens = int(max_tokens)
        for i in range(max_tokens):
            if i < min_tokens:
                logits[0] = -100
            for i in range(len(counts)):
                logits[i] -= (counts[i]* alpha_f) + (float(counts[i] > 0) * alpha_p)
            token = np.argmax(logits)
            counts[token] += 1
            if token == 0:
                break
            tokens += [token]
            ins = ins[1:] + [token]
            if ins == tokens_i:
                tokens += tokens_end[num_tokens_insert:]
                i = max_tokens - 1 # to break earlier...
            text = tokenizer.decode(tokens)
            yield (text, None)
            if i == max_tokens - 1:
                break
            logits = model.forward_token(token, state)
        yield (text, None)
    except Exception as e:
        print(e)
        yield ("Error...", gr.Text.update(value=str(e), visible=True))

def classify_fn_inner2(inpt, clas):
    state = rwkv_rs.State(model)
    tokens = tokenizer.encode(f"This is an example of {clas} text:").ids
    for i in tokens:
        model.forward_token_preproc(i, state)

    tokens = tokenizer.encode(f" {inpt}\n").ids
    loss = 0
    for i in range(len(tokens)-1):
        loss += np.log(softmax(model.forward_token(tokens[i], state)))[tokens[i+1]]
    loss = -loss / (len(tokens)-1)

    return loss 

def softmax(x):
    e = np.exp(x - np.max(x))
    return e / e.sum()

def classify_fn(inpt, clas, clasneg):
    loss_3 = classify_fn_inner2(inpt, clas)
    loss_3_neg = classify_fn_inner2(inpt, clasneg)
    # print(loss_3, loss_3_neg, end=' | ')
    loss_3, loss_3_neg = softmax([-loss_3, -loss_3_neg])
    # print(loss_3, loss_3_neg)

    return ({"+": loss_3, "-": loss_3_neg})

def generator_wrap(l, fn):
    def wrap(*args):
        last_i = list([None] * l)
        try:
            for i in fn(*args):
                last_i = list(i)
                yield last_i + GT
        finally:
            yield last_i + GF
    return wrap


with gr.Blocks() as app:
    gr.Markdown(f"Running on `{model_path}`")
    error_box = gr.Text(label="Error", visible=False)

    with gr.Tab("Complete"):
        with gr.Row():
            inpt = gr.TextArea(label="Input")
            out = gr.TextArea(label="Output")
        complete = gr.Button("Complete", variant="primary")
        c_stop = gr.Button("Stop", variant="stop", visible=False)
    with gr.Tab("Insert"):
        gr.Markdown("Use `<|INSERT|>` to indicate a place to replace, if insert fails - end text won't be concatenated")
        with gr.Row():
            inpt_i = gr.TextArea(label="Input")
            out_i = gr.TextArea(label="Output")
        num_tokens_insert = gr.Slider(label="Number of tokens to compare for ending (from the beginning of 2nd part)", minimum=1, maximum=2048, value=1024, step=1)
        insert = gr.Button("Insert", variant="primary")
        i_stop = gr.Button("Stop", variant="stop", visible=False)
    with gr.Tab("Classification W/O head"):
        gr.Markdown("This is an experimental classification attempt based on [this Twitter post](https://twitter.com/aicrumb/status/1625239547268280321)\n\nSettings at the bottom do no affect this example.")
        with gr.Row():
            inpt_c = gr.TextArea(label="Input")
            out_c = gr.Label(label="Output")
        clas = gr.Textbox(label="+ NL class/example to check against.")
        clasneg = gr.Textbox(label="- NL class/example to check against.")
        classify = gr.Button("Classify", variant="primary")

    with gr.Column():
        max_tokens = gr.Slider(label="Max Tokens", minimum=1, maximum=4096, step=1, value=767)
        min_tokens = gr.Slider(label="Min Tokens", minimum=0, maximum=4096, step=1)
        alpha_f = gr.Slider(label="Alpha Frequency", minimum=0, maximum=100, step=0.01)
        alpha_p = gr.Slider(label="Alpha Presence", minimum=0, maximum=100, step=0.01)

    c = complete.click(generator_wrap(2, complete_fn), [inpt, max_tokens, min_tokens, alpha_f, alpha_p], [out, error_box, complete, c_stop])
    c_stop.click(lambda: (complete.update(visible=True), c_stop.update(visible=False)), inputs=None, outputs=[complete, c_stop], cancels=[c], queue=False)

    i = insert.click(generator_wrap(2, insert_fn), [inpt_i, max_tokens, min_tokens, alpha_f, alpha_p, num_tokens_insert], [out_i, error_box, insert, i_stop])
    i_stop.click(lambda: (insert.update(visible=True), i_stop.update(visible=False)), inputs=None, outputs=[insert, i_stop], cancels=[i], queue=False)

    classify.click(classify_fn, [inpt_c, clas, clasneg], [out_c])

app.queue(concurrency_count=2)
app.launch()