PuristanLabs1 commited on
Commit
1cc5f25
·
verified ·
1 Parent(s): 75d9e33

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import gradio as gr
6
+
7
+ # Update this to the correct path of your folder
8
+ task_folder = "./ARC_AGI" # Use absolute path if necessary
9
+
10
+ def plot_arc_problem(task_file_path, data_type="train"):
11
+ """
12
+ Visualize input-output pairs for a given ARC task file.
13
+ Args:
14
+ task_file_path (str): Path to the ARC task JSON file.
15
+ data_type (str): 'train' or 'test' to visualize respective examples.
16
+ Returns:
17
+ matplotlib.figure.Figure or str: The plotted figure or an error message.
18
+ """
19
+ try:
20
+ # Load the JSON file
21
+ with open(task_file_path, 'r') as f:
22
+ task = json.load(f)
23
+
24
+ # Get the data type (train or test) pairs
25
+ pairs = task.get(data_type, [])
26
+ if not pairs: # Check if the section exists and has data
27
+ return f"No '{data_type}' data found in the selected file."
28
+
29
+ # Create a figure with subplots for each pair
30
+ fig, axes = plt.subplots(len(pairs), 2, figsize=(10, 5 * len(pairs)))
31
+ fig.suptitle(f"ARC Task: {os.path.basename(task_file_path)} ({data_type.capitalize()})", fontsize=16)
32
+
33
+ # Handle case where there is only one pair
34
+ if len(pairs) == 1:
35
+ axes = [axes]
36
+
37
+ for idx, pair in enumerate(pairs):
38
+ input_grid = np.array(pair['input'])
39
+ output_grid = np.array(pair['output'])
40
+
41
+ # Plot input grid
42
+ axes[idx][0].imshow(input_grid, cmap='tab20', interpolation='none')
43
+ axes[idx][0].set_title(f"Input Pair {idx + 1} ({data_type.capitalize()})")
44
+ axes[idx][0].axis('off')
45
+
46
+ # Plot output grid
47
+ axes[idx][1].imshow(output_grid, cmap='tab20', interpolation='none')
48
+ axes[idx][1].set_title(f"Output Pair {idx + 1} ({data_type.capitalize()})")
49
+ axes[idx][1].axis('off')
50
+
51
+ plt.tight_layout(rect=[0, 0, 1, 0.96])
52
+ return fig
53
+ except Exception as e:
54
+ print("Error in plot_arc_problem:", e)
55
+ return f"An error occurred while plotting the {data_type} data: {e}"
56
+
57
+ def visualize_task(file_name, data_type):
58
+ """
59
+ Load and visualize the ARC task for a given file and data type.
60
+ Args:
61
+ file_name (str): Name of the JSON file in the folder.
62
+ data_type (str): 'train' or 'test'.
63
+ Returns:
64
+ matplotlib.figure.Figure or str: Figure of the visualized task or an error message.
65
+ """
66
+ try:
67
+ print(f"Selected file: {file_name}, Data type: {data_type}") # Debugging
68
+ task_file_path = os.path.join(task_folder, file_name)
69
+ result = plot_arc_problem(task_file_path, data_type)
70
+ return result
71
+ except Exception as e:
72
+ print("Error in visualize_task:", e)
73
+ return f"An error occurred while visualizing the task: {e}"
74
+
75
+ # Gradio Interface
76
+ task_files = [f for f in os.listdir(task_folder) if f.endswith('.json')]
77
+
78
+ interface = gr.Interface(
79
+ fn=visualize_task,
80
+ inputs=[
81
+ gr.Dropdown(choices=task_files, label="Select ARC Task File"),
82
+ gr.Radio(choices=["train", "test"], label="Select Data Type to Visualize", value="train")
83
+ ],
84
+ outputs="plot",
85
+ title="ARC Task Visualizer",
86
+ description="Select a task file and data type (train or test) to visualize its input-output grids."
87
+ )
88
+
89
+ if __name__ == "__main__":
90
+ print("Task files:", task_files) # Debugging
91
+ interface.launch()