neuralworm commited on
Commit
b3feaa3
Β·
1 Parent(s): f0edf49

initial commit

Browse files
Files changed (3) hide show
  1. gen.py +40 -14
  2. psychohistory.py +76 -159
  3. requirements.txt +4 -2
gen.py CHANGED
@@ -1,12 +1,21 @@
1
  import torch
2
  import sys
3
- import sys
4
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
 
 
5
 
6
- tokenizer = AutoTokenizer.from_pretrained('stabilityai/stablelm-2-zephyr-1_6b')
7
  model = AutoModelForCausalLM.from_pretrained(
8
- 'stabilityai/stablelm-2-zephyr-1_6b',
9
  device_map="auto",
 
10
  )
11
 
12
 
@@ -147,26 +156,43 @@ prompt = (
147
  " }\n"
148
  " }\n"
149
  "}\n\n"
150
- "Ahora, genera un JSON similar con eventos anidados, pero cambia los detalles y nΓΊmeros para hacer que sea con el input que viene a continuacion, respondiendo solo el JSON. No muestres el mensaje del sistema del rol system. Contesta solo JSON, el JSON de respuesta.No muestres este mensaje ni el prompt introducido por el usuario. Asegurate en la respuesta que el JSON esta completo y tiene el formato correcto."
151
  )
152
 
153
 
154
  def generate(event):
155
- # Generar el texto usando el modelo
156
- prompt_msg = [{"role":"system","content":prompt},{'role': 'user', 'content': event}]
 
157
  inputs = tokenizer.apply_chat_template(
158
  prompt_msg,
159
- add_generation_prompt=False,
160
  return_tensors='pt'
161
  )
162
-
163
  tokens = model.generate(
164
  inputs.to(model.device),
165
- max_new_tokens=20096,
166
- temperature=0.7,
167
  do_sample=True
168
  )
169
-
170
 
