Spaces:
Runtime error
Runtime error
import os | |
import numpy as np | |
import torch | |
import matplotlib.pyplot as plt | |
import networkx as nx | |
import gradio as gr | |
from matplotlib.colors import LinearSegmentedColormap | |
import matplotlib.patches as mpatches | |
# Check if GPU is available | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
class EnhancedMindMapGenerator: | |
def __init__(self): | |
self.graph = nx.DiGraph() # Using DiGraph for directed edges | |
self.node_positions = {} | |
self.node_colors = {} | |
self.edge_colors = {} | |
self.node_sizes = {} | |
self.node_depth = {} | |
self.levels = {} | |
def reset(self): | |
self.graph = nx.DiGraph() | |
self.node_positions = {} | |
self.node_colors = {} | |
self.edge_colors = {} | |
self.node_sizes = {} | |
self.node_depth = {} | |
self.levels = {} | |
return "Mind map reset successfully" | |
def parse_input(self, text): | |
"""Parse the input text into nodes and relationships""" | |
lines = text.strip().split('\n') | |
root_node = None | |
parent_map = {} # Track parent nodes based on indent level | |
current_indent_level = -1 | |
current_parent = None | |
# First pass: Build hierarchy based on indentation | |
for line in lines: | |
original_line = line | |
line = line.strip() | |
if not line or '->' in line: | |
continue # Skip empty lines and relationship lines for now | |
# Calculate indent level | |
indent_level = len(original_line) - len(original_line.lstrip()) | |
if root_node is None: | |
# This is the root node | |
root_node = line | |
self.add_node(root_node, is_root=True, depth=0) | |
parent_map[0] = root_node | |
current_indent_level = indent_level | |
current_parent = root_node | |
self.levels[0] = [root_node] | |
else: | |
# Handle indentation to determine parent-child relationships | |
if indent_level > current_indent_level: | |
# This is a child of the previous node | |
parent_map[indent_level] = current_parent | |
parent = None | |
if indent_level in parent_map: | |
parent = parent_map[indent_level] | |
# If this is a new indent level, set the parent to the previous node | |
if indent_level > current_indent_level: | |
parent = current_parent | |
else: | |
# Find the closest parent based on indent | |
closest_indent = max([i for i in parent_map.keys() if i < indent_level], default=0) | |
parent = parent_map[closest_indent] | |
# Calculate depth based on parent's depth | |
parent_depth = self.node_depth.get(parent, 0) | |
current_depth = parent_depth + 1 | |
# Add node and edge | |
self.add_node(line, depth=current_depth) | |
self.add_edge(parent, line, "hierarchy") | |
# Add to level structure | |
if current_depth not in self.levels: | |
self.levels[current_depth] = [] | |
self.levels[current_depth].append(line) | |
# Update tracking variables | |
current_indent_level = indent_level | |
current_parent = line | |
parent_map[indent_level] = line | |
# Second pass: Process explicit relationships (->) | |
for line in lines: | |
line = line.strip() | |
if '->' in line: | |
parts = line.split('->') | |
if len(parts) == 2: | |
source = parts[0].strip() | |
target = parts[1].strip() | |
self.add_edge(source, target, "relationship") | |
return f"Parsed mind map with root: {root_node}" | |
def add_node(self, node_name, is_root=False, depth=0): | |
"""Add a node to the graph""" | |
if node_name not in self.graph.nodes: | |
self.graph.add_node(node_name) | |
self.node_depth[node_name] = depth | |
# Set color based on depth | |
if is_root: | |
self.node_colors[node_name] = '#FF5733' # Root is red | |
self.node_sizes[node_name] = 2500 | |
else: | |
# Use a color scheme based on depth | |
color_map = { | |
1: '#3498DB', # Blue | |
2: '#F39C12', # Orange | |
3: '#2ECC71', # Green | |
4: '#9B59B6', # Purple | |
5: '#E74C3C', # Red | |
} | |
self.node_colors[node_name] = color_map.get(depth % len(color_map), '#95A5A6') # Gray as default | |
self.node_sizes[node_name] = 2000 - (depth * 200) # Size decreases with depth | |
def add_edge(self, source, target, edge_type="hierarchy"): | |
"""Add an edge between two nodes""" | |
if source not in self.graph.nodes: | |
self.add_node(source) | |
if target not in self.graph.nodes: | |
self.add_node(target) | |
if not self.graph.has_edge(source, target): | |
self.graph.add_edge(source, target) | |
# Color edges based on type | |
if edge_type == "relationship": | |
self.edge_colors[(source, target)] = 'green' | |
else: | |
self.edge_colors[(source, target)] = 'gray' | |
def calculate_hierarchical_layout(self): | |
"""Calculate a hierarchical layout based on node depth""" | |
# Use hierarchical layout with depth levels | |
pos = {} | |
max_nodes_per_level = max([len(nodes) for nodes in self.levels.values()]) | |
for level, nodes in self.levels.items(): | |
y = -level * 2 # Vertical position based on level | |
# Center the nodes at each level | |
width = max(max_nodes_per_level, len(nodes)) | |
for i, node in enumerate(nodes): | |
x = (i - (len(nodes) - 1) / 2) * 3 # Horizontal spacing | |
pos[node] = np.array([x, y]) | |
return pos | |
def optimize_layout(self): | |
"""Use GPU-accelerated optimization for node layout (if available)""" | |
# First set initial positions using hierarchical layout | |
initial_pos = self.calculate_hierarchical_layout() | |
self.node_positions = initial_pos | |
if device.type == "cuda": | |
print("Optimizing layout using GPU...") | |
# Implement GPU optimization if needed | |
nodes = list(self.graph.nodes) | |
positions = torch.tensor([self.node_positions[node] for node in nodes], device=device) | |
# Simple force-directed algorithm using PyTorch (maintains hierarchical structure) | |
for _ in range(50): | |
# Calculate attractive forces (edges) | |
attractive_force = torch.zeros_like(positions) | |
for u, v in self.graph.edges: | |
u_idx = nodes.index(u) | |
v_idx = nodes.index(v) | |
direction = positions[v_idx] - positions[u_idx] | |
distance = torch.norm(direction) + 1e-5 | |
force = direction * torch.log(distance / 2) * 0.1 | |
attractive_force[u_idx] += force | |
attractive_force[v_idx] -= force | |
# Calculate repulsive forces (nodes at same level) | |
repulsive_force = torch.zeros_like(positions) | |
for level_nodes in self.levels.values(): | |
level_indices = [nodes.index(node) for node in level_nodes if node in nodes] | |
for i_idx, i in enumerate(level_indices): | |
for j in level_indices[i_idx+1:]: | |
direction = positions[j] - positions[i] | |
distance = torch.norm(direction) + 1e-5 | |
if distance < 3.0: # Only apply repulsion when nodes are close | |
force = direction / (distance ** 2) * 0.5 | |
repulsive_force[i] -= force | |
repulsive_force[j] += force | |
# Update positions but maintain y-coordinate (level) | |
new_pos = positions + (attractive_force + repulsive_force) * 0.1 | |
# Preserve y-coordinates to maintain hierarchical layout | |
for i, node in enumerate(nodes): | |
level = self.node_depth[node] | |
new_pos[i, 1] = positions[i, 1] # Keep original y-coordinate | |
positions = new_pos | |
# Copy back to CPU and update positions | |
positions_cpu = positions.cpu().numpy() | |
for i, node in enumerate(nodes): | |
self.node_positions[node] = positions_cpu[i] | |
return "Layout optimized using GPU acceleration while preserving hierarchy" | |
else: | |
# CPU-based optimization | |
# Adjust positions to prevent overlaps while maintaining hierarchy | |
pos = nx.spring_layout( | |
self.graph, | |
pos=self.node_positions, | |
fixed=None, # Don't fix positions | |
k=1.5, # Increase node separation | |
iterations=50, | |
weight=None | |
) | |
# Preserve y-coordinates to maintain hierarchical layout | |
for node in self.graph.nodes: | |
pos[node][1] = self.node_positions[node][1] # Keep original y-coordinate | |
self.node_positions = pos | |
return "Layout optimized using CPU while preserving hierarchy" | |
def visualize(self): | |
"""Generate a visualization of the mind map""" | |
if not self.graph.nodes: | |
return None | |
plt.figure(figsize=(16, 12), dpi=100) | |
# Use calculated positions from hierarchical layout or optimization | |
pos = self.node_positions | |
# Create a legend for depth levels | |
depth_colors = {} | |
for node, depth in self.node_depth.items(): | |
if depth not in depth_colors: | |
depth_colors[depth] = self.node_colors[node] | |
# Draw edges with curved arrows for relationships | |
for edge in self.graph.edges: | |
edge_color = self.edge_colors.get(edge, 'gray') | |
# Use curved edges for explicit relationships, straight for hierarchy | |
if edge_color == 'green': | |
nx.draw_networkx_edges( | |
self.graph, | |
pos, | |
edgelist=[edge], | |
width=2.5, | |
edge_color=edge_color, | |
alpha=0.8, | |
arrows=True, | |
arrowsize=15, | |
connectionstyle="arc3,rad=0.3" | |
) | |
else: | |
nx.draw_networkx_edges( | |
self.graph, | |
pos, | |
edgelist=[edge], | |
width=1.5, | |
edge_color=edge_color, | |
alpha=0.7, | |
arrows=True, | |
arrowsize=12 | |
) | |
# Draw nodes with depth-based colors | |
for node in self.graph.nodes: | |
nx.draw_networkx_nodes( | |
self.graph, | |
pos, | |
nodelist=[node], | |
node_color=self.node_colors.get(node, 'blue'), | |
node_size=self.node_sizes.get(node, 1000), | |
alpha=0.9, | |
edgecolors='black', | |
linewidths=1 | |
) | |
# Draw labels with white background for better readability | |
label_pos = {node: (pos[node][0], pos[node][1]) for node in self.graph.nodes} | |
nx.draw_networkx_labels( | |
self.graph, | |
label_pos, | |
font_size=10, | |
font_family='sans-serif', | |
font_weight='bold', | |
bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', boxstyle='round,pad=0.3') | |
) | |
# Add a legend | |
legend_elements = [ | |
mpatches.Patch(color='#FF5733', label='Root'), | |
mpatches.Patch(color='#3498DB', label='Level 1'), | |
mpatches.Patch(color='#F39C12', label='Level 2'), | |
mpatches.Patch(color='#2ECC71', label='Level 3'), | |
mpatches.Patch(color='#9B59B6', label='Level 4+'), | |
mpatches.Patch(color='green', label='Explicit Relationship'), | |
mpatches.Patch(color='gray', label='Hierarchical Relationship') | |
] | |
plt.legend(handles=legend_elements, loc='upper right') | |
plt.title("Mind Map Visualization", fontsize=16, fontweight='bold') | |
plt.axis('off') | |
plt.tight_layout() | |
# Save to a temporary file | |
temp_path = "mindmap_output.png" | |
plt.savefig(temp_path, format="png", dpi=300, bbox_inches='tight', facecolor='white') | |
plt.close() | |
return temp_path | |
# Create the Gradio interface | |
def create_mind_map(input_text, optimization): | |
"""Create a mind map from input text""" | |
generator = EnhancedMindMapGenerator() | |
message = generator.parse_input(input_text) | |
print(message) | |
if optimization: | |
message = generator.optimize_layout() | |
print(message) | |
image_path = generator.visualize() | |
return image_path | |
# For Colab, use this function to create and launch the demo | |
def create_and_launch(): | |
"""Create and launch the Gradio demo""" | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# Enhanced Mind Map Generator") | |
gr.Markdown("Enter your mind map structure below. Use indentation to represent hierarchy or use -> for direct relationships.") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
input_text = gr.Textbox( | |
placeholder="Project Name\n Task 1\n Subtask 1.1\n Subtask 1.2\n Task 2\nTask 1 -> Task 2", | |
label="Mind Map Structure", | |
lines=15 | |
) | |
with gr.Row(): | |
optimization = gr.Checkbox(label="Use Layout Optimization", value=True) | |
generate_btn = gr.Button("Generate Mind Map", variant="primary") | |
gr.Markdown("### Format Guide:") | |
gr.Markdown(""" | |
- Use indentation (spaces) to create parent-child relationships | |
- Each level of indentation creates a new depth level | |
- Use '-> ' to create explicit connections (e.g., 'NodeA -> NodeB') | |
- The first non-indented line becomes the root node | |
""") | |
with gr.Column(scale=3): | |
output_image = gr.Image(label="Generated Mind Map", type="filepath") | |
generate_btn.click(fn=create_mind_map, inputs=[input_text, optimization], outputs=output_image) | |
# Add examples | |
example_input1 = """Software Project | |
Planning | |
Requirements Gathering | |
Project Timeline | |
Resource Allocation | |
Development | |
Frontend | |
UI Design | |
React Components | |
Backend | |
API Development | |
Database Setup | |
Testing | |
Unit Tests | |
Integration Tests | |
Deployment | |
CI/CD Pipeline | |
Production Release | |
Planning -> Development | |
Development -> Testing | |
Testing -> Deployment""" | |
example_input2 = """Business Strategy | |
Market Analysis | |
Customer Demographics | |
Competitor Research | |
Market Trends | |
Internal Assessment | |
SWOT Analysis | |
Resource Inventory | |
Strategic Goals | |
Short-term Objectives | |
Long-term Vision | |
Implementation | |
Action Plans | |
Market Analysis -> Strategic Goals | |
Internal Assessment -> Strategic Goals | |
Strategic Goals -> Implementation""" | |
gr.Examples( | |
examples=[[example_input1, True], [example_input2, True]], | |
inputs=[input_text, optimization], | |
outputs=output_image, | |
fn=create_mind_map, | |
cache_examples=True, | |
) | |
# Launch with sharing enabled for Colab | |
demo.launch(share=True, debug=True) | |
return demo | |
# Main execution | |
def run_in_colab(): | |
# Install necessary packages | |
print("Installing required packages...") | |
try: | |
import gradio | |
import networkx | |
except ImportError: | |
!pip install gradio networkx matplotlib | |
print("Packages installed!") | |
# Create and launch the demo | |
print("Launching the Enhanced Mind Map Generator...") | |
create_and_launch() | |
# For Google Colab, use this | |
try: | |
import google.colab | |
print("Running in Google Colab environment") | |
run_in_colab() | |
except: | |
print("Running in local environment") | |
# If not in Colab, just create and launch | |
create_and_launch() |