import gradio as gr
from datasets import load_dataset
import argparse
import pandas as pd
from functools import partial
import subprocess

# Headers and datatypes for the remaining columns
HEADERS = ["__index_level_0__", "problem", "username", "entrypoint", "submitted_text", "prompt", "subset"]
DATATYPES = ["number", "str", "str", "str", "str", "str", "str"]

SUCCESS_HEADERS = ["subset"]
SUCCESS_DATATYPES = ["str"]

def capture_output(prompt, completion, prints):
    code = "\n".join([prompt, "    "+"   \n".join(completion.split("\n")), prints])
    outputs = subprocess.run(["python", "-c", code], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    stderr = gr.Textbox(outputs.stderr.decode("utf-8").strip(), label="Code Errors", type="text")
    stdout = gr.Code(outputs.stdout.decode("utf-8").strip(), label="Code Outputs", language="python")
    return stderr, stdout

def update_components(
    ds, 
    slider, 
    header_data, 
    success_data, 
    prompt, 
    submitted_text, 
    assertions, 
):
    if isinstance(ds, gr.State):
        ds = ds.value
    row = ds.iloc[[slider]]
    header_data = gr.Dataframe(
        headers=HEADERS,
        datatype=DATATYPES,
        row_count=1,
        col_count=(len(HEADERS), "fixed"),
        column_widths=["60px"]*len(HEADERS),
        value=row[HEADERS],
        interactive=False
    )
    success_data = gr.Dataframe(
        headers=SUCCESS_HEADERS,
        datatype=SUCCESS_DATATYPES,
        row_count=1,
        col_count=(len(SUCCESS_HEADERS), "fixed"),
        column_widths=["60px"]*len(SUCCESS_HEADERS),
        value=row[SUCCESS_HEADERS],
        interactive=False
    )
    row = row.iloc[0]
    prompt = gr.Code(row["prompt"], language="python", label="Prompt")
    submitted_text = gr.Textbox(row["submitted_text"], type="text", label="Submitted Text")
    assertions = gr.Code(row["assertions"], language="python", label="Assertions")
    slider = gr.Slider(0, len(ds) - 1, step=1, label="Problem ID (click and arrow keys to navigate):", value=slider)
    return [slider, header_data, success_data, prompt, submitted_text, 
            assertions]

def filter_by(
    dataset_name, 
    dataset_split,
    problem_box,
    student_box,
    slider,
    *components_to_update):
    ds = load_dataset(dataset_name, split=dataset_split)
    ds = ds.to_pandas()
    
    if problem_box != None:
        ds = ds[ds["problem"] == problem_box]
        
    if student_box != None:
        ds = ds[ds["username"] == student_box]
    
    return [ds, *update_components(ds, 0, *components_to_update)]

def next_example(ds, *components):
    slider_value = components[0]
    new_slider_value = int(slider_value)+1 if slider_value < len(ds)-1 else len(ds)-1
    lesscomponents = components[1:]
    return update_components(ds, new_slider_value, *lesscomponents)

def prev_example(ds, *components):
    slider_value = components[0]
    new_slider_value = int(slider_value)-1 if slider_value > 0 else 0
    lesscomponents = components[1:]
    return update_components(ds, new_slider_value, *lesscomponents)

def main(args):
    ds = load_dataset(args.dataset, split=args.split)
    ds = ds.to_pandas()    
    callback = gr.SimpleCSVLogger()
    student_usernames = list(set(ds["username"]))
    student_usernames.sort(key=lambda x: int(x.replace("student","")))
    problem_names = list(set(ds["problem"]))
    problem_names.sort()
    
    with gr.Blocks(theme="gradio/monochrome") as demo:
        dataset = gr.State(ds)
        # slider for selecting problem id
        slider = gr.Slider(0, len(ds) - 1, step=1, label="Problem ID (click and arrow keys to navigate):")
        # display headers in dataframe for problem id
        header_data = gr.Dataframe(
            headers=HEADERS,
            datatype=DATATYPES,
            row_count=1,
            col_count=(len(HEADERS), "fixed"),
            column_widths=["60px"]*len(HEADERS),
            interactive=False,
        )
        success_data = gr.Dataframe(
            headers=SUCCESS_HEADERS,
            datatype=SUCCESS_DATATYPES,
            row_count=1,
            col_count=(len(SUCCESS_HEADERS), "fixed"),
            column_widths=["60px"]*len(SUCCESS_HEADERS),
            interactive=False,
        )
        
        prompt = gr.Code("__prompt__", language="python", label="Prompt")     
        with gr.Row():
            prev_btn = gr.Button("Previous")
            next_btn = gr.Button("Next")       
        submitted_text = gr.Textbox("__submitted_text__", type="text", label="Submitted Text")

                
        with gr.Row():
            assertions = gr.Code("__assertions__", language="python", label="Assertions")
    
            
        # updates                
        # change example on slider change
        components = [slider, header_data, success_data, prompt, submitted_text, assertions]
        slider.input(fn=update_components, inputs=[dataset, *components], outputs=components)
                
        prev_btn.click(fn=prev_example, inputs=[dataset, *components], outputs=components)
        next_btn.click(fn=next_example, inputs=[dataset, *components], outputs=components)

        # add filtering options
        gr.Markdown("**Filtering (reload to clear all filters)**\n")
        with gr.Row():
            with gr.Column():
                problem_box = gr.Dropdown(label="problem", choices=problem_names)
                student_box = gr.Dropdown(label="username", choices=student_usernames)
            filter_btn = gr.Button("Filter")
        
        filter_btn.click(fn=partial(filter_by, args.dataset, args.split), inputs=[problem_box, student_box, *components], 
                         outputs=[dataset, *components])
        

    demo.launch(share=args.share)
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default="nuprl-staging/studenteval_tagged_prompts")
    parser.add_argument("--split", type=str, default="test")
    parser.add_argument("--share", action="store_true")
    args = parser.parse_args()
    main(args)