BioGeek commited on
Commit
e38f067
·
2 Parent(s): 6851f02 44b2355

Merge branch 'citation'

Browse files
Files changed (1) hide show
  1. app.py +140 -63
app.py CHANGED
@@ -28,13 +28,15 @@ except ImportError as e:
28
  raise ImportError(f"Failed to import InstaNovo components: {e}")
29
 
30
  # --- Configuration ---
31
- MODEL_ID = "instanovo-v1.1.0" # Use the desired pretrained model ID
32
  KNAPSACK_DIR = Path("./knapsack_cache")
33
- DEFAULT_CONFIG_PATH = Path("./configs/inference/default.yaml") # Assuming instanovo installs configs locally relative to execution
 
 
34
 
35
  # Determine device
36
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
37
- FP16 = DEVICE == "cuda" # Enable FP16 only on CUDA
38
 
39
  # --- Global Variables (Load Model and Knapsack Once) ---
40
  MODEL: InstaNovo | None = None
@@ -78,9 +80,9 @@ def load_model_and_knapsack():
78
 
79
  # --- Knapsack Handling ---
80
  knapsack_exists = (
81
- (KNAPSACK_DIR / "parameters.pkl").exists() and
82
- (KNAPSACK_DIR / "masses.npy").exists() and
83
- (KNAPSACK_DIR / "chart.npy").exists()
84
  )
85
 
86
  if knapsack_exists:
@@ -96,11 +98,15 @@ def load_model_and_knapsack():
96
  if not knapsack_exists:
97
  logger.info("Knapsack not found or failed to load. Generating knapsack...")
98
  if RESIDUE_SET is None:
99
- raise gr.Error("Cannot generate knapsack because ResidueSet failed to load.")
 
 
100
  try:
101
  # Prepare residue masses for knapsack generation (handle negative/zero masses)
102
  residue_masses_knapsack = dict(RESIDUE_SET.residue_masses.copy())
103
- negative_residues = [k for k, v in residue_masses_knapsack.items() if v <= 0]
 
 
104
  if negative_residues:
105
  logger.info(f"Warning: Non-positive masses found in residues: {negative_residues}. "
106
  "Excluding from knapsack generation.")
@@ -108,19 +114,19 @@ def load_model_and_knapsack():
108
  del residue_masses_knapsack[res]
109
  # Remove special tokens explicitly if they somehow got mass
110
  for special_token in RESIDUE_SET.special_tokens:
111
- if special_token in residue_masses_knapsack:
112
- del residue_masses_knapsack[special_token]
113
 
114
  # Ensure residue indices used match those without special/negative masses
115
  valid_residue_indices = {
116
- res: idx for res, idx in RESIDUE_SET.residue_to_index.items()
 
117
  if res in residue_masses_knapsack
118
  }
119
 
120
-
121
  KNAPSACK = Knapsack.construct_knapsack(
122
  residue_masses=residue_masses_knapsack,
123
- residue_indices=valid_residue_indices, # Use only valid indices
124
  max_mass=MAX_MASS,
125
  mass_scale=MASS_SCALE,
126
  )
@@ -135,6 +141,7 @@ def load_model_and_knapsack():
135
  # Load the model and knapsack when the script starts
136
  load_model_and_knapsack()