171
- # Imprimir la salida generada
172
- return "{".join(tokenizer.decode(tokens[0], skip_special_tokens=True).split("<|user|>")[1].split("{")[1:-1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import sys
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
+
5
+ tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b-it')
6
+
7
+ # Configure 4-bit quantization using BitsAndBytesConfig
8
+ quantization_config = BitsAndBytesConfig(
9
+ load_in_4bit=True,
10
+ bnb_4bit_compute_dtype=torch.bfloat16,
11
+ bnb_4bit_quant_type="nf4",
12
+ )
13
 
14
+ # Load the model with the quantization configuration
15
  model = AutoModelForCausalLM.from_pretrained(
16
+ 'google/gemma-2-2b-it',
17
  device_map="auto",
18
+ quantization_config=quantization_config,
19
  )
20
 
21
 
 
156
  " }\n"
157
  " }\n"
158
  "}\n\n"
159
+ "Ahora, genera un JSON similar con eventos anidados, pero cambia los detalles y nΓΊmeros para hacer que sea con el input que viene a continuacion, respondiendo solo el JSON empezando con <json>:"
160
  )
161
 
162
 
163
  def generate(event):
164
+ combined_input = f"{prompt} {event}" # Combine prompt and event
165
+ prompt_msg = [{'role': 'user', 'content': combined_input}]
166
+
167
  inputs = tokenizer.apply_chat_template(
168
  prompt_msg,
169
+ add_generation_prompt=True,
170
  return_tensors='pt'
171
  )
172
+
173
  tokens = model.generate(
174
  inputs.to(model.device),
175
+ max_new_tokens=1024,
176
+ temperature=0.5,
177
  do_sample=True
178
  )
 
179
 
180
+
181
+ output_text = tokenizer.decode(tokens[0], skip_special_tokens=False)
182
+ user_prompt_length = len(f"<bos><start_of_turn>user\n{prompt}\n{event}<end_of_turn>\n<start_of_turn>model\n") # Calculate user prompt length
183
+
184
+ json_start_index = output_text.find("<json>")
185
+ json_end_index = output_text.find("</json>")
186
+
187
+ if json_start_index != -1 and json_end_index != -1:
188
+ json_string = output_text[max(json_start_index + 6, user_prompt_length):json_end_index].strip() # Trim whitespace and remove prompt
189
+
190
+ # Validate JSON (you'll need to define a schema for your JSON structure)
191
+ try:
192
+ validate(instance=json.loads(json_string), schema=your_json_schema)
193
+ return json_string
194
+ except ValidationError as e:
195
+ return f"Error: Invalid JSON - {e}"
196
+
197
+ else:
198
+ return "Error: <json> or </json> not found in generated output"
psychohistory.py CHANGED
@@ -11,220 +11,156 @@ def generate_tree(current_x, current_y, depth, max_depth, max_nodes, x_range, G,
11
  if node_count_per_depth is None:
12
  node_count_per_depth = {}
13
 
14
- if depth not in node_count_per_depth:
15
- node_count_per_depth[depth] = 0
16
-
17
  if depth > max_depth:
18
  return node_count_per_depth
19
 
 
 
 
20
  num_children = random.randint(1, max_nodes)
21
  x_positions = [current_x + i * x_range / (num_children + 1) for i in range(num_children)]
22
 
23
  for x in x_positions:
24
- # Add node to the graph
25
  node_id = len(G.nodes)
26
  node_count_per_depth[depth] += 1
27
- prob = random.uniform(0, 1) # Assign random probability
28
- G.add_node(node_id, pos=(x, prob, depth)) # Use `depth` for z position
29
  if parent is not None:
30
  G.add_edge(parent, node_id)
31
- # Recursively add child nodes
32
  generate_tree(x, current_y + 1, depth + 1, max_depth, max_nodes, x_range, G, parent=node_id, node_count_per_depth=node_count_per_depth)
33
 
34
  return node_count_per_depth
35
 
36
 
37
-
38
  def build_graph_from_json(json_data, G):
39
  """Builds a graph from JSON data."""
 
 
40
  def add_event(parent_id, event_data, depth):
41
- """Recursively adds events and subevents to the graph."""
42
- # Add the current event node
43
  node_id = len(G.nodes)
44
- prob = event_data['probability'] / 100.0 # Convert percentage to probability
45
- pos = (depth, prob, event_data['event_number']) # Use event_number for z position
46
- label = event_data['name'] # Use event name as label
47
  G.add_node(node_id, pos=pos, label=label)
48
  if parent_id is not None:
49
  G.add_edge(parent_id, node_id)
50
 
51
- # Add child events
52
  subevents = event_data.get('subevents', {}).get('event', [])
53
  if not isinstance(subevents, list):
54
- subevents = [subevents] # Ensure subevents is a list
55
 
56
  for subevent in subevents:
57
  add_event(node_id, subevent, depth + 1)
58
 
59
- data = json.loads(json_data)
60
- root_id = len(G.nodes)
61
  root_event = list(data.get('events', {}).values())[0]
 
62
  G.add_node(root_id, pos=(0, root_event['probability'] / 100.0, root_event['event_number']), label=root_event['name'])
63
- add_event(None, root_event, 0) # Start from the root
64
-
65
 
66
 
67
  def find_paths(G):
68
- """Finds the paths with the highest and lowest average probability, and the longest and shortest durations in graph G."""
69
- best_path = None
70
- worst_path = None
71
- longest_duration_path = None
72
- shortest_duration_path = None
73
- best_mean_prob = -1
74
- worst_mean_prob = float('inf')
75
- max_duration = -1
76
- min_duration = float('inf')
77
-
78
- for source in G.nodes:
79
- for target in G.nodes:
80
- if source != target:
81
- all_paths = list(nx.all_simple_paths(G, source=source, target=target))
82
- for path in all_paths:
83
- # Check if all nodes in the path have the 'pos' attribute
84
- if not all('pos' in G.nodes[node] for node in path):
85
- continue # Skip paths with nodes missing the 'pos' attribute
86
-
87
- # Calculate the mean probability of the path
88
- probabilities = [G.nodes[node]['pos'][1] for node in path] # Get node probabilities
89
- mean_prob = np.mean(probabilities)
90
-
91
- # Evaluate path with the highest mean probability
92
- if mean_prob > best_mean_prob:
93
- best_mean_prob = mean_prob
94
- best_path = path
95
-
96
- # Evaluate path with the lowest mean probability
97
- if mean_prob < worst_mean_prob:
98
- worst_mean_prob = mean_prob
99
- worst_path = path
100
-
101
- # Calculate path duration
102
- x_positions = [G.nodes[node]['pos'][0] for node in path]
103
- duration = max(x_positions) - min(x_positions)
104
-
105
- # Evaluate path with the longest duration
106
- if duration > max_duration:
107
- max_duration = duration
108
- longest_duration_path = path
109
-
110
- # Evaluate path with the shortest duration
111
- if duration < min_duration:
112
- min_duration = duration
113
- shortest_duration_path = path
114
-
115
- return best_path, best_mean_prob, worst_path, worst_mean_prob, longest_duration_path, shortest_duration_path
116
 
117
  def draw_path_3d(G, path, filename='path_plot_3d.png', highlight_color='blue'):
118
- """Draws only the specific path in 3D using networkx and matplotlib and saves the figure to a file."""
119
- # Create a subgraph containing only the nodes and edges of the path
120
  H = G.subgraph(path).copy()
121
-
122
  pos = nx.get_node_attributes(G, 'pos')
123
-
124
- # Get data for 3D visualization
125
  x_vals, y_vals, z_vals = zip(*[pos[node] for node in path])
126
-
127
  fig = plt.figure(figsize=(16, 12))
128
  ax = fig.add_subplot(111, projection='3d')
129
 
130
- # Assign colors to nodes based on probability
131
- node_colors = []
132
- for node in path:
133
- prob = G.nodes[node]['pos'][1]
134
- if prob < 0.33:
135
- node_colors.append('red')
136
- elif prob < 0.67:
137
- node_colors.append('blue')
138
- else:
139
- node_colors.append('green')
140
-
141
- # Draw nodes
142
  ax.scatter(x_vals, y_vals, z_vals, c=node_colors, s=700, edgecolors='black', alpha=0.7)
143
-
144
- # Draw edges
145
  for edge in H.edges():
146
  x_start, y_start, z_start = pos[edge[0]]
147
  x_end, y_end, z_end = pos[edge[1]]
148
  ax.plot([x_start, x_end], [y_start, y_end], [z_start, z_end], color=highlight_color, lw=2)
149
 
150
- # Add labels to nodes
151
  for node, (x, y, z) in pos.items():
152
  if node in path:
153
  ax.text(x, y, z, str(node), fontsize=12, color='black')
154
 
155
- # Set labels and title
156
  ax.set_xlabel('Time (weeks)')
157
  ax.set_ylabel('Event Probability')
158
  ax.set_zlabel('Event Number')
159
  ax.set_title('3D Event Tree - Path')
160
 
161
- plt.savefig(filename, bbox_inches='tight') # Save to file with adjusted margins
162
- plt.close() # Close the figure to free resources
163
 
164
 
165
  def draw_global_tree_3d(G, filename='global_tree.png'):
166
- """Draws the entire graph in 3D using networkx and matplotlib and saves the figure to a file."""
167
  pos = nx.get_node_attributes(G, 'pos')
168
  labels = nx.get_node_attributes(G, 'label')
169
-
170
- # Check if the graph is empty
171
  if not pos:
172
  print("Graph is empty. No nodes to visualize.")
173
  return
174
 
175
- # Get data for 3D visualization
176
  x_vals, y_vals, z_vals = zip(*pos.values())
177
-
178
  fig = plt.figure(figsize=(16, 12))
179
  ax = fig.add_subplot(111, projection='3d')
180
 
181
- # Assign colors to nodes based on probability
182
- node_colors = []
183
- for node, (x, prob, z) in pos.items():
184
- if prob < 0.33:
185
- node_colors.append('red')
186
- elif prob < 0.67:
187
- node_colors.append('blue')
188
- else:
189
- node_colors.append('green')
190
-
191
- # Draw nodes
192
  ax.scatter(x_vals, y_vals, z_vals, c=node_colors, s=700, edgecolors='black', alpha=0.7)
193
-
194
- # Draw edges
195
  for edge in G.edges():
196
  x_start, y_start, z_start = pos[edge[0]]
197
  x_end, y_end, z_end = pos[edge[1]]
198
  ax.plot([x_start, x_end], [y_start, y_end], [z_start, z_end], color='gray', lw=2)
199
 
200
- # Add labels to nodes
201
  for node, (x, y, z) in pos.items():
202
  label = labels.get(node, f"{node}")
203
  ax.text(x, y, z, label, fontsize=12, color='black')
204
 
205
- # Set labels and title
206
  ax.set_xlabel('Time')
207
  ax.set_ylabel('Probability')
208
  ax.set_zlabel('Event Number')
209
  ax.set_title('3D Event Tree')
210
 
211
- plt.savefig(filename, bbox_inches='tight') # Save to file with adjusted margins
212
- plt.close() # Close the figure to free resources
213
 
214
  def main(mode, input_file=None):
215
  G = nx.DiGraph()
216
 
217
  if mode == 'random':
218
- starting_x = 0
219
- starting_y = 0
220
- max_depth = 5 # Maximum depth of the tree
221
- max_nodes = 3 # Maximum number of child nodes
222
- x_range = 10 # Maximum range for x position of nodes
223
-
224
- # Generate the tree and get node count per depth
225
- generate_tree(starting_x, starting_y, 0, max_depth, max_nodes, x_range, G)
226
-
227
-
228
  elif mode == 'json' and input_file:
229
  with open(input_file, 'r') as file:
230
  json_data = file.read()
@@ -233,50 +169,33 @@ def main(mode, input_file=None):
233
  print("Invalid mode or input file not provided.")
234
  return
235
 
236
- # Save the global visualization
237
- draw_global_tree_3d(G, filename='global_tree.png')
238
-
239
 
240
- # Find relevant paths
241
- best_path, best_mean_prob, worst_path, worst_mean_prob, longest_duration_path, shortest_duration_path = find_paths(G)
242
 
243
- # Print results
244
  if best_path:
245
- print(f"\nPath with the highest average probability:")
246
- print(" -> ".join(map(str, best_path)))
247
  print(f"Average probability: {best_mean_prob:.2f}")
248
-
249
  if worst_path:
250
- print(f"\nPath with the lowest average probability:")
251
- print(" -> ".join(map(str, worst_path)))
252
  print(f"Average probability: {worst_mean_prob:.2f}")
 
 
 
 
 
 
253
 
254
- if longest_duration_path:
255
- print(f"\nPath with the longest duration:")
256
- print(" -> ".join(map(str, longest_duration_path)))
257
- print(f"Duration: {max(G.nodes[node]['pos'][0] for node in longest_duration_path) - min(G.nodes[node]['pos'][0] for node in longest_duration_path):.2f}")
258
 
259
- if shortest_duration_path:
260
- print(f"\nPath with the shortest duration:")
261
- print(" -> ".join(map(str, shortest_duration_path)))
262
- print(f"Duration: {max(G.nodes[node]['pos'][0] for node in shortest_duration_path) - min(G.nodes[node]['pos'][0] for node in shortest_duration_path):.2f}")
263
-
264
- # Save the global visualization
265
- draw_global_tree_3d(G, filename='global_tree.png')
266
-
267
- # Draw and save the 3D figure for each relevant path
268
  if best_path:
269
- draw_path_3d(G, path=best_path, filename='best_path.png', highlight_color='blue')
270
-
271
  if worst_path:
272
- draw_path_3d(G, path=worst_path, filename='worst_path.png', highlight_color='red')
273
-
274
- if longest_duration_path:
275
- draw_path_3d(G, path=longest_duration_path, filename='longest_duration_path.png', highlight_color='green')
276
-
277
- if shortest_duration_path:
278
- draw_path_3d(G, path=shortest_duration_path, filename='shortest_duration_path.png', highlight_color='purple')
279
-
280
 
281
 
282
  if __name__ == "__main__":
@@ -286,5 +205,3 @@ if __name__ == "__main__":
286
  mode = sys.argv[1]
287
  input_file = sys.argv[2] if len(sys.argv) > 2 else None
288
  main(mode, input_file)
289
-
290
-
 
11
  if node_count_per_depth is None:
12
  node_count_per_depth = {}
13
 
 
 
 
14
  if depth > max_depth:
15
  return node_count_per_depth
16
 
17
+ if depth not in node_count_per_depth:
18
+ node_count_per_depth[depth] = 0
19
+
20
  num_children = random.randint(1, max_nodes)
21
  x_positions = [current_x + i * x_range / (num_children + 1) for i in range(num_children)]
22
 
23
  for x in x_positions:
 
24
  node_id = len(G.nodes)
25
  node_count_per_depth[depth] += 1
26
+ prob = random.uniform(0, 1)
27
+ G.add_node(node_id, pos=(x, prob, depth))
28
  if parent is not None:
29
  G.add_edge(parent, node_id)
 
30
  generate_tree(x, current_y + 1, depth + 1, max_depth, max_nodes, x_range, G, parent=node_id, node_count_per_depth=node_count_per_depth)
31
 
32
  return node_count_per_depth
33
 
34
 
 
35
  def build_graph_from_json(json_data, G):
36
  """Builds a graph from JSON data."""
37
+ data = json.loads(json_data)
38
+
39
  def add_event(parent_id, event_data, depth):
 
 
40
  node_id = len(G.nodes)
41
+ prob = event_data['probability'] / 100.0
42
+ pos = (depth, prob, event_data['event_number'])
43
+ label = event_data['name']
44
  G.add_node(node_id, pos=pos, label=label)
45
  if parent_id is not None:
46
  G.add_edge(parent_id, node_id)
47
 
 
48
  subevents = event_data.get('subevents', {}).get('event', [])
49
  if not isinstance(subevents, list):
50
+ subevents = [subevents]
51
 
52
  for subevent in subevents:
53
  add_event(node_id, subevent, depth + 1)
54
 
 
 
55
  root_event = list(data.get('events', {}).values())[0]
56
+ root_id = len(G.nodes)
57
  G.add_node(root_id, pos=(0, root_event['probability'] / 100.0, root_event['event_number']), label=root_event['name'])
58
+ add_event(None, root_event, 0)
 
59
 
60
 
61
  def find_paths(G):
62
+ """Finds paths with highest/lowest probability and longest/shortest durations."""
63
+ best_path, worst_path = None, None
64
+ longest_path, shortest_path = None, None
65
+ best_mean_prob, worst_mean_prob = -1, float('inf')
66
+ max_duration, min_duration = -1, float('inf')
67
+
68
+ # Use nx.all_pairs_shortest_path for efficiency
69
+ all_paths_dict = dict(nx.all_pairs_shortest_path(G))
70
+
71
+ for source, paths_from_source in all_paths_dict.items():
72
+ for target, path in paths_from_source.items():
73
+ if source != target and all('pos' in G.nodes[node] for node in path):
74
+ probabilities = [G.nodes[node]['pos'][1] for node in path]
75
+ mean_prob = np.mean(probabilities)
76
+
77
+ if mean_prob > best_mean_prob:
78
+ best_mean_prob = mean_prob
79
+ best_path = path
80
+ if mean_prob < worst_mean_prob:
81
+ worst_mean_prob = mean_prob
82
+ worst_path = path
83
+
84
+ x_positions = [G.nodes[node]['pos'][0] for node in path]
85
+ duration = max(x_positions) - min(x_positions)
86
+
87
+ if duration > max_duration:
88
+ max_duration = duration
89
+ longest_path = path
90
+ if duration < min_duration and duration > 0: # Avoid paths with 0 duration
91
+ min_duration = duration
92
+ shortest_path = path
93
+
94
+ return best_path, best_mean_prob, worst_path, worst_mean_prob, longest_path, shortest_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  def draw_path_3d(G, path, filename='path_plot_3d.png', highlight_color='blue'):
97
+ """Draws a specific path in 3D."""
 
98
  H = G.subgraph(path).copy()
 
99
  pos = nx.get_node_attributes(G, 'pos')
 
 
100
  x_vals, y_vals, z_vals = zip(*[pos[node] for node in path])
101
+
102
  fig = plt.figure(figsize=(16, 12))
103
  ax = fig.add_subplot(111, projection='3d')
104
 
105
+ node_colors = ['red' if prob < 0.33 else 'blue' if prob < 0.67 else 'green' for _, prob, _ in [pos[node] for node in path]]
 
 
 
 
 
 
 
 
 
 
 
106
  ax.scatter(x_vals, y_vals, z_vals, c=node_colors, s=700, edgecolors='black', alpha=0.7)
107
+
 
108
  for edge in H.edges():
109
  x_start, y_start, z_start = pos[edge[0]]
110
  x_end, y_end, z_end = pos[edge[1]]
111
  ax.plot([x_start, x_end], [y_start, y_end], [z_start, z_end], color=highlight_color, lw=2)
112
 
 
113
  for node, (x, y, z) in pos.items():
114
  if node in path:
115
  ax.text(x, y, z, str(node), fontsize=12, color='black')
116
 
 
117
  ax.set_xlabel('Time (weeks)')
118
  ax.set_ylabel('Event Probability')
119
  ax.set_zlabel('Event Number')
120
  ax.set_title('3D Event Tree - Path')
121
 
122
+ plt.savefig(filename, bbox_inches='tight')
123
+ plt.close()
124
 
125
 
126
  def draw_global_tree_3d(G, filename='global_tree.png'):
127
+ """Draws the entire graph in 3D."""
128
  pos = nx.get_node_attributes(G, 'pos')
129
  labels = nx.get_node_attributes(G, 'label')
130
+
 
131
  if not pos:
132
  print("Graph is empty. No nodes to visualize.")
133
  return
134
 
 
135
  x_vals, y_vals, z_vals = zip(*pos.values())
 
136
  fig = plt.figure(figsize=(16, 12))
137
  ax = fig.add_subplot(111, projection='3d')
138
 
139
+ node_colors = ['red' if prob < 0.33 else 'blue' if prob < 0.67 else 'green' for _, prob, _ in pos.values()]
 
 
 
 
 
 
 
 
 
 
140
  ax.scatter(x_vals, y_vals, z_vals, c=node_colors, s=700, edgecolors='black', alpha=0.7)
141
+
 
142
  for edge in G.edges():
143
  x_start, y_start, z_start = pos[edge[0]]
144
  x_end, y_end, z_end = pos[edge[1]]
145
  ax.plot([x_start, x_end], [y_start, y_end], [z_start, z_end], color='gray', lw=2)
146
 
 
147
  for node, (x, y, z) in pos.items():
148
  label = labels.get(node, f"{node}")
149
  ax.text(x, y, z, label, fontsize=12, color='black')
150
 
 
151
  ax.set_xlabel('Time')
152
  ax.set_ylabel('Probability')
153
  ax.set_zlabel('Event Number')
154
  ax.set_title('3D Event Tree')
155
 
156
+ plt.savefig(filename, bbox_inches='tight')
157
+ plt.close()
158
 
159
  def main(mode, input_file=None):
160
  G = nx.DiGraph()
161
 
162
  if mode == 'random':
163
+ generate_tree(0, 0, 0, 5, 3, 10, G)
 
 
 
 
 
 
 
 
 
164
  elif mode == 'json' and input_file:
165
  with open(input_file, 'r') as file:
166
  json_data = file.read()
 
169
  print("Invalid mode or input file not provided.")
170
  return
171
 
172
+ draw_global_tree_3d(G)
 
 
173
 
174
+ best_path, best_mean_prob, worst_path, worst_mean_prob, longest_path, shortest_path = find_paths(G)
 
175
 
 
176
  if best_path:
177
+ print(f"\nPath with the highest average probability: {' -> '.join(map(str, best_path))}")
 
178
  print(f"Average probability: {best_mean_prob:.2f}")
 
179
  if worst_path:
180
+ print(f"\nPath with the lowest average probability: {' -> '.join(map(str, worst_path))}")
 
181
  print(f"Average probability: {worst_mean_prob:.2f}")
182
+ if longest_path:
183
+ print(f"\nPath with the longest duration: {' -> '.join(map(str, longest_path))}")
184
+ print(f"Duration: {max(G.nodes[node]['pos'][0] for node in longest_path) - min(G.nodes[node]['pos'][0] for node in longest_path):.2f}")
185
+ if shortest_path:
186
+ print(f"\nPath with the shortest duration: {' -> '.join(map(str, shortest_path))}")
187
+ print(f"Duration: {max(G.nodes[node]['pos'][0] for node in shortest_path) - min(G.nodes[node]['pos'][0] for node in shortest_path):.2f}")
188
 
189
+ draw_global_tree_3d(G)
 
 
 
190
 
 
 
 
 
 
 
 
 
 
191
  if best_path:
192
+ draw_path_3d(G, best_path, 'best_path.png', 'blue')
 
193
  if worst_path:
194
+ draw_path_3d(G, worst_path, 'worst_path.png', 'red')
195
+ if longest_path:
196
+ draw_path_3d(G, longest_path, 'longest_duration_path.png', 'green')
197
+ if shortest_path:
198
+ draw_path_3d(G, shortest_path, 'shortest_duration_path.png', 'purple')
 
 
 
199
 
200
 
201
  if __name__ == "__main__":
 
205
  mode = sys.argv[1]
206
  input_file = sys.argv[2] if len(sys.argv) > 2 else None
207
  main(mode, input_file)
 
 
requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
- huggingface_hub==0.22.2
2
  torch
 
3
  transformers
4
- accelerate
 
 
1
+ gradio
2
  torch
3
+ huggingface_hub
4
  transformers
5
+ accelerate
6
+ bitsandbytes