BioGeek commited on
Commit
44b2355
·
1 Parent(s): e649e86

feat: add citation

Browse files
Files changed (1) hide show
  1. app.py +246 -133
app.py CHANGED
@@ -26,13 +26,15 @@ except ImportError as e:
26
  raise ImportError("Failed to import InstaNovo components: {e}")
27
 
28
  # --- Configuration ---
29
- MODEL_ID = "instanovo-v1.1.0" # Use the desired pretrained model ID
30
  KNAPSACK_DIR = Path("./knapsack_cache")
31
- DEFAULT_CONFIG_PATH = Path("./configs/inference/default.yaml") # Assuming instanovo installs configs locally relative to execution
 
 
32
 
33
  # Determine device
34
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
35
- FP16 = DEVICE == "cuda" # Enable FP16 only on CUDA
36
 
37
  # --- Global Variables (Load Model and Knapsack Once) ---
38
  MODEL: InstaNovo | None = None
@@ -41,7 +43,7 @@ MODEL_CONFIG: DictConfig | None = None
41
  RESIDUE_SET: ResidueSet | None = None
42
 
43
  # Assets
44
- gr.set_static_paths(paths=[Path.cwd().absolute()/"assets"])
45
 
46
 
47
  def load_model_and_knapsack():
@@ -64,9 +66,9 @@ def load_model_and_knapsack():
64
 
65
  # --- Knapsack Handling ---
66
  knapsack_exists = (
67
- (KNAPSACK_DIR / "parameters.pkl").exists() and
68
- (KNAPSACK_DIR / "masses.npy").exists() and
69
- (KNAPSACK_DIR / "chart.npy").exists()
70
  )
71
 
72
  if knapsack_exists:
@@ -76,51 +78,61 @@ def load_model_and_knapsack():
76
  print("Knapsack loaded successfully.")
77
  except Exception as e:
78
  print(f"Error loading knapsack: {e}. Will attempt to regenerate.")
79
- KNAPSACK = None # Force regeneration
80
- knapsack_exists = False # Ensure generation happens
81
 
82
  if not knapsack_exists:
83
  print("Knapsack not found or failed to load. Generating knapsack...")
84
  if RESIDUE_SET is None:
85
- raise gr.Error("Cannot generate knapsack because ResidueSet failed to load.")
 
 
86
  try:
87
  # Prepare residue masses for knapsack generation (handle negative/zero masses)
88
  residue_masses_knapsack = dict(RESIDUE_SET.residue_masses.copy())
89
- negative_residues = [k for k, v in residue_masses_knapsack.items() if v <= 0]
 
 
90
  if negative_residues:
91
- print(f"Warning: Non-positive masses found in residues: {negative_residues}. "
92
- "Excluding from knapsack generation.")
 
 
93
  for res in negative_residues:
94
  del residue_masses_knapsack[res]
95
  # Remove special tokens explicitly if they somehow got mass
96
  for special_token in RESIDUE_SET.special_tokens:
97
- if special_token in residue_masses_knapsack:
98
- del residue_masses_knapsack[special_token]
99
 
100
  # Ensure residue indices used match those without special/negative masses
101
  valid_residue_indices = {
102
- res: idx for res, idx in RESIDUE_SET.residue_to_index.items()
 
103
  if res in residue_masses_knapsack
104
  }
105
 
106
-
107
  KNAPSACK = Knapsack.construct_knapsack(
108
  residue_masses=residue_masses_knapsack,
109
- residue_indices=valid_residue_indices, # Use only valid indices
110
  max_mass=MAX_MASS,
111
  mass_scale=MASS_SCALE,
112
  )
113
  print(f"Knapsack generated. Saving to {KNAPSACK_DIR}...")
114
- KNAPSACK.save(str(KNAPSACK_DIR)) # Save for future runs
115
  print("Knapsack saved.")
116
  except Exception as e:
117
  print(f"Error generating or saving knapsack: {e}")
118
- gr.Warning("Failed to generate Knapsack. Knapsack Beam Search will not be available. {e}")
119
- KNAPSACK = None # Ensure it's None if generation failed
 
 
 
120
 
121
  # Load the model and knapsack when the script starts
122
  load_model_and_knapsack()