137
 
 
138
  def create_inference_config(
139
  input_path: str,
140
  output_path: str,
@@ -143,7 +150,7 @@ def create_inference_config(
143
  """Creates the OmegaConf DictConfig needed for prediction."""
144
  # Load default config if available, otherwise create from scratch
145
  if DEFAULT_CONFIG_PATH.exists():
146
- base_cfg = OmegaConf.load(DEFAULT_CONFIG_PATH)
147
  else:
148
  logger.info(f"Warning: Default config not found at {DEFAULT_CONFIG_PATH}. Using minimal config.")
149
  # Create a minimal config if default is missing
@@ -206,7 +213,9 @@ def create_inference_config(
206
  cfg_overrides["use_knapsack"] = False
207
  elif "Knapsack" in decoding_method:
208
  if KNAPSACK is None:
209
- raise gr.Error("Knapsack is not available. Cannot use Knapsack Beam Search.")
 
 
210
  cfg_overrides["num_beams"] = 5
211
  cfg_overrides["use_knapsack"] = True
212
  cfg_overrides["knapsack_path"] = str(KNAPSACK_DIR)
@@ -223,9 +232,9 @@ def predict_peptides(input_file, decoding_method):
223
  Main function to load data, run prediction, and return results.
224
  """
225
  if MODEL is None or RESIDUE_SET is None or MODEL_CONFIG is None:
226
- load_model_and_knapsack() # Attempt to reload if None (e.g., after space restart)
227
- if MODEL is None:
228
- raise gr.Error("InstaNovo model is not loaded. Cannot perform prediction.")
229
 
230
  if input_file is None:
231
  raise gr.Error("Please upload a mass spectrometry file.")
@@ -248,17 +257,18 @@ def predict_peptides(input_file, decoding_method):
248
  try:
249
  sdf = SpectrumDataFrame.load(
250
  config.data_path,
251
- lazy=False, # Load eagerly for Gradio simplicity
252
- is_annotated=False, # De novo mode
253
  column_mapping=config.get("column_map", None),
254
  shuffle=False,
255
- verbose=True # Print loading logs
256
  )
257
  # Apply charge filter like in CLI
258
  original_size = len(sdf)
259
  max_charge = config.get("max_charge", 10)
260
  sdf.filter_rows(
261
- lambda row: (row["precursor_charge"] <= max_charge) and (row["precursor_charge"] > 0)
 
262
  )
263
  if len(sdf) < original_size:
264
  logger.info(f"Warning: Filtered {original_size - len(sdf)} spectra with charge > {max_charge} or <= 0.")
@@ -275,16 +285,17 @@ def predict_peptides(input_file, decoding_method):
275
  sdf,
276
  RESIDUE_SET,
277
  MODEL_CONFIG.get("n_peaks", 200),
278
- return_str=True, # Needed for greedy/beam search targets later (though not used here)
279
  annotated=False,
280
- pad_spectrum_max_length=config.get("compile_model", False) or config.get("use_flash_attention", False),
 
281
  bin_spectra=config.get("conv_peak_encoder", False),
282
  )
283
  dl = DataLoader(
284
  ds,
285
  batch_size=config.batch_size,
286
- num_workers=0, # Required by SpectrumDataFrame
287
- shuffle=False, # Required by SpectrumDataFrame
288
  collate_fn=collate_batch,
289
  )
290
 
@@ -293,8 +304,10 @@ def predict_peptides(input_file, decoding_method):
293
  decoder: Decoder
294
  if config.use_knapsack:
295
  if KNAPSACK is None:
296
- # This check should ideally be earlier, but double-check
297
- raise gr.Error("Knapsack is required for Knapsack Beam Search but is not available.")
 
 
298
  # KnapsackBeamSearchDecoder doesn't directly load from path in this version?
299
  # We load Knapsack globally, so just pass it.
300
  # If it needed path: decoder = KnapsackBeamSearchDecoder.from_file(model=MODEL, path=config.knapsack_path)
@@ -316,15 +329,22 @@ def predict_peptides(input_file, decoding_method):
316
  # 5. Run Prediction Loop (Adapted from instanovo/transformer/predict.py)
317
  logger.info("Starting prediction...")
318
  start_time = time.time()
319
- results_list: list[ScoredSequence | list] = [] # Store ScoredSequence or empty list
 
 
320
 
321
  for i, batch in enumerate(dl):
322
- spectra, precursors, spectra_mask, _, _ = batch # Ignore peptides/masks for de novo
 
 
323
  spectra = spectra.to(DEVICE)
324
  precursors = precursors.to(DEVICE)
325
  spectra_mask = spectra_mask.to(DEVICE)
326
 
327
- with torch.no_grad(), torch.amp.autocast(DEVICE, dtype=torch.float16, enabled=FP16):
 
 
 
328
  # Beam search decoder might return list[list[ScoredSequence]] if return_beam=True
329
  # Greedy decoder returns list[ScoredSequence]
330
  # KnapsackBeamSearchDecoder returns list[ScoredSequence] or list[list[ScoredSequence]]
@@ -334,9 +354,12 @@ def predict_peptides(input_file, decoding_method):
334
  beam_size=config.num_beams,
335
  max_length=config.max_length,
336
  # Knapsack/Beam Search specific params if needed
337
- mass_tolerance=config.get("filter_precursor_ppm", 20) * 1e-6, # Convert ppm to relative
338
- max_isotope=config.isotope_error_range[1] if config.isotope_error_range else 1,
339
- return_beam=False # Only get the top prediction for simplicity
 
 
 
340
  )
341
  results_list.extend(batch_predictions) # Should be list[ScoredSequence] or list[list]
342
  logger.info(f"Processed batch {i+1}/{len(dl)}")
@@ -349,26 +372,30 @@ def predict_peptides(input_file, decoding_method):
349
  output_data = []
350
  # Use sdf index columns + prediction results
351
  index_cols = [col for col in config.index_columns if col in sdf.df.columns]
352
- base_df_pd = sdf.df.select(index_cols).to_pandas() # Get base info
353
 
354
  metrics_calc = Metrics(RESIDUE_SET, config.isotope_error_range)
355
 
356
  for i, res in enumerate(results_list):
357
- row_data = base_df_pd.iloc[i].to_dict() # Get corresponding input data
358
  if isinstance(res, ScoredSequence) and res.sequence:
359
  sequence_str = "".join(res.sequence)
360
  row_data["prediction"] = sequence_str
361
  row_data["log_probability"] = f"{res.sequence_log_probability:.4f}"
362
  # Use metrics to calculate delta mass ppm for the top prediction
363
  try:
364
- _, delta_mass_list = metrics_calc.matches_precursor(
365
- res.sequence,
366
- row_data["precursor_mz"],
367
- row_data["precursor_charge"]
368
- )
369
- # Find the smallest absolute ppm error across isotopes
370
- min_abs_ppm = min(abs(p) for p in delta_mass_list) if delta_mass_list else float('nan')
371
- row_data["delta_mass_ppm"] = f"{min_abs_ppm:.2f}"
 
 
 
 
372
  except Exception as e:
373
  logger.info(f"Warning: Could not calculate delta mass for prediction {i}: {e}")
374
  row_data["delta_mass_ppm"] = "N/A"
@@ -382,7 +409,14 @@ def predict_peptides(input_file, decoding_method):
382
  output_df = pl.DataFrame(output_data)
383
 
384
  # Ensure specific columns are present and ordered
385
- display_cols = ["scan_number", "precursor_mz", "precursor_charge", "prediction", "log_probability", "delta_mass_ppm"]
 
 
 
 
 
 
 
386
  final_display_cols = []
387
  for col in display_cols:
388
  if col in output_df.columns:
@@ -397,7 +431,6 @@ def predict_peptides(input_file, decoding_method):
397
 
398
  output_df_display = output_df.select(final_display_cols)
399
 
400
-
401
  # 7. Save full results to CSV
402
  logger.info(f"Saving results to {output_csv_path}...")
403
  output_df.write_csv(output_csv_path)
@@ -413,6 +446,7 @@ def predict_peptides(input_file, decoding_method):
413
  # Re-raise as Gradio error
414
  raise gr.Error(f"Prediction failed: {e}")
415
 
 
416
  # --- Gradio Interface ---
417
  css = """
418
  .gradio-container { font-family: sans-serif; }
@@ -422,7 +456,9 @@ footer { display: none !important; }
422
  .logo-container img { margin-bottom: 1rem; }
423
  """
424
 
425
- with gr.Blocks(css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")) as demo:
 
 
426
  # --- Logo Display ---
427
  gr.Markdown(
428
  """
@@ -430,7 +466,7 @@ with gr.Blocks(css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hu
430
  <img src='/gradio_api/file=assets/instanovo.svg' alt="InstaNovo Logo" width="300" style="display: block; margin: 0 auto;">
431
  </div>
432
  """,
433
- elem_classes="logo-container" # Optional class for CSS targeting
434
  )
435
 
436
  # --- App Content ---
@@ -445,38 +481,57 @@ with gr.Blocks(css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hu
445
  with gr.Column(scale=1):
446
  input_file = gr.File(
447
  label="Upload Mass Spectrometry File (.mgf, .mzml, .mzxml)",
448
- file_types=[".mgf", ".mzml", ".mzxml"]
449
  )
450
  decoding_method = gr.Radio(
451
- ["Greedy Search (Fast, resonably accurate)", "Knapsack Beam Search (More accurate, but slower)"],
 
 
 
452
  label="Decoding Method",
453
- value="Greedy Search (Fast, resonably accurate)" # Default to fast method
454
  )
455
  submit_btn = gr.Button("Predict Sequences", variant="primary")
456
  with gr.Column(scale=2):
457
- output_df = gr.DataFrame(label="Prediction Results", headers=["scan_number", "precursor_mz", "precursor_charge", "prediction", "log_probability", "delta_mass_ppm"], wrap=True)
 
 
 
 
 
 
 
 
 
 
 
458
  output_file = gr.File(label="Download Full Results (CSV)")
459
 
460
  submit_btn.click(
461
  predict_peptides,
462
  inputs=[input_file, decoding_method],
463
- outputs=[output_df, output_file]
464
  )
465
 
466
  gr.Examples(
467
- [["assets/sample_spectra.mgf", "Greedy Search (Fast, resonably accurate)" ],
468
- ["assets/sample_spectra.mgf", "Knapsack Beam Search (More accurate, but slower)" ]],
469
- inputs=[input_file, decoding_method],
470
- outputs=[output_df, output_file],
471
- fn=predict_peptides,
472
- cache_examples=False, # Re-run examples if needed
473
- label="Example Usage"
 
 
 
 
 
474
  )
475
 
476
  gr.Markdown(
477
- """
478
  **Notes:**
479
- * Predictions are based on the [InstaNovo](https://github.com/instadeepai/InstaNovo) model ({MODEL_ID}).
480
  * Knapsack Beam Search uses pre-calculated mass constraints and yields better results but takes longer.
481
  * `delta_mass_ppm` shows the lowest absolute precursor mass error (in ppm) across potential isotopes (0-1 neutron).
482
  * Ensure your input file format is correctly specified. Large files may take time to process.
@@ -487,10 +542,32 @@ with gr.Blocks(css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hu
487
  with gr.Accordion("Application Logs", open=True):
488
  log_display = Log(log_file, dark=True, height=300)
489
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
  # --- Launch the App ---
491
  if __name__ == "__main__":
492
  # Set share=True for temporary public link if running locally
493
  # Set server_name="0.0.0.0" to allow access from network if needed
494
  # demo.launch(server_name="0.0.0.0", server_port=7860)
495
  # For Hugging Face Spaces, just demo.launch() is usually sufficient
496
- demo.launch(share=True) # For local testing with public URL
 
28
  raise ImportError(f"Failed to import InstaNovo components: {e}")
29
 
30
  # --- Configuration ---
31
+ MODEL_ID = "instanovo-v1.1.0" # Use the desired pretrained model ID
32
  KNAPSACK_DIR = Path("./knapsack_cache")
33
+ DEFAULT_CONFIG_PATH = Path(
34
+ "./configs/inference/default.yaml"
35
+ ) # Assuming instanovo installs configs locally relative to execution
36
 
37
  # Determine device
38
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
39
+ FP16 = DEVICE == "cuda" # Enable FP16 only on CUDA
40
 
41
  # --- Global Variables (Load Model and Knapsack Once) ---
42
  MODEL: InstaNovo | None = None
 
80
 
81
  # --- Knapsack Handling ---
82
  knapsack_exists = (
83
+ (KNAPSACK_DIR / "parameters.pkl").exists()
84
+ and (KNAPSACK_DIR / "masses.npy").exists()
85
+ and (KNAPSACK_DIR / "chart.npy").exists()
86
  )
87
 
88
  if knapsack_exists:
 
98
  if not knapsack_exists:
99
  logger.info("Knapsack not found or failed to load. Generating knapsack...")
100
  if RESIDUE_SET is None:
101
+ raise gr.Error(
102
+ "Cannot generate knapsack because ResidueSet failed to load."
103
+ )
104
  try:
105
  # Prepare residue masses for knapsack generation (handle negative/zero masses)
106
  residue_masses_knapsack = dict(RESIDUE_SET.residue_masses.copy())
107
+ negative_residues = [
108
+ k for k, v in residue_masses_knapsack.items() if v <= 0
109
+ ]
110
  if negative_residues:
111
  logger.info(f"Warning: Non-positive masses found in residues: {negative_residues}. "
112
  "Excluding from knapsack generation.")
 
114
  del residue_masses_knapsack[res]
115
  # Remove special tokens explicitly if they somehow got mass
116
  for special_token in RESIDUE_SET.special_tokens:
117
+ if special_token in residue_masses_knapsack:
118
+ del residue_masses_knapsack[special_token]
119
 
120
  # Ensure residue indices used match those without special/negative masses
121
  valid_residue_indices = {
122
+ res: idx
123
+ for res, idx in RESIDUE_SET.residue_to_index.items()
124
  if res in residue_masses_knapsack
125
  }
126
 
 
127
  KNAPSACK = Knapsack.construct_knapsack(
128
  residue_masses=residue_masses_knapsack,
129
+ residue_indices=valid_residue_indices, # Use only valid indices
130
  max_mass=MAX_MASS,
131
  mass_scale=MASS_SCALE,
132
  )
 
141
  # Load the model and knapsack when the script starts
142
  load_model_and_knapsack()
143
 
144
+
145
  def create_inference_config(
146
  input_path: str,
147
  output_path: str,
 
150
  """Creates the OmegaConf DictConfig needed for prediction."""
151
  # Load default config if available, otherwise create from scratch
152
  if DEFAULT_CONFIG_PATH.exists():
153
+ base_cfg = OmegaConf.load(DEFAULT_CONFIG_PATH)
154
  else:
155
  logger.info(f"Warning: Default config not found at {DEFAULT_CONFIG_PATH}. Using minimal config.")
156
  # Create a minimal config if default is missing
 
213
  cfg_overrides["use_knapsack"] = False
214
  elif "Knapsack" in decoding_method:
215
  if KNAPSACK is None:
216
+ raise gr.Error(
217
+ "Knapsack is not available. Cannot use Knapsack Beam Search."
218
+ )
219
  cfg_overrides["num_beams"] = 5
220
  cfg_overrides["use_knapsack"] = True
221
  cfg_overrides["knapsack_path"] = str(KNAPSACK_DIR)
 
232
  Main function to load data, run prediction, and return results.
233
  """
234
  if MODEL is None or RESIDUE_SET is None or MODEL_CONFIG is None:
235
+ load_model_and_knapsack() # Attempt to reload if None (e.g., after space restart)
236
+ if MODEL is None:
237
+ raise gr.Error("InstaNovo model is not loaded. Cannot perform prediction.")
238
 
239
  if input_file is None:
240
  raise gr.Error("Please upload a mass spectrometry file.")
 
257
  try:
258
  sdf = SpectrumDataFrame.load(
259
  config.data_path,
260
+ lazy=False, # Load eagerly for Gradio simplicity
261
+ is_annotated=False, # De novo mode
262
  column_mapping=config.get("column_map", None),
263
  shuffle=False,
264
+ verbose=True, # Print loading logs
265
  )
266
  # Apply charge filter like in CLI
267
  original_size = len(sdf)
268
  max_charge = config.get("max_charge", 10)
269
  sdf.filter_rows(
270
+ lambda row: (row["precursor_charge"] <= max_charge)
271
+ and (row["precursor_charge"] > 0)
272
  )
273
  if len(sdf) < original_size:
274
  logger.info(f"Warning: Filtered {original_size - len(sdf)} spectra with charge > {max_charge} or <= 0.")
 
285
  sdf,
286
  RESIDUE_SET,
287
  MODEL_CONFIG.get("n_peaks", 200),
288
+ return_str=True, # Needed for greedy/beam search targets later (though not used here)
289
  annotated=False,
290
+ pad_spectrum_max_length=config.get("compile_model", False)
291
+ or config.get("use_flash_attention", False),
292
  bin_spectra=config.get("conv_peak_encoder", False),
293
  )
294
  dl = DataLoader(
295
  ds,
296
  batch_size=config.batch_size,
297
+ num_workers=0, # Required by SpectrumDataFrame
298
+ shuffle=False, # Required by SpectrumDataFrame
299
  collate_fn=collate_batch,
300
  )
301
 
 
304
  decoder: Decoder
305
  if config.use_knapsack:
306
  if KNAPSACK is None:
307
+ # This check should ideally be earlier, but double-check
308
+ raise gr.Error(
309
+ "Knapsack is required for Knapsack Beam Search but is not available."
310
+ )
311
  # KnapsackBeamSearchDecoder doesn't directly load from path in this version?
312
  # We load Knapsack globally, so just pass it.
313
  # If it needed path: decoder = KnapsackBeamSearchDecoder.from_file(model=MODEL, path=config.knapsack_path)
 
329
  # 5. Run Prediction Loop (Adapted from instanovo/transformer/predict.py)
330
  logger.info("Starting prediction...")
331
  start_time = time.time()
332
+ results_list: list[
333
+ ScoredSequence | list
334
+ ] = [] # Store ScoredSequence or empty list
335
 
336
  for i, batch in enumerate(dl):
337
+ spectra, precursors, spectra_mask, _, _ = (
338
+ batch # Ignore peptides/masks for de novo
339
+ )
340
  spectra = spectra.to(DEVICE)
341
  precursors = precursors.to(DEVICE)
342
  spectra_mask = spectra_mask.to(DEVICE)
343
 
344
+ with (
345
+ torch.no_grad(),
346
+ torch.amp.autocast(DEVICE, dtype=torch.float16, enabled=FP16),
347
+ ):
348
  # Beam search decoder might return list[list[ScoredSequence]] if return_beam=True
349
  # Greedy decoder returns list[ScoredSequence]
350
  # KnapsackBeamSearchDecoder returns list[ScoredSequence] or list[list[ScoredSequence]]
 
354
  beam_size=config.num_beams,
355
  max_length=config.max_length,
356
  # Knapsack/Beam Search specific params if needed
357
+ mass_tolerance=config.get("filter_precursor_ppm", 20)
358
+ * 1e-6, # Convert ppm to relative
359
+ max_isotope=config.isotope_error_range[1]
360
+ if config.isotope_error_range
361
+ else 1,
362
+ return_beam=False, # Only get the top prediction for simplicity
363
  )
364
  results_list.extend(batch_predictions) # Should be list[ScoredSequence] or list[list]
365
  logger.info(f"Processed batch {i+1}/{len(dl)}")
 
372
  output_data = []
373
  # Use sdf index columns + prediction results
374
  index_cols = [col for col in config.index_columns if col in sdf.df.columns]
375
+ base_df_pd = sdf.df.select(index_cols).to_pandas() # Get base info
376
 
377
  metrics_calc = Metrics(RESIDUE_SET, config.isotope_error_range)
378
 
379
  for i, res in enumerate(results_list):
380
+ row_data = base_df_pd.iloc[i].to_dict() # Get corresponding input data
381
  if isinstance(res, ScoredSequence) and res.sequence:
382
  sequence_str = "".join(res.sequence)
383
  row_data["prediction"] = sequence_str
384
  row_data["log_probability"] = f"{res.sequence_log_probability:.4f}"
385
  # Use metrics to calculate delta mass ppm for the top prediction
386
  try:
387
+ _, delta_mass_list = metrics_calc.matches_precursor(
388
+ res.sequence,
389
+ row_data["precursor_mz"],
390
+ row_data["precursor_charge"],
391
+ )
392
+ # Find the smallest absolute ppm error across isotopes
393
+ min_abs_ppm = (
394
+ min(abs(p) for p in delta_mass_list)
395
+ if delta_mass_list
396
+ else float("nan")
397
+ )
398
+ row_data["delta_mass_ppm"] = f"{min_abs_ppm:.2f}"
399
  except Exception as e:
400
  logger.info(f"Warning: Could not calculate delta mass for prediction {i}: {e}")
401
  row_data["delta_mass_ppm"] = "N/A"
 
409
  output_df = pl.DataFrame(output_data)
410
 
411
  # Ensure specific columns are present and ordered
412
+ display_cols = [
413
+ "scan_number",
414
+ "precursor_mz",
415
+ "precursor_charge",
416
+ "prediction",
417
+ "log_probability",
418
+ "delta_mass_ppm",
419
+ ]
420
  final_display_cols = []
421
  for col in display_cols:
422
  if col in output_df.columns:
 
431
 
432
  output_df_display = output_df.select(final_display_cols)
433
 
 
434
  # 7. Save full results to CSV
435
  logger.info(f"Saving results to {output_csv_path}...")
436
  output_df.write_csv(output_csv_path)
 
446
  # Re-raise as Gradio error
447
  raise gr.Error(f"Prediction failed: {e}")
448
 
449
+
450
  # --- Gradio Interface ---
451
  css = """
452
  .gradio-container { font-family: sans-serif; }
 
456
  .logo-container img { margin-bottom: 1rem; }
457
  """
458
 
459
+ with gr.Blocks(
460
+ css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")
461
+ ) as demo:
462
  # --- Logo Display ---
463
  gr.Markdown(
464
  """
 
466
  <img src='/gradio_api/file=assets/instanovo.svg' alt="InstaNovo Logo" width="300" style="display: block; margin: 0 auto;">
467
  </div>
468
  """,
469
+ elem_classes="logo-container", # Optional class for CSS targeting
470
  )
471
 
472
  # --- App Content ---
 
481
  with gr.Column(scale=1):
482
  input_file = gr.File(
483
  label="Upload Mass Spectrometry File (.mgf, .mzml, .mzxml)",
484
+ file_types=[".mgf", ".mzml", ".mzxml"],
485
  )
486
  decoding_method = gr.Radio(
487
+ [
488
+ "Greedy Search (Fast, resonably accurate)",
489
+ "Knapsack Beam Search (More accurate, but slower)",
490
+ ],
491
  label="Decoding Method",
492
+ value="Greedy Search (Fast, resonably accurate)", # Default to fast method
493
  )
494
  submit_btn = gr.Button("Predict Sequences", variant="primary")
495
  with gr.Column(scale=2):
496
+ output_df = gr.DataFrame(
497
+ label="Prediction Results",
498
+ headers=[
499
+ "scan_number",
500
+ "precursor_mz",
501
+ "precursor_charge",
502
+ "prediction",
503
+ "log_probability",
504
+ "delta_mass_ppm",
505
+ ],
506
+ wrap=True,
507
+ )
508
  output_file = gr.File(label="Download Full Results (CSV)")
509
 
510
  submit_btn.click(
511
  predict_peptides,
512
  inputs=[input_file, decoding_method],
513
+ outputs=[output_df, output_file],
514
  )
515
 
516
  gr.Examples(
517
+ [
518
+ ["assets/sample_spectra.mgf", "Greedy Search (Fast, resonably accurate)"],
519
+ [
520
+ "assets/sample_spectra.mgf",
521
+ "Knapsack Beam Search (More accurate, but slower)",
522
+ ],
523
+ ],
524
+ inputs=[input_file, decoding_method],
525
+ outputs=[output_df, output_file],
526
+ fn=predict_peptides,
527
+ cache_examples=False, # Re-run examples if needed
528
+ label="Example Usage",
529
  )
530
 
531
  gr.Markdown(
532
+ """
533
  **Notes:**
534
+ * Predictions are based on the [InstaNovo](https://github.com/instadeepai/InstaNovo) model `{MODEL_ID}`.
535
  * Knapsack Beam Search uses pre-calculated mass constraints and yields better results but takes longer.
536
  * `delta_mass_ppm` shows the lowest absolute precursor mass error (in ppm) across potential isotopes (0-1 neutron).
537
  * Ensure your input file format is correctly specified. Large files may take time to process.
 
542
  with gr.Accordion("Application Logs", open=True):
543
  log_display = Log(log_file, dark=True, height=300)
544
 
545
+ gr.Textbox(
546
+ value="""
547
+ @article{eloff_kalogeropoulos_2025_instanovo,
548
+ title = {InstaNovo enables diffusion-powered de novo peptide sequencing in large-scale proteomics experiments},
549
+ author = {Kevin Eloff and Konstantinos Kalogeropoulos and Amandla Mabona and Oliver Morell and Rachel Catzel and
550
+ Esperanza Rivera-de-Torre and Jakob Berg Jespersen and Wesley Williams and Sam P. B. van Beljouw and
551
+ Marcin J. Skwark and Andreas Hougaard Laustsen and Stan J. J. Brouns and Anne Ljungars and Erwin M.
552
+ Schoof and Jeroen Van Goey and Ulrich auf dem Keller and Karim Beguir and Nicolas Lopez Carranza and
553
+ Timothy P. Jenkins},
554
+ year = 2025,
555
+ month = {Mar},
556
+ day = 31,
557
+ journal = {Nature Machine Intelligence},
558
+ doi = {10.1038/s42256-025-01019-5},
559
+ url = {https://www.nature.com/articles/s42256-025-01019-5}
560
+ }
561
+ """,
562
+ show_copy_button=True,
563
+ label="If you use InstaNovo in your research, please cite:",
564
+ interactive=False,
565
+ )
566
+
567
  # --- Launch the App ---
568
  if __name__ == "__main__":
569
  # Set share=True for temporary public link if running locally
570
  # Set server_name="0.0.0.0" to allow access from network if needed
571
  # demo.launch(server_name="0.0.0.0", server_port=7860)
572
  # For Hugging Face Spaces, just demo.launch() is usually sufficient
573
+ demo.launch(share=True) # For local testing with public URL