ashu316 commited on
Commit
52ca193
·
verified ·
1 Parent(s): 4e02570

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -22
app.py CHANGED
@@ -112,11 +112,11 @@ def get_random_negative_node(source_node, destination_node, total_nodes):
112
 
113
 
114
  def predict(u_id, i_id, timestamp):
115
- # Convert inputs to batch format
116
- #u_tensor = torch.tensor([int(u_id)], dtype=torch.long) # Single source node, batch size 1
117
- #i_tensor = torch.tensor([int(i_id)], dtype=torch.long) # Single destination node, batch size 1
118
- #ts_tensor = torch.LongTensor([int(timestamp)])
119
 
 
120
  u_array = np.array([int(u_id)]) # List of source nodes
121
  i_array = np.array([int(i_id)]) # List of destination nodes
122
  ts_array = np.array([int(timestamp)]) # List of timestamps
@@ -144,14 +144,8 @@ def predict(u_id, i_id, timestamp):
144
  # Negative node sampling: choose random negative nodes for this example (should be handled in your dataset)
145
  total_nodes = len(full_data.unique_nodes) # or whatever gives your total node count
146
  random_negative_node = get_random_negative_node(u_id, i_id, total_nodes)
147
- #negative_nodes_tensor = torch.tensor([random_negative_node], dtype=torch.long)
148
  negative_nodes_array = np.array([random_negative_node]) # List of negative nodes
149
 
150
- # Make sure you batch them together properly (even if it is just a single edge for now)
151
- #positive_probs, negative_probs = tgn.compute_edge_probabilities(
152
- # u_array, i_array, negative_nodes_array, ts_array, edge_idx_array
153
- #)
154
-
155
  # Call compute_edge_probabilities
156
  # You can pass edge_idxs or dummy edge features here depending on how your TGN expects it
157
  try:
@@ -183,18 +177,6 @@ def predict(u_id, i_id, timestamp):
183
  return f"Top {len(top_values)} predicted interaction probabilities:\n{formatted_result}"
184
 
185
 
186
- #demo = gr.Interface(
187
- # fn=predict,
188
- # inputs=[
189
- # gr.Number(label="Source Node ID"),
190
- # gr.Number(label="Destination Node ID"),
191
- # gr.Number(label="Timestamp"),
192
- # ],
193
- # outputs="text",
194
- # title="🧠 TGN Playground (Wikipedia)",
195
- # description="Enter node IDs and timestamp to predict future interaction probability using Temporal Graph Networks (TGN).",
196
- #)
197
-
198
  with gr.Blocks() as demo:
199
  gr.Markdown("## 🧠 TGN Playground (Wikipedia)")
200
  gr.Markdown("Enter node IDs and timestamp to predict future interaction probability using Temporal Graph Networks (TGN).")
 
112
 
113
 
114
  def predict(u_id, i_id, timestamp):
115
+ # Before prediction
116
+ tgn.memory.__init_memory__() # Re-initialize memory
117
+ tgn.embeddings.__init_memory__()
 
118
 
119
+ # Then run prediction
120
  u_array = np.array([int(u_id)]) # List of source nodes
121
  i_array = np.array([int(i_id)]) # List of destination nodes
122
  ts_array = np.array([int(timestamp)]) # List of timestamps
 
144
  # Negative node sampling: choose random negative nodes for this example (should be handled in your dataset)
145
  total_nodes = len(full_data.unique_nodes) # or whatever gives your total node count
146
  random_negative_node = get_random_negative_node(u_id, i_id, total_nodes)
 
147
  negative_nodes_array = np.array([random_negative_node]) # List of negative nodes
148
 
 
 
 
 
 
149
  # Call compute_edge_probabilities
150
  # You can pass edge_idxs or dummy edge features here depending on how your TGN expects it
151
  try:
 
177
  return f"Top {len(top_values)} predicted interaction probabilities:\n{formatted_result}"
178
 
179
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  with gr.Blocks() as demo:
181
  gr.Markdown("## 🧠 TGN Playground (Wikipedia)")
182
  gr.Markdown("Enter node IDs and timestamp to predict future interaction probability using Temporal Graph Networks (TGN).")