import gradio as gr from datasets import load_dataset # + def get_methods_and_arch(dataset): columns = dataset.column_names[5:] methods = [] archs = [] for column in columns: methods.append(column.split('_')[0]) archs.append('_'.join(column.split('_')[1:-2])) return list(set(methods)),list(set(archs)) def get_columns(arch,method): columns = dataset.column_names[5:] for col in columns: if f'{method}_{arch}' in col: return col def button_fn(arch,method): column_heatmap = get_columns(arch,method) #print("Updated column: ",column_heatmap) return column_heatmap,index_default,dataset[index_default]["image"],dataset[index_default][column_heatmap] def func_slider(index,column_textbox): #global column_heatmap example = dataset[index] return example['image'],example[column_textbox] # - dataset = load_dataset("GazeLocation/stimuli_heatmaps",split = 'train') METHODS, ARCHS = get_methods_and_arch(dataset) index_default = 0 DEMO = False if __name__ == '__main__': demo = gr.Blocks() with demo: gr.Markdown("# Heatmap Gaze Location") with gr.Row(): dropdown_arch = gr.Dropdown(choices = ARCHS, value = 'resnet50', label = 'Model') dropdown_method = gr.Dropdown(choices = METHODS, value = 'gradcam', label = 'Method') with gr.Row(): button = gr.Button(label = 'Update Heatmap Model - Method') with gr.Row(): hf_slider = gr.Slider(minimum=0, maximum=len(dataset)-1,step = 1) with gr.Row(): column_textbox = gr.Textbox(label = 'column name', value = get_columns(ARCHS[0],METHODS[0]) ) with gr.Row(): with gr.Column(): image_input = gr.Image(label="Input Image",value = dataset[index_default]["image"]) with gr.Column(): image_output = gr.Image(label="Output",value = dataset[index_default][get_columns('resnet50','gradcam')]) button.click(fn = button_fn, inputs = [dropdown_arch,dropdown_method], outputs = [column_textbox,hf_slider,image_input,image_output]) hf_slider.change(func_slider, inputs = [hf_slider,column_textbox], outputs = [image_input, image_output]) if DEMO: demo.launch(share = True,debug = True) else: demo.launch()