123
 
 
124
  def create_inference_config(
125
  input_path: str,
126
  output_path: str,
@@ -129,53 +141,72 @@ def create_inference_config(
129
  """Creates the OmegaConf DictConfig needed for prediction."""
130
  # Load default config if available, otherwise create from scratch
131
  if DEFAULT_CONFIG_PATH.exists():
132
- base_cfg = OmegaConf.load(DEFAULT_CONFIG_PATH)
133
  else:
134
- print(f"Warning: Default config not found at {DEFAULT_CONFIG_PATH}. Using minimal config.")
135
- # Create a minimal config if default is missing
136
- base_cfg = OmegaConf.create({
137
- "data_path": None,
138
- "instanovo_model": MODEL_ID,
139
- "output_path": None,
140
- "knapsack_path": str(KNAPSACK_DIR),
141
- "denovo": True,
142
- "refine": False, # Not doing refinement here
143
- "num_beams": 1,
144
- "max_length": 40,
145
- "max_charge": 10,
146
- "isotope_error_range": [0, 1],
147
- "subset": 1.0,
148
- "use_knapsack": False,
149
- "save_beams": False,
150
- "batch_size": 64, # Adjust as needed
151
- "device": DEVICE,
152
- "fp16": FP16,
153
- "log_interval": 500, # Less relevant for Gradio app
154
- "use_basic_logging": True,
155
- "filter_precursor_ppm": 20,
156
- "filter_confidence": 1e-4,
157
- "filter_fdr_threshold": 0.05,
158
- "residue_remapping": { # Add default mappings
159
- "M(ox)": "M[UNIMOD:35]", "M(+15.99)": "M[UNIMOD:35]",
160
- "S(p)": "S[UNIMOD:21]", "T(p)": "T[UNIMOD:21]", "Y(p)": "Y[UNIMOD:21]",
161
- "S(+79.97)": "S[UNIMOD:21]", "T(+79.97)": "T[UNIMOD:21]", "Y(+79.97)": "Y[UNIMOD:21]",
162
- "Q(+0.98)": "Q[UNIMOD:7]", "N(+0.98)": "N[UNIMOD:7]",
163
- "Q(+.98)": "Q[UNIMOD:7]", "N(+.98)": "N[UNIMOD:7]",
164
- "C(+57.02)": "C[UNIMOD:4]",
165
- "(+42.01)": "[UNIMOD:1]", "(+43.01)": "[UNIMOD:5]", "(-17.03)": "[UNIMOD:385]",
166
- },
167
- "column_map": { # Add default mappings
168
- "Modified sequence": "modified_sequence", "MS/MS m/z": "precursor_mz",
169
- "Mass": "precursor_mass", "Charge": "precursor_charge",
170
- "Mass values": "mz_array", "Mass spectrum": "mz_array",
171
- "Intensity": "intensity_array", "Raw intensity spectrum": "intensity_array",
172
- "Scan number": "scan_number"
173
- },
174
- "index_columns": [
175
- "scan_number", "precursor_mz", "precursor_charge",
176
- ],
177
- # Add other defaults if needed based on errors
178
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  # Override specific parameters
181
  cfg_overrides = {
@@ -192,7 +223,9 @@ def create_inference_config(
192
  cfg_overrides["use_knapsack"] = False
193
  elif "Knapsack" in decoding_method:
194
  if KNAPSACK is None:
195
- raise gr.Error("Knapsack is not available. Cannot use Knapsack Beam Search.")
 
 
196
  cfg_overrides["num_beams"] = 5
197
  cfg_overrides["use_knapsack"] = True
198
  cfg_overrides["knapsack_path"] = str(KNAPSACK_DIR)
@@ -209,14 +242,14 @@ def predict_peptides(input_file, decoding_method):
209
  Main function to load data, run prediction, and return results.
210
  """
211
  if MODEL is None or RESIDUE_SET is None or MODEL_CONFIG is None:
212
- load_model_and_knapsack() # Attempt to reload if None (e.g., after space restart)
213
- if MODEL is None:
214
- raise gr.Error("InstaNovo model is not loaded. Cannot perform prediction.")
215
 
216
  if input_file is None:
217
  raise gr.Error("Please upload a mass spectrometry file.")
218
 
219
- input_path = input_file.name # Gradio provides the path in .name
220
  print(f"Processing file: {input_path}")
221
  print(f"Using decoding method: {decoding_method}")
222
 
@@ -234,23 +267,28 @@ def predict_peptides(input_file, decoding_method):
234
  try:
235
  sdf = SpectrumDataFrame.load(
236
  config.data_path,
237
- lazy=False, # Load eagerly for Gradio simplicity
238
- is_annotated=False, # De novo mode
239
  column_mapping=config.get("column_map", None),
240
  shuffle=False,
241
- verbose=True # Print loading logs
242
  )
243
  # Apply charge filter like in CLI
244
  original_size = len(sdf)
245
  max_charge = config.get("max_charge", 10)
246
  sdf.filter_rows(
247
- lambda row: (row["precursor_charge"] <= max_charge) and (row["precursor_charge"] > 0)
 
248
  )
249
  if len(sdf) < original_size:
250
- print(f"Warning: Filtered {original_size - len(sdf)} spectra with charge > {max_charge} or <= 0.")
 
 
251
 
252
  if len(sdf) == 0:
253
- raise gr.Error("No valid spectra found in the uploaded file after filtering.")
 
 
254
  print(f"Data loaded: {len(sdf)} spectra.")
255
  except Exception as e:
256
  print(f"Error loading data: {e}")
@@ -261,16 +299,17 @@ def predict_peptides(input_file, decoding_method):
261
  sdf,
262
  RESIDUE_SET,
263
  MODEL_CONFIG.get("n_peaks", 200),
264
- return_str=True, # Needed for greedy/beam search targets later (though not used here)
265
  annotated=False,
266
- pad_spectrum_max_length=config.get("compile_model", False) or config.get("use_flash_attention", False),
 
267
  bin_spectra=config.get("conv_peak_encoder", False),
268
  )
269
  dl = DataLoader(
270
  ds,
271
  batch_size=config.batch_size,
272
- num_workers=0, # Required by SpectrumDataFrame
273
- shuffle=False, # Required by SpectrumDataFrame
274
  collate_fn=collate_batch,
275
  )
276
 
@@ -279,38 +318,51 @@ def predict_peptides(input_file, decoding_method):
279
  decoder: Decoder
280
  if config.use_knapsack:
281
  if KNAPSACK is None:
282
- # This check should ideally be earlier, but double-check
283
- raise gr.Error("Knapsack is required for Knapsack Beam Search but is not available.")
 
 
284
  # KnapsackBeamSearchDecoder doesn't directly load from path in this version?
285
  # We load Knapsack globally, so just pass it.
286
  # If it needed path: decoder = KnapsackBeamSearchDecoder.from_file(model=MODEL, path=config.knapsack_path)
287
  decoder = KnapsackBeamSearchDecoder(model=MODEL, knapsack=KNAPSACK)
288
  elif config.num_beams > 1:
289
- # BeamSearchDecoder is available but not explicitly requested, use Greedy for num_beams=1
290
- print(f"Warning: num_beams={config.num_beams} > 1 but only Greedy and Knapsack Beam Search are implemented in this app. Defaulting to Greedy.")
291
- decoder = GreedyDecoder(model=MODEL, mass_scale=MASS_SCALE)
 
 
292
  else:
293
- decoder = GreedyDecoder(
294
- model=MODEL,
295
- mass_scale=MASS_SCALE,
296
- # Add suppression options if needed from config
297
- suppressed_residues=config.get("suppressed_residues", None),
298
- disable_terminal_residues_anywhere=config.get("disable_terminal_residues_anywhere", True),
299
- )
 
 
300
  print(f"Using decoder: {type(decoder).__name__}")
301
 
302
  # 5. Run Prediction Loop (Adapted from instanovo/transformer/predict.py)
303
  print("Starting prediction...")
304
  start_time = time.time()
305
- results_list: list[ScoredSequence | list] = [] # Store ScoredSequence or empty list
 
 
306
 
307
  for i, batch in enumerate(dl):
308
- spectra, precursors, spectra_mask, _, _ = batch # Ignore peptides/masks for de novo
 
 
309
  spectra = spectra.to(DEVICE)
310
  precursors = precursors.to(DEVICE)
311
  spectra_mask = spectra_mask.to(DEVICE)
312
 
313
- with torch.no_grad(), torch.amp.autocast(DEVICE, dtype=torch.float16, enabled=FP16):
 
 
 
314
  # Beam search decoder might return list[list[ScoredSequence]] if return_beam=True
315
  # Greedy decoder returns list[ScoredSequence]
316
  # KnapsackBeamSearchDecoder returns list[ScoredSequence] or list[list[ScoredSequence]]
@@ -320,12 +372,17 @@ def predict_peptides(input_file, decoding_method):
320
  beam_size=config.num_beams,
321
  max_length=config.max_length,
322
  # Knapsack/Beam Search specific params if needed
323
- mass_tolerance=config.get("filter_precursor_ppm", 20) * 1e-6, # Convert ppm to relative
324
- max_isotope=config.isotope_error_range[1] if config.isotope_error_range else 1,
325
- return_beam=False # Only get the top prediction for simplicity
 
 
 
326
  )
327
- results_list.extend(batch_predictions) # Should be list[ScoredSequence] or list[list]
328
- print(f"Processed batch {i+1}/{len(dl)}")
 
 
329
 
330
  end_time = time.time()
331
  print(f"Prediction finished in {end_time - start_time:.2f} seconds.")
@@ -335,29 +392,35 @@ def predict_peptides(input_file, decoding_method):
335
  output_data = []
336
  # Use sdf index columns + prediction results
337
  index_cols = [col for col in config.index_columns if col in sdf.df.columns]
338
- base_df_pd = sdf.df.select(index_cols).to_pandas() # Get base info
339
 
340
  metrics_calc = Metrics(RESIDUE_SET, config.isotope_error_range)
341
 
342
  for i, res in enumerate(results_list):
343
- row_data = base_df_pd.iloc[i].to_dict() # Get corresponding input data
344
  if isinstance(res, ScoredSequence) and res.sequence:
345
  sequence_str = "".join(res.sequence)
346
  row_data["prediction"] = sequence_str
347
  row_data["log_probability"] = f"{res.sequence_log_probability:.4f}"
348
  # Use metrics to calculate delta mass ppm for the top prediction
349
  try:
350
- _, delta_mass_list = metrics_calc.matches_precursor(
351
- res.sequence,
352
- row_data["precursor_mz"],
353
- row_data["precursor_charge"]
354
- )
355
- # Find the smallest absolute ppm error across isotopes
356
- min_abs_ppm = min(abs(p) for p in delta_mass_list) if delta_mass_list else float('nan')
357
- row_data["delta_mass_ppm"] = f"{min_abs_ppm:.2f}"
 
 
 
 
358
  except Exception as e:
359
- print(f"Warning: Could not calculate delta mass for prediction {i}: {e}")
360
- row_data["delta_mass_ppm"] = "N/A"
 
 
361
 
362
  else:
363
  row_data["prediction"] = ""
@@ -368,13 +431,20 @@ def predict_peptides(input_file, decoding_method):
368
  output_df = pl.DataFrame(output_data)
369
 
370
  # Ensure specific columns are present and ordered
371
- display_cols = ["scan_number", "precursor_mz", "precursor_charge", "prediction", "log_probability", "delta_mass_ppm"]
 
 
 
 
 
 
 
372
  final_display_cols = []
373
  for col in display_cols:
374
  if col in output_df.columns:
375
  final_display_cols.append(col)
376
  else:
377
- print(f"Warning: Expected display column '{col}' not found in results.")
378
 
379
  # Add any remaining index columns that weren't in display_cols
380
  for col in index_cols:
@@ -383,7 +453,6 @@ def predict_peptides(input_file, decoding_method):
383
 
384
  output_df_display = output_df.select(final_display_cols)
385
 
386
-
387
  # 7. Save full results to CSV
388
  print(f"Saving results to {output_csv_path}...")
389
  output_df.write_csv(output_csv_path)
@@ -399,6 +468,7 @@ def predict_peptides(input_file, decoding_method):
399
  # Re-raise as Gradio error
400
  raise gr.Error(f"Prediction failed: {e}")
401
 
 
402
  # --- Gradio Interface ---
403
  css = """
404
  .gradio-container { font-family: sans-serif; }
@@ -408,7 +478,9 @@ footer { display: none !important; }
408
  .logo-container img { margin-bottom: 1rem; }
409
  """
410
 
411
- with gr.Blocks(css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")) as demo:
 
 
412
  # --- Logo Display ---
413
  gr.Markdown(
414
  """
@@ -416,7 +488,7 @@ with gr.Blocks(css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hu
416
  <img src='/gradio_api/file=assets/instanovo.svg' alt="InstaNovo Logo" width="300" style="display: block; margin: 0 auto;">
417
  </div>
418
  """,
419
- elem_classes="logo-container" # Optional class for CSS targeting
420
  )
421
 
422
  # --- App Content ---
@@ -431,36 +503,55 @@ with gr.Blocks(css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hu
431
  with gr.Column(scale=1):
432
  input_file = gr.File(
433
  label="Upload Mass Spectrometry File (.mgf, .mzml, .mzxml)",
434
- file_types=[".mgf", ".mzml", ".mzxml"]
435
  )
436
  decoding_method = gr.Radio(
437
- ["Greedy Search (Fast, resonably accurate)", "Knapsack Beam Search (More accurate, but slower)"],
 
 
 
438
  label="Decoding Method",
439
- value="Greedy Search (Fast, resonably accurate)" # Default to fast method
440
  )
441
  submit_btn = gr.Button("Predict Sequences", variant="primary")
442
  with gr.Column(scale=2):
443
- output_df = gr.DataFrame(label="Prediction Results", headers=["scan_number", "precursor_mz", "precursor_charge", "prediction", "log_probability", "delta_mass_ppm"], wrap=True)
 
 
 
 
 
 
 
 
 
 
 
444
  output_file = gr.File(label="Download Full Results (CSV)")
445
 
446
  submit_btn.click(
447
  predict_peptides,
448
  inputs=[input_file, decoding_method],
449
- outputs=[output_df, output_file]
450
  )
451
 
452
  gr.Examples(
453
- [["assets/sample_spectra.mgf", "Greedy Search (Fast, resonably accurate)" ],
454
- ["assets/sample_spectra.mgf", "Knapsack Beam Search (More accurate, but slower)" ]],
455
- inputs=[input_file, decoding_method],
456
- outputs=[output_df, output_file],
457
- fn=predict_peptides,
458
- cache_examples=False, # Re-run examples if needed
459
- label="Example Usage"
 
 
 
 
 
460
  )
461
 
462
  gr.Markdown(
463
- """
464
  **Notes:**
465
  * Predictions are based on the [InstaNovo](https://github.com/instadeepai/InstaNovo) model ({MODEL_ID}).
466
  * Knapsack Beam Search uses pre-calculated mass constraints and yields better results but takes longer.
@@ -469,10 +560,32 @@ with gr.Blocks(css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hu
469
  """.format(MODEL_ID=MODEL_ID)
470
  )
471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
  # --- Launch the App ---
473
  if __name__ == "__main__":
474
  # Set share=True for temporary public link if running locally
475
  # Set server_name="0.0.0.0" to allow access from network if needed
476
  # demo.launch(server_name="0.0.0.0", server_port=7860)
477
  # For Hugging Face Spaces, just demo.launch() is usually sufficient
478
- demo.launch(share=True) # For local testing with public URL
 
26
  raise ImportError("Failed to import InstaNovo components: {e}")
27
 
28
  # --- Configuration ---
29
+ MODEL_ID = "instanovo-v1.1.0" # Use the desired pretrained model ID
30
  KNAPSACK_DIR = Path("./knapsack_cache")
31
+ DEFAULT_CONFIG_PATH = Path(
32
+ "./configs/inference/default.yaml"
33
+ ) # Assuming instanovo installs configs locally relative to execution
34
 
35
  # Determine device
36
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
37
+ FP16 = DEVICE == "cuda" # Enable FP16 only on CUDA
38
 
39
  # --- Global Variables (Load Model and Knapsack Once) ---
40
  MODEL: InstaNovo | None = None
 
43
  RESIDUE_SET: ResidueSet | None = None
44
 
45
  # Assets
46
+ gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"])
47
 
48
 
49
  def load_model_and_knapsack():
 
66
 
67
  # --- Knapsack Handling ---
68
  knapsack_exists = (
69
+ (KNAPSACK_DIR / "parameters.pkl").exists()
70
+ and (KNAPSACK_DIR / "masses.npy").exists()
71
+ and (KNAPSACK_DIR / "chart.npy").exists()
72
  )
73
 
74
  if knapsack_exists:
 
78
  print("Knapsack loaded successfully.")
79
  except Exception as e:
80
  print(f"Error loading knapsack: {e}. Will attempt to regenerate.")
81
+ KNAPSACK = None # Force regeneration
82
+ knapsack_exists = False # Ensure generation happens
83
 
84
  if not knapsack_exists:
85
  print("Knapsack not found or failed to load. Generating knapsack...")
86
  if RESIDUE_SET is None:
87
+ raise gr.Error(
88
+ "Cannot generate knapsack because ResidueSet failed to load."
89
+ )
90
  try:
91
  # Prepare residue masses for knapsack generation (handle negative/zero masses)
92
  residue_masses_knapsack = dict(RESIDUE_SET.residue_masses.copy())
93
+ negative_residues = [
94
+ k for k, v in residue_masses_knapsack.items() if v <= 0
95
+ ]
96
  if negative_residues:
97
+ print(
98
+ f"Warning: Non-positive masses found in residues: {negative_residues}. "
99
+ "Excluding from knapsack generation."
100
+ )
101
  for res in negative_residues:
102
  del residue_masses_knapsack[res]
103
  # Remove special tokens explicitly if they somehow got mass
104
  for special_token in RESIDUE_SET.special_tokens:
105
+ if special_token in residue_masses_knapsack:
106
+ del residue_masses_knapsack[special_token]
107
 
108
  # Ensure residue indices used match those without special/negative masses
109
  valid_residue_indices = {
110
+ res: idx
111
+ for res, idx in RESIDUE_SET.residue_to_index.items()
112
  if res in residue_masses_knapsack
113
  }
114
 
 
115
  KNAPSACK = Knapsack.construct_knapsack(
116
  residue_masses=residue_masses_knapsack,
117
+ residue_indices=valid_residue_indices, # Use only valid indices
118
  max_mass=MAX_MASS,
119
  mass_scale=MASS_SCALE,
120
  )
121
  print(f"Knapsack generated. Saving to {KNAPSACK_DIR}...")
122
+ KNAPSACK.save(str(KNAPSACK_DIR)) # Save for future runs
123
  print("Knapsack saved.")
124
  except Exception as e:
125
  print(f"Error generating or saving knapsack: {e}")
126
+ gr.Warning(
127
+ "Failed to generate Knapsack. Knapsack Beam Search will not be available. {e}"
128
+ )
129
+ KNAPSACK = None # Ensure it's None if generation failed
130
+
131
 
132
  # Load the model and knapsack when the script starts
133
  load_model_and_knapsack()
134
 
135
+
136
  def create_inference_config(
137
  input_path: str,
138
  output_path: str,
 
141
  """Creates the OmegaConf DictConfig needed for prediction."""
142
  # Load default config if available, otherwise create from scratch
143
  if DEFAULT_CONFIG_PATH.exists():
144
+ base_cfg = OmegaConf.load(DEFAULT_CONFIG_PATH)
145
  else:
146
+ print(
147
+ f"Warning: Default config not found at {DEFAULT_CONFIG_PATH}. Using minimal config."
148
+ )
149
+ # Create a minimal config if default is missing
150
+ base_cfg = OmegaConf.create(
151
+ {
152
+ "data_path": None,
153
+ "instanovo_model": MODEL_ID,
154
+ "output_path": None,
155
+ "knapsack_path": str(KNAPSACK_DIR),
156
+ "denovo": True,
157
+ "refine": False, # Not doing refinement here
158
+ "num_beams": 1,
159
+ "max_length": 40,
160
+ "max_charge": 10,
161
+ "isotope_error_range": [0, 1],
162
+ "subset": 1.0,
163
+ "use_knapsack": False,
164
+ "save_beams": False,
165
+ "batch_size": 64, # Adjust as needed
166
+ "device": DEVICE,
167
+ "fp16": FP16,
168
+ "log_interval": 500, # Less relevant for Gradio app
169
+ "use_basic_logging": True,
170
+ "filter_precursor_ppm": 20,
171
+ "filter_confidence": 1e-4,
172
+ "filter_fdr_threshold": 0.05,
173
+ "residue_remapping": { # Add default mappings
174
+ "M(ox)": "M[UNIMOD:35]",
175
+ "M(+15.99)": "M[UNIMOD:35]",
176
+ "S(p)": "S[UNIMOD:21]",
177
+ "T(p)": "T[UNIMOD:21]",
178
+ "Y(p)": "Y[UNIMOD:21]",
179
+ "S(+79.97)": "S[UNIMOD:21]",
180
+ "T(+79.97)": "T[UNIMOD:21]",
181
+ "Y(+79.97)": "Y[UNIMOD:21]",
182
+ "Q(+0.98)": "Q[UNIMOD:7]",
183
+ "N(+0.98)": "N[UNIMOD:7]",
184
+ "Q(+.98)": "Q[UNIMOD:7]",
185
+ "N(+.98)": "N[UNIMOD:7]",
186
+ "C(+57.02)": "C[UNIMOD:4]",
187
+ "(+42.01)": "[UNIMOD:1]",
188
+ "(+43.01)": "[UNIMOD:5]",
189
+ "(-17.03)": "[UNIMOD:385]",
190
+ },
191
+ "column_map": { # Add default mappings
192
+ "Modified sequence": "modified_sequence",
193
+ "MS/MS m/z": "precursor_mz",
194
+ "Mass": "precursor_mass",
195
+ "Charge": "precursor_charge",
196
+ "Mass values": "mz_array",
197
+ "Mass spectrum": "mz_array",
198
+ "Intensity": "intensity_array",
199
+ "Raw intensity spectrum": "intensity_array",
200
+ "Scan number": "scan_number",
201
+ },
202
+ "index_columns": [
203
+ "scan_number",
204
+ "precursor_mz",
205
+ "precursor_charge",
206
+ ],
207
+ # Add other defaults if needed based on errors
208
+ }
209
+ )
210
 
211
  # Override specific parameters
212
  cfg_overrides = {
 
223
  cfg_overrides["use_knapsack"] = False
224
  elif "Knapsack" in decoding_method:
225
  if KNAPSACK is None:
226
+ raise gr.Error(
227
+ "Knapsack is not available. Cannot use Knapsack Beam Search."
228
+ )
229
  cfg_overrides["num_beams"] = 5
230
  cfg_overrides["use_knapsack"] = True
231
  cfg_overrides["knapsack_path"] = str(KNAPSACK_DIR)
 
242
  Main function to load data, run prediction, and return results.
243
  """
244
  if MODEL is None or RESIDUE_SET is None or MODEL_CONFIG is None:
245
+ load_model_and_knapsack() # Attempt to reload if None (e.g., after space restart)
246
+ if MODEL is None:
247
+ raise gr.Error("InstaNovo model is not loaded. Cannot perform prediction.")
248
 
249
  if input_file is None:
250
  raise gr.Error("Please upload a mass spectrometry file.")
251
 
252
+ input_path = input_file.name # Gradio provides the path in .name
253
  print(f"Processing file: {input_path}")
254
  print(f"Using decoding method: {decoding_method}")
255
 
 
267
  try:
268
  sdf = SpectrumDataFrame.load(
269
  config.data_path,
270
+ lazy=False, # Load eagerly for Gradio simplicity
271
+ is_annotated=False, # De novo mode
272
  column_mapping=config.get("column_map", None),
273
  shuffle=False,
274
+ verbose=True, # Print loading logs
275
  )
276
  # Apply charge filter like in CLI
277
  original_size = len(sdf)
278
  max_charge = config.get("max_charge", 10)
279
  sdf.filter_rows(
280
+ lambda row: (row["precursor_charge"] <= max_charge)
281
+ and (row["precursor_charge"] > 0)
282
  )
283
  if len(sdf) < original_size:
284
+ print(
285
+ f"Warning: Filtered {original_size - len(sdf)} spectra with charge > {max_charge} or <= 0."
286
+ )
287
 
288
  if len(sdf) == 0:
289
+ raise gr.Error(
290
+ "No valid spectra found in the uploaded file after filtering."
291
+ )
292
  print(f"Data loaded: {len(sdf)} spectra.")
293
  except Exception as e:
294
  print(f"Error loading data: {e}")
 
299
  sdf,
300
  RESIDUE_SET,
301
  MODEL_CONFIG.get("n_peaks", 200),
302
+ return_str=True, # Needed for greedy/beam search targets later (though not used here)
303
  annotated=False,
304
+ pad_spectrum_max_length=config.get("compile_model", False)
305
+ or config.get("use_flash_attention", False),
306
  bin_spectra=config.get("conv_peak_encoder", False),
307
  )
308
  dl = DataLoader(
309
  ds,
310
  batch_size=config.batch_size,
311
+ num_workers=0, # Required by SpectrumDataFrame
312
+ shuffle=False, # Required by SpectrumDataFrame
313
  collate_fn=collate_batch,
314
  )
315
 
 
318
  decoder: Decoder
319
  if config.use_knapsack:
320
  if KNAPSACK is None:
321
+ # This check should ideally be earlier, but double-check
322
+ raise gr.Error(
323
+ "Knapsack is required for Knapsack Beam Search but is not available."
324
+ )
325
  # KnapsackBeamSearchDecoder doesn't directly load from path in this version?
326
  # We load Knapsack globally, so just pass it.
327
  # If it needed path: decoder = KnapsackBeamSearchDecoder.from_file(model=MODEL, path=config.knapsack_path)
328
  decoder = KnapsackBeamSearchDecoder(model=MODEL, knapsack=KNAPSACK)
329
  elif config.num_beams > 1:
330
+ # BeamSearchDecoder is available but not explicitly requested, use Greedy for num_beams=1
331
+ print(
332
+ f"Warning: num_beams={config.num_beams} > 1 but only Greedy and Knapsack Beam Search are implemented in this app. Defaulting to Greedy."
333
+ )
334
+ decoder = GreedyDecoder(model=MODEL, mass_scale=MASS_SCALE)
335
  else:
336
+ decoder = GreedyDecoder(
337
+ model=MODEL,
338
+ mass_scale=MASS_SCALE,
339
+ # Add suppression options if needed from config
340
+ suppressed_residues=config.get("suppressed_residues", None),
341
+ disable_terminal_residues_anywhere=config.get(
342
+ "disable_terminal_residues_anywhere", True
343
+ ),
344
+ )
345
  print(f"Using decoder: {type(decoder).__name__}")
346
 
347
  # 5. Run Prediction Loop (Adapted from instanovo/transformer/predict.py)
348
  print("Starting prediction...")
349
  start_time = time.time()
350
+ results_list: list[
351
+ ScoredSequence | list
352
+ ] = [] # Store ScoredSequence or empty list
353
 
354
  for i, batch in enumerate(dl):
355
+ spectra, precursors, spectra_mask, _, _ = (
356
+ batch # Ignore peptides/masks for de novo
357
+ )
358
  spectra = spectra.to(DEVICE)
359
  precursors = precursors.to(DEVICE)
360
  spectra_mask = spectra_mask.to(DEVICE)
361
 
362
+ with (
363
+ torch.no_grad(),
364
+ torch.amp.autocast(DEVICE, dtype=torch.float16, enabled=FP16),
365
+ ):
366
  # Beam search decoder might return list[list[ScoredSequence]] if return_beam=True
367
  # Greedy decoder returns list[ScoredSequence]
368
  # KnapsackBeamSearchDecoder returns list[ScoredSequence] or list[list[ScoredSequence]]
 
372
  beam_size=config.num_beams,
373
  max_length=config.max_length,
374
  # Knapsack/Beam Search specific params if needed
375
+ mass_tolerance=config.get("filter_precursor_ppm", 20)
376
+ * 1e-6, # Convert ppm to relative
377
+ max_isotope=config.isotope_error_range[1]
378
+ if config.isotope_error_range
379
+ else 1,
380
+ return_beam=False, # Only get the top prediction for simplicity
381
  )
382
+ results_list.extend(
383
+ batch_predictions
384
+ ) # Should be list[ScoredSequence] or list[list]
385
+ print(f"Processed batch {i + 1}/{len(dl)}")
386
 
387
  end_time = time.time()
388
  print(f"Prediction finished in {end_time - start_time:.2f} seconds.")
 
392
  output_data = []
393
  # Use sdf index columns + prediction results
394
  index_cols = [col for col in config.index_columns if col in sdf.df.columns]
395
+ base_df_pd = sdf.df.select(index_cols).to_pandas() # Get base info
396
 
397
  metrics_calc = Metrics(RESIDUE_SET, config.isotope_error_range)
398
 
399
  for i, res in enumerate(results_list):
400
+ row_data = base_df_pd.iloc[i].to_dict() # Get corresponding input data
401
  if isinstance(res, ScoredSequence) and res.sequence:
402
  sequence_str = "".join(res.sequence)
403
  row_data["prediction"] = sequence_str
404
  row_data["log_probability"] = f"{res.sequence_log_probability:.4f}"
405
  # Use metrics to calculate delta mass ppm for the top prediction
406
  try:
407
+ _, delta_mass_list = metrics_calc.matches_precursor(
408
+ res.sequence,
409
+ row_data["precursor_mz"],
410
+ row_data["precursor_charge"],
411
+ )
412
+ # Find the smallest absolute ppm error across isotopes
413
+ min_abs_ppm = (
414
+ min(abs(p) for p in delta_mass_list)
415
+ if delta_mass_list
416
+ else float("nan")
417
+ )
418
+ row_data["delta_mass_ppm"] = f"{min_abs_ppm:.2f}"
419
  except Exception as e:
420
+ print(
421
+ f"Warning: Could not calculate delta mass for prediction {i}: {e}"
422
+ )
423
+ row_data["delta_mass_ppm"] = "N/A"
424
 
425
  else:
426
  row_data["prediction"] = ""
 
431
  output_df = pl.DataFrame(output_data)
432
 
433
  # Ensure specific columns are present and ordered
434
+ display_cols = [
435
+ "scan_number",
436
+ "precursor_mz",
437
+ "precursor_charge",
438
+ "prediction",
439
+ "log_probability",
440
+ "delta_mass_ppm",
441
+ ]
442
  final_display_cols = []
443
  for col in display_cols:
444
  if col in output_df.columns:
445
  final_display_cols.append(col)
446
  else:
447
+ print(f"Warning: Expected display column '{col}' not found in results.")
448
 
449
  # Add any remaining index columns that weren't in display_cols
450
  for col in index_cols:
 
453
 
454
  output_df_display = output_df.select(final_display_cols)
455
 
 
456
  # 7. Save full results to CSV
457
  print(f"Saving results to {output_csv_path}...")
458
  output_df.write_csv(output_csv_path)
 
468
  # Re-raise as Gradio error
469
  raise gr.Error(f"Prediction failed: {e}")
470
 
471
+
472
  # --- Gradio Interface ---
473
  css = """
474
  .gradio-container { font-family: sans-serif; }
 
478
  .logo-container img { margin-bottom: 1rem; }
479
  """
480
 
481
+ with gr.Blocks(
482
+ css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")
483
+ ) as demo:
484
  # --- Logo Display ---
485
  gr.Markdown(
486
  """
 
488
  <img src='/gradio_api/file=assets/instanovo.svg' alt="InstaNovo Logo" width="300" style="display: block; margin: 0 auto;">
489
  </div>
490
  """,
491
+ elem_classes="logo-container", # Optional class for CSS targeting
492
  )
493
 
494
  # --- App Content ---
 
503
  with gr.Column(scale=1):
504
  input_file = gr.File(
505
  label="Upload Mass Spectrometry File (.mgf, .mzml, .mzxml)",
506
+ file_types=[".mgf", ".mzml", ".mzxml"],
507
  )
508
  decoding_method = gr.Radio(
509
+ [
510
+ "Greedy Search (Fast, resonably accurate)",
511
+ "Knapsack Beam Search (More accurate, but slower)",
512
+ ],
513
  label="Decoding Method",
514
+ value="Greedy Search (Fast, resonably accurate)", # Default to fast method
515
  )
516
  submit_btn = gr.Button("Predict Sequences", variant="primary")
517
  with gr.Column(scale=2):
518
+ output_df = gr.DataFrame(
519
+ label="Prediction Results",
520
+ headers=[
521
+ "scan_number",
522
+ "precursor_mz",
523
+ "precursor_charge",
524
+ "prediction",
525
+ "log_probability",
526
+ "delta_mass_ppm",
527
+ ],
528
+ wrap=True,
529
+ )
530
  output_file = gr.File(label="Download Full Results (CSV)")
531
 
532
  submit_btn.click(
533
  predict_peptides,
534
  inputs=[input_file, decoding_method],
535
+ outputs=[output_df, output_file],
536
  )
537
 
538
  gr.Examples(
539
+ [
540
+ ["assets/sample_spectra.mgf", "Greedy Search (Fast, resonably accurate)"],
541
+ [
542
+ "assets/sample_spectra.mgf",
543
+ "Knapsack Beam Search (More accurate, but slower)",
544
+ ],
545
+ ],
546
+ inputs=[input_file, decoding_method],
547
+ outputs=[output_df, output_file],
548
+ fn=predict_peptides,
549
+ cache_examples=False, # Re-run examples if needed
550
+ label="Example Usage",
551
  )
552
 
553
  gr.Markdown(
554
+ """
555
  **Notes:**
556
  * Predictions are based on the [InstaNovo](https://github.com/instadeepai/InstaNovo) model ({MODEL_ID}).
557
  * Knapsack Beam Search uses pre-calculated mass constraints and yields better results but takes longer.
 
560
  """.format(MODEL_ID=MODEL_ID)
561
  )
562
 
563
+ gr.Textbox(
564
+ value="""
565
+ @article{eloff_kalogeropoulos_2025_instanovo,
566
+ title = {InstaNovo enables diffusion-powered de novo peptide sequencing in large-scale proteomics experiments},
567
+ author = {Kevin Eloff and Konstantinos Kalogeropoulos and Amandla Mabona and Oliver Morell and Rachel Catzel and
568
+ Esperanza Rivera-de-Torre and Jakob Berg Jespersen and Wesley Williams and Sam P. B. van Beljouw and
569
+ Marcin J. Skwark and Andreas Hougaard Laustsen and Stan J. J. Brouns and Anne Ljungars and Erwin M.
570
+ Schoof and Jeroen Van Goey and Ulrich auf dem Keller and Karim Beguir and Nicolas Lopez Carranza and
571
+ Timothy P. Jenkins},
572
+ year = 2025,
573
+ month = {Mar},
574
+ day = 31,
575
+ journal = {Nature Machine Intelligence},
576
+ doi = {10.1038/s42256-025-01019-5},
577
+ url = {https://www.nature.com/articles/s42256-025-01019-5}
578
+ }
579
+ """,
580
+ show_copy_button=True,
581
+ label="If you use InstaNovo in your research, please cite:",
582
+ interactive=False,
583
+ )
584
+
585
  # --- Launch the App ---
586
  if __name__ == "__main__":
587
  # Set share=True for temporary public link if running locally
588
  # Set server_name="0.0.0.0" to allow access from network if needed
589
  # demo.launch(server_name="0.0.0.0", server_port=7860)
590
  # For Hugging Face Spaces, just demo.launch() is usually sufficient
591
+ demo.launch(share=True) # For local testing with public URL