Spaces:
Running
on
Zero
Running
on
Zero
Merge branch 'citation'
Browse files
app.py
CHANGED
|
@@ -28,13 +28,15 @@ except ImportError as e:
|
|
| 28 |
raise ImportError(f"Failed to import InstaNovo components: {e}")
|
| 29 |
|
| 30 |
# --- Configuration ---
|
| 31 |
-
MODEL_ID = "instanovo-v1.1.0"
|
| 32 |
KNAPSACK_DIR = Path("./knapsack_cache")
|
| 33 |
-
DEFAULT_CONFIG_PATH = Path(
|
|
|
|
|
|
|
| 34 |
|
| 35 |
# Determine device
|
| 36 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 37 |
-
FP16 = DEVICE == "cuda"
|
| 38 |
|
| 39 |
# --- Global Variables (Load Model and Knapsack Once) ---
|
| 40 |
MODEL: InstaNovo | None = None
|
|
@@ -78,9 +80,9 @@ def load_model_and_knapsack():
|
|
| 78 |
|
| 79 |
# --- Knapsack Handling ---
|
| 80 |
knapsack_exists = (
|
| 81 |
-
(KNAPSACK_DIR / "parameters.pkl").exists()
|
| 82 |
-
(KNAPSACK_DIR / "masses.npy").exists()
|
| 83 |
-
(KNAPSACK_DIR / "chart.npy").exists()
|
| 84 |
)
|
| 85 |
|
| 86 |
if knapsack_exists:
|
|
@@ -96,11 +98,15 @@ def load_model_and_knapsack():
|
|
| 96 |
if not knapsack_exists:
|
| 97 |
logger.info("Knapsack not found or failed to load. Generating knapsack...")
|
| 98 |
if RESIDUE_SET is None:
|
| 99 |
-
|
|
|
|
|
|
|
| 100 |
try:
|
| 101 |
# Prepare residue masses for knapsack generation (handle negative/zero masses)
|
| 102 |
residue_masses_knapsack = dict(RESIDUE_SET.residue_masses.copy())
|
| 103 |
-
negative_residues = [
|
|
|
|
|
|
|
| 104 |
if negative_residues:
|
| 105 |
logger.info(f"Warning: Non-positive masses found in residues: {negative_residues}. "
|
| 106 |
"Excluding from knapsack generation.")
|
|
@@ -108,19 +114,19 @@ def load_model_and_knapsack():
|
|
| 108 |
del residue_masses_knapsack[res]
|
| 109 |
# Remove special tokens explicitly if they somehow got mass
|
| 110 |
for special_token in RESIDUE_SET.special_tokens:
|
| 111 |
-
|
| 112 |
-
|
| 113 |
|
| 114 |
# Ensure residue indices used match those without special/negative masses
|
| 115 |
valid_residue_indices = {
|
| 116 |
-
res: idx
|
|
|
|
| 117 |
if res in residue_masses_knapsack
|
| 118 |
}
|
| 119 |
|
| 120 |
-
|
| 121 |
KNAPSACK = Knapsack.construct_knapsack(
|
| 122 |
residue_masses=residue_masses_knapsack,
|
| 123 |
-
residue_indices=valid_residue_indices,
|
| 124 |
max_mass=MAX_MASS,
|
| 125 |
mass_scale=MASS_SCALE,
|
| 126 |
)
|
|
@@ -135,6 +141,7 @@ def load_model_and_knapsack():
|
|
| 135 |
# Load the model and knapsack when the script starts
|
| 136 |
load_model_and_knapsack()
|
| 137 |
|
|
|
|
| 138 |
def create_inference_config(
|
| 139 |
input_path: str,
|
| 140 |
output_path: str,
|
|
@@ -143,7 +150,7 @@ def create_inference_config(
|
|
| 143 |
"""Creates the OmegaConf DictConfig needed for prediction."""
|
| 144 |
# Load default config if available, otherwise create from scratch
|
| 145 |
if DEFAULT_CONFIG_PATH.exists():
|
| 146 |
-
|
| 147 |
else:
|
| 148 |
logger.info(f"Warning: Default config not found at {DEFAULT_CONFIG_PATH}. Using minimal config.")
|
| 149 |
# Create a minimal config if default is missing
|
|
@@ -206,7 +213,9 @@ def create_inference_config(
|
|
| 206 |
cfg_overrides["use_knapsack"] = False
|
| 207 |
elif "Knapsack" in decoding_method:
|
| 208 |
if KNAPSACK is None:
|
| 209 |
-
raise gr.Error(
|
|
|
|
|
|
|
| 210 |
cfg_overrides["num_beams"] = 5
|
| 211 |
cfg_overrides["use_knapsack"] = True
|
| 212 |
cfg_overrides["knapsack_path"] = str(KNAPSACK_DIR)
|
|
@@ -223,9 +232,9 @@ def predict_peptides(input_file, decoding_method):
|
|
| 223 |
Main function to load data, run prediction, and return results.
|
| 224 |
"""
|
| 225 |
if MODEL is None or RESIDUE_SET is None or MODEL_CONFIG is None:
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
|
| 230 |
if input_file is None:
|
| 231 |
raise gr.Error("Please upload a mass spectrometry file.")
|
|
@@ -248,17 +257,18 @@ def predict_peptides(input_file, decoding_method):
|
|
| 248 |
try:
|
| 249 |
sdf = SpectrumDataFrame.load(
|
| 250 |
config.data_path,
|
| 251 |
-
lazy=False,
|
| 252 |
-
is_annotated=False,
|
| 253 |
column_mapping=config.get("column_map", None),
|
| 254 |
shuffle=False,
|
| 255 |
-
verbose=True
|
| 256 |
)
|
| 257 |
# Apply charge filter like in CLI
|
| 258 |
original_size = len(sdf)
|
| 259 |
max_charge = config.get("max_charge", 10)
|
| 260 |
sdf.filter_rows(
|
| 261 |
-
lambda row: (row["precursor_charge"] <= max_charge)
|
|
|
|
| 262 |
)
|
| 263 |
if len(sdf) < original_size:
|
| 264 |
logger.info(f"Warning: Filtered {original_size - len(sdf)} spectra with charge > {max_charge} or <= 0.")
|
|
@@ -275,16 +285,17 @@ def predict_peptides(input_file, decoding_method):
|
|
| 275 |
sdf,
|
| 276 |
RESIDUE_SET,
|
| 277 |
MODEL_CONFIG.get("n_peaks", 200),
|
| 278 |
-
return_str=True,
|
| 279 |
annotated=False,
|
| 280 |
-
pad_spectrum_max_length=config.get("compile_model", False)
|
|
|
|
| 281 |
bin_spectra=config.get("conv_peak_encoder", False),
|
| 282 |
)
|
| 283 |
dl = DataLoader(
|
| 284 |
ds,
|
| 285 |
batch_size=config.batch_size,
|
| 286 |
-
num_workers=0,
|
| 287 |
-
shuffle=False,
|
| 288 |
collate_fn=collate_batch,
|
| 289 |
)
|
| 290 |
|
|
@@ -293,8 +304,10 @@ def predict_peptides(input_file, decoding_method):
|
|
| 293 |
decoder: Decoder
|
| 294 |
if config.use_knapsack:
|
| 295 |
if KNAPSACK is None:
|
| 296 |
-
|
| 297 |
-
|
|
|
|
|
|
|
| 298 |
# KnapsackBeamSearchDecoder doesn't directly load from path in this version?
|
| 299 |
# We load Knapsack globally, so just pass it.
|
| 300 |
# If it needed path: decoder = KnapsackBeamSearchDecoder.from_file(model=MODEL, path=config.knapsack_path)
|
|
@@ -316,15 +329,22 @@ def predict_peptides(input_file, decoding_method):
|
|
| 316 |
# 5. Run Prediction Loop (Adapted from instanovo/transformer/predict.py)
|
| 317 |
logger.info("Starting prediction...")
|
| 318 |
start_time = time.time()
|
| 319 |
-
results_list: list[
|
|
|
|
|
|
|
| 320 |
|
| 321 |
for i, batch in enumerate(dl):
|
| 322 |
-
spectra, precursors, spectra_mask, _, _ =
|
|
|
|
|
|
|
| 323 |
spectra = spectra.to(DEVICE)
|
| 324 |
precursors = precursors.to(DEVICE)
|
| 325 |
spectra_mask = spectra_mask.to(DEVICE)
|
| 326 |
|
| 327 |
-
with
|
|
|
|
|
|
|
|
|
|
| 328 |
# Beam search decoder might return list[list[ScoredSequence]] if return_beam=True
|
| 329 |
# Greedy decoder returns list[ScoredSequence]
|
| 330 |
# KnapsackBeamSearchDecoder returns list[ScoredSequence] or list[list[ScoredSequence]]
|
|
@@ -334,9 +354,12 @@ def predict_peptides(input_file, decoding_method):
|
|
| 334 |
beam_size=config.num_beams,
|
| 335 |
max_length=config.max_length,
|
| 336 |
# Knapsack/Beam Search specific params if needed
|
| 337 |
-
mass_tolerance=config.get("filter_precursor_ppm", 20)
|
| 338 |
-
|
| 339 |
-
|
|
|
|
|
|
|
|
|
|
| 340 |
)
|
| 341 |
results_list.extend(batch_predictions) # Should be list[ScoredSequence] or list[list]
|
| 342 |
logger.info(f"Processed batch {i+1}/{len(dl)}")
|
|
@@ -349,26 +372,30 @@ def predict_peptides(input_file, decoding_method):
|
|
| 349 |
output_data = []
|
| 350 |
# Use sdf index columns + prediction results
|
| 351 |
index_cols = [col for col in config.index_columns if col in sdf.df.columns]
|
| 352 |
-
base_df_pd = sdf.df.select(index_cols).to_pandas()
|
| 353 |
|
| 354 |
metrics_calc = Metrics(RESIDUE_SET, config.isotope_error_range)
|
| 355 |
|
| 356 |
for i, res in enumerate(results_list):
|
| 357 |
-
row_data = base_df_pd.iloc[i].to_dict()
|
| 358 |
if isinstance(res, ScoredSequence) and res.sequence:
|
| 359 |
sequence_str = "".join(res.sequence)
|
| 360 |
row_data["prediction"] = sequence_str
|
| 361 |
row_data["log_probability"] = f"{res.sequence_log_probability:.4f}"
|
| 362 |
# Use metrics to calculate delta mass ppm for the top prediction
|
| 363 |
try:
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
except Exception as e:
|
| 373 |
logger.info(f"Warning: Could not calculate delta mass for prediction {i}: {e}")
|
| 374 |
row_data["delta_mass_ppm"] = "N/A"
|
|
@@ -382,7 +409,14 @@ def predict_peptides(input_file, decoding_method):
|
|
| 382 |
output_df = pl.DataFrame(output_data)
|
| 383 |
|
| 384 |
# Ensure specific columns are present and ordered
|
| 385 |
-
display_cols = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
final_display_cols = []
|
| 387 |
for col in display_cols:
|
| 388 |
if col in output_df.columns:
|
|
@@ -397,7 +431,6 @@ def predict_peptides(input_file, decoding_method):
|
|
| 397 |
|
| 398 |
output_df_display = output_df.select(final_display_cols)
|
| 399 |
|
| 400 |
-
|
| 401 |
# 7. Save full results to CSV
|
| 402 |
logger.info(f"Saving results to {output_csv_path}...")
|
| 403 |
output_df.write_csv(output_csv_path)
|
|
@@ -413,6 +446,7 @@ def predict_peptides(input_file, decoding_method):
|
|
| 413 |
# Re-raise as Gradio error
|
| 414 |
raise gr.Error(f"Prediction failed: {e}")
|
| 415 |
|
|
|
|
| 416 |
# --- Gradio Interface ---
|
| 417 |
css = """
|
| 418 |
.gradio-container { font-family: sans-serif; }
|
|
@@ -422,7 +456,9 @@ footer { display: none !important; }
|
|
| 422 |
.logo-container img { margin-bottom: 1rem; }
|
| 423 |
"""
|
| 424 |
|
| 425 |
-
with gr.Blocks(
|
|
|
|
|
|
|
| 426 |
# --- Logo Display ---
|
| 427 |
gr.Markdown(
|
| 428 |
"""
|
|
@@ -430,7 +466,7 @@ with gr.Blocks(css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hu
|
|
| 430 |
<img src='/gradio_api/file=assets/instanovo.svg' alt="InstaNovo Logo" width="300" style="display: block; margin: 0 auto;">
|
| 431 |
</div>
|
| 432 |
""",
|
| 433 |
-
elem_classes="logo-container"
|
| 434 |
)
|
| 435 |
|
| 436 |
# --- App Content ---
|
|
@@ -445,38 +481,57 @@ with gr.Blocks(css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hu
|
|
| 445 |
with gr.Column(scale=1):
|
| 446 |
input_file = gr.File(
|
| 447 |
label="Upload Mass Spectrometry File (.mgf, .mzml, .mzxml)",
|
| 448 |
-
file_types=[".mgf", ".mzml", ".mzxml"]
|
| 449 |
)
|
| 450 |
decoding_method = gr.Radio(
|
| 451 |
-
[
|
|
|
|
|
|
|
|
|
|
| 452 |
label="Decoding Method",
|
| 453 |
-
value="Greedy Search (Fast, resonably accurate)"
|
| 454 |
)
|
| 455 |
submit_btn = gr.Button("Predict Sequences", variant="primary")
|
| 456 |
with gr.Column(scale=2):
|
| 457 |
-
output_df = gr.DataFrame(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
output_file = gr.File(label="Download Full Results (CSV)")
|
| 459 |
|
| 460 |
submit_btn.click(
|
| 461 |
predict_peptides,
|
| 462 |
inputs=[input_file, decoding_method],
|
| 463 |
-
outputs=[output_df, output_file]
|
| 464 |
)
|
| 465 |
|
| 466 |
gr.Examples(
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
)
|
| 475 |
|
| 476 |
gr.Markdown(
|
| 477 |
-
|
| 478 |
**Notes:**
|
| 479 |
-
* Predictions are based on the [InstaNovo](https://github.com/instadeepai/InstaNovo) model
|
| 480 |
* Knapsack Beam Search uses pre-calculated mass constraints and yields better results but takes longer.
|
| 481 |
* `delta_mass_ppm` shows the lowest absolute precursor mass error (in ppm) across potential isotopes (0-1 neutron).
|
| 482 |
* Ensure your input file format is correctly specified. Large files may take time to process.
|
|
@@ -487,10 +542,32 @@ with gr.Blocks(css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hu
|
|
| 487 |
with gr.Accordion("Application Logs", open=True):
|
| 488 |
log_display = Log(log_file, dark=True, height=300)
|
| 489 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
# --- Launch the App ---
|
| 491 |
if __name__ == "__main__":
|
| 492 |
# Set share=True for temporary public link if running locally
|
| 493 |
# Set server_name="0.0.0.0" to allow access from network if needed
|
| 494 |
# demo.launch(server_name="0.0.0.0", server_port=7860)
|
| 495 |
# For Hugging Face Spaces, just demo.launch() is usually sufficient
|
| 496 |
-
demo.launch(share=True)
|
|
|
|
| 28 |
raise ImportError(f"Failed to import InstaNovo components: {e}")
|
| 29 |
|
| 30 |
# --- Configuration ---
|
| 31 |
+
MODEL_ID = "instanovo-v1.1.0" # Use the desired pretrained model ID
|
| 32 |
KNAPSACK_DIR = Path("./knapsack_cache")
|
| 33 |
+
DEFAULT_CONFIG_PATH = Path(
|
| 34 |
+
"./configs/inference/default.yaml"
|
| 35 |
+
) # Assuming instanovo installs configs locally relative to execution
|
| 36 |
|
| 37 |
# Determine device
|
| 38 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 39 |
+
FP16 = DEVICE == "cuda" # Enable FP16 only on CUDA
|
| 40 |
|
| 41 |
# --- Global Variables (Load Model and Knapsack Once) ---
|
| 42 |
MODEL: InstaNovo | None = None
|
|
|
|
| 80 |
|
| 81 |
# --- Knapsack Handling ---
|
| 82 |
knapsack_exists = (
|
| 83 |
+
(KNAPSACK_DIR / "parameters.pkl").exists()
|
| 84 |
+
and (KNAPSACK_DIR / "masses.npy").exists()
|
| 85 |
+
and (KNAPSACK_DIR / "chart.npy").exists()
|
| 86 |
)
|
| 87 |
|
| 88 |
if knapsack_exists:
|
|
|
|
| 98 |
if not knapsack_exists:
|
| 99 |
logger.info("Knapsack not found or failed to load. Generating knapsack...")
|
| 100 |
if RESIDUE_SET is None:
|
| 101 |
+
raise gr.Error(
|
| 102 |
+
"Cannot generate knapsack because ResidueSet failed to load."
|
| 103 |
+
)
|
| 104 |
try:
|
| 105 |
# Prepare residue masses for knapsack generation (handle negative/zero masses)
|
| 106 |
residue_masses_knapsack = dict(RESIDUE_SET.residue_masses.copy())
|
| 107 |
+
negative_residues = [
|
| 108 |
+
k for k, v in residue_masses_knapsack.items() if v <= 0
|
| 109 |
+
]
|
| 110 |
if negative_residues:
|
| 111 |
logger.info(f"Warning: Non-positive masses found in residues: {negative_residues}. "
|
| 112 |
"Excluding from knapsack generation.")
|
|
|
|
| 114 |
del residue_masses_knapsack[res]
|
| 115 |
# Remove special tokens explicitly if they somehow got mass
|
| 116 |
for special_token in RESIDUE_SET.special_tokens:
|
| 117 |
+
if special_token in residue_masses_knapsack:
|
| 118 |
+
del residue_masses_knapsack[special_token]
|
| 119 |
|
| 120 |
# Ensure residue indices used match those without special/negative masses
|
| 121 |
valid_residue_indices = {
|
| 122 |
+
res: idx
|
| 123 |
+
for res, idx in RESIDUE_SET.residue_to_index.items()
|
| 124 |
if res in residue_masses_knapsack
|
| 125 |
}
|
| 126 |
|
|
|
|
| 127 |
KNAPSACK = Knapsack.construct_knapsack(
|
| 128 |
residue_masses=residue_masses_knapsack,
|
| 129 |
+
residue_indices=valid_residue_indices, # Use only valid indices
|
| 130 |
max_mass=MAX_MASS,
|
| 131 |
mass_scale=MASS_SCALE,
|
| 132 |
)
|
|
|
|
| 141 |
# Load the model and knapsack when the script starts
|
| 142 |
load_model_and_knapsack()
|
| 143 |
|
| 144 |
+
|
| 145 |
def create_inference_config(
|
| 146 |
input_path: str,
|
| 147 |
output_path: str,
|
|
|
|
| 150 |
"""Creates the OmegaConf DictConfig needed for prediction."""
|
| 151 |
# Load default config if available, otherwise create from scratch
|
| 152 |
if DEFAULT_CONFIG_PATH.exists():
|
| 153 |
+
base_cfg = OmegaConf.load(DEFAULT_CONFIG_PATH)
|
| 154 |
else:
|
| 155 |
logger.info(f"Warning: Default config not found at {DEFAULT_CONFIG_PATH}. Using minimal config.")
|
| 156 |
# Create a minimal config if default is missing
|
|
|
|
| 213 |
cfg_overrides["use_knapsack"] = False
|
| 214 |
elif "Knapsack" in decoding_method:
|
| 215 |
if KNAPSACK is None:
|
| 216 |
+
raise gr.Error(
|
| 217 |
+
"Knapsack is not available. Cannot use Knapsack Beam Search."
|
| 218 |
+
)
|
| 219 |
cfg_overrides["num_beams"] = 5
|
| 220 |
cfg_overrides["use_knapsack"] = True
|
| 221 |
cfg_overrides["knapsack_path"] = str(KNAPSACK_DIR)
|
|
|
|
| 232 |
Main function to load data, run prediction, and return results.
|
| 233 |
"""
|
| 234 |
if MODEL is None or RESIDUE_SET is None or MODEL_CONFIG is None:
|
| 235 |
+
load_model_and_knapsack() # Attempt to reload if None (e.g., after space restart)
|
| 236 |
+
if MODEL is None:
|
| 237 |
+
raise gr.Error("InstaNovo model is not loaded. Cannot perform prediction.")
|
| 238 |
|
| 239 |
if input_file is None:
|
| 240 |
raise gr.Error("Please upload a mass spectrometry file.")
|
|
|
|
| 257 |
try:
|
| 258 |
sdf = SpectrumDataFrame.load(
|
| 259 |
config.data_path,
|
| 260 |
+
lazy=False, # Load eagerly for Gradio simplicity
|
| 261 |
+
is_annotated=False, # De novo mode
|
| 262 |
column_mapping=config.get("column_map", None),
|
| 263 |
shuffle=False,
|
| 264 |
+
verbose=True, # Print loading logs
|
| 265 |
)
|
| 266 |
# Apply charge filter like in CLI
|
| 267 |
original_size = len(sdf)
|
| 268 |
max_charge = config.get("max_charge", 10)
|
| 269 |
sdf.filter_rows(
|
| 270 |
+
lambda row: (row["precursor_charge"] <= max_charge)
|
| 271 |
+
and (row["precursor_charge"] > 0)
|
| 272 |
)
|
| 273 |
if len(sdf) < original_size:
|
| 274 |
logger.info(f"Warning: Filtered {original_size - len(sdf)} spectra with charge > {max_charge} or <= 0.")
|
|
|
|
| 285 |
sdf,
|
| 286 |
RESIDUE_SET,
|
| 287 |
MODEL_CONFIG.get("n_peaks", 200),
|
| 288 |
+
return_str=True, # Needed for greedy/beam search targets later (though not used here)
|
| 289 |
annotated=False,
|
| 290 |
+
pad_spectrum_max_length=config.get("compile_model", False)
|
| 291 |
+
or config.get("use_flash_attention", False),
|
| 292 |
bin_spectra=config.get("conv_peak_encoder", False),
|
| 293 |
)
|
| 294 |
dl = DataLoader(
|
| 295 |
ds,
|
| 296 |
batch_size=config.batch_size,
|
| 297 |
+
num_workers=0, # Required by SpectrumDataFrame
|
| 298 |
+
shuffle=False, # Required by SpectrumDataFrame
|
| 299 |
collate_fn=collate_batch,
|
| 300 |
)
|
| 301 |
|
|
|
|
| 304 |
decoder: Decoder
|
| 305 |
if config.use_knapsack:
|
| 306 |
if KNAPSACK is None:
|
| 307 |
+
# This check should ideally be earlier, but double-check
|
| 308 |
+
raise gr.Error(
|
| 309 |
+
"Knapsack is required for Knapsack Beam Search but is not available."
|
| 310 |
+
)
|
| 311 |
# KnapsackBeamSearchDecoder doesn't directly load from path in this version?
|
| 312 |
# We load Knapsack globally, so just pass it.
|
| 313 |
# If it needed path: decoder = KnapsackBeamSearchDecoder.from_file(model=MODEL, path=config.knapsack_path)
|
|
|
|
| 329 |
# 5. Run Prediction Loop (Adapted from instanovo/transformer/predict.py)
|
| 330 |
logger.info("Starting prediction...")
|
| 331 |
start_time = time.time()
|
| 332 |
+
results_list: list[
|
| 333 |
+
ScoredSequence | list
|
| 334 |
+
] = [] # Store ScoredSequence or empty list
|
| 335 |
|
| 336 |
for i, batch in enumerate(dl):
|
| 337 |
+
spectra, precursors, spectra_mask, _, _ = (
|
| 338 |
+
batch # Ignore peptides/masks for de novo
|
| 339 |
+
)
|
| 340 |
spectra = spectra.to(DEVICE)
|
| 341 |
precursors = precursors.to(DEVICE)
|
| 342 |
spectra_mask = spectra_mask.to(DEVICE)
|
| 343 |
|
| 344 |
+
with (
|
| 345 |
+
torch.no_grad(),
|
| 346 |
+
torch.amp.autocast(DEVICE, dtype=torch.float16, enabled=FP16),
|
| 347 |
+
):
|
| 348 |
# Beam search decoder might return list[list[ScoredSequence]] if return_beam=True
|
| 349 |
# Greedy decoder returns list[ScoredSequence]
|
| 350 |
# KnapsackBeamSearchDecoder returns list[ScoredSequence] or list[list[ScoredSequence]]
|
|
|
|
| 354 |
beam_size=config.num_beams,
|
| 355 |
max_length=config.max_length,
|
| 356 |
# Knapsack/Beam Search specific params if needed
|
| 357 |
+
mass_tolerance=config.get("filter_precursor_ppm", 20)
|
| 358 |
+
* 1e-6, # Convert ppm to relative
|
| 359 |
+
max_isotope=config.isotope_error_range[1]
|
| 360 |
+
if config.isotope_error_range
|
| 361 |
+
else 1,
|
| 362 |
+
return_beam=False, # Only get the top prediction for simplicity
|
| 363 |
)
|
| 364 |
results_list.extend(batch_predictions) # Should be list[ScoredSequence] or list[list]
|
| 365 |
logger.info(f"Processed batch {i+1}/{len(dl)}")
|
|
|
|
| 372 |
output_data = []
|
| 373 |
# Use sdf index columns + prediction results
|
| 374 |
index_cols = [col for col in config.index_columns if col in sdf.df.columns]
|
| 375 |
+
base_df_pd = sdf.df.select(index_cols).to_pandas() # Get base info
|
| 376 |
|
| 377 |
metrics_calc = Metrics(RESIDUE_SET, config.isotope_error_range)
|
| 378 |
|
| 379 |
for i, res in enumerate(results_list):
|
| 380 |
+
row_data = base_df_pd.iloc[i].to_dict() # Get corresponding input data
|
| 381 |
if isinstance(res, ScoredSequence) and res.sequence:
|
| 382 |
sequence_str = "".join(res.sequence)
|
| 383 |
row_data["prediction"] = sequence_str
|
| 384 |
row_data["log_probability"] = f"{res.sequence_log_probability:.4f}"
|
| 385 |
# Use metrics to calculate delta mass ppm for the top prediction
|
| 386 |
try:
|
| 387 |
+
_, delta_mass_list = metrics_calc.matches_precursor(
|
| 388 |
+
res.sequence,
|
| 389 |
+
row_data["precursor_mz"],
|
| 390 |
+
row_data["precursor_charge"],
|
| 391 |
+
)
|
| 392 |
+
# Find the smallest absolute ppm error across isotopes
|
| 393 |
+
min_abs_ppm = (
|
| 394 |
+
min(abs(p) for p in delta_mass_list)
|
| 395 |
+
if delta_mass_list
|
| 396 |
+
else float("nan")
|
| 397 |
+
)
|
| 398 |
+
row_data["delta_mass_ppm"] = f"{min_abs_ppm:.2f}"
|
| 399 |
except Exception as e:
|
| 400 |
logger.info(f"Warning: Could not calculate delta mass for prediction {i}: {e}")
|
| 401 |
row_data["delta_mass_ppm"] = "N/A"
|
|
|
|
| 409 |
output_df = pl.DataFrame(output_data)
|
| 410 |
|
| 411 |
# Ensure specific columns are present and ordered
|
| 412 |
+
display_cols = [
|
| 413 |
+
"scan_number",
|
| 414 |
+
"precursor_mz",
|
| 415 |
+
"precursor_charge",
|
| 416 |
+
"prediction",
|
| 417 |
+
"log_probability",
|
| 418 |
+
"delta_mass_ppm",
|
| 419 |
+
]
|
| 420 |
final_display_cols = []
|
| 421 |
for col in display_cols:
|
| 422 |
if col in output_df.columns:
|
|
|
|
| 431 |
|
| 432 |
output_df_display = output_df.select(final_display_cols)
|
| 433 |
|
|
|
|
| 434 |
# 7. Save full results to CSV
|
| 435 |
logger.info(f"Saving results to {output_csv_path}...")
|
| 436 |
output_df.write_csv(output_csv_path)
|
|
|
|
| 446 |
# Re-raise as Gradio error
|
| 447 |
raise gr.Error(f"Prediction failed: {e}")
|
| 448 |
|
| 449 |
+
|
| 450 |
# --- Gradio Interface ---
|
| 451 |
css = """
|
| 452 |
.gradio-container { font-family: sans-serif; }
|
|
|
|
| 456 |
.logo-container img { margin-bottom: 1rem; }
|
| 457 |
"""
|
| 458 |
|
| 459 |
+
with gr.Blocks(
|
| 460 |
+
css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")
|
| 461 |
+
) as demo:
|
| 462 |
# --- Logo Display ---
|
| 463 |
gr.Markdown(
|
| 464 |
"""
|
|
|
|
| 466 |
<img src='/gradio_api/file=assets/instanovo.svg' alt="InstaNovo Logo" width="300" style="display: block; margin: 0 auto;">
|
| 467 |
</div>
|
| 468 |
""",
|
| 469 |
+
elem_classes="logo-container", # Optional class for CSS targeting
|
| 470 |
)
|
| 471 |
|
| 472 |
# --- App Content ---
|
|
|
|
| 481 |
with gr.Column(scale=1):
|
| 482 |
input_file = gr.File(
|
| 483 |
label="Upload Mass Spectrometry File (.mgf, .mzml, .mzxml)",
|
| 484 |
+
file_types=[".mgf", ".mzml", ".mzxml"],
|
| 485 |
)
|
| 486 |
decoding_method = gr.Radio(
|
| 487 |
+
[
|
| 488 |
+
"Greedy Search (Fast, resonably accurate)",
|
| 489 |
+
"Knapsack Beam Search (More accurate, but slower)",
|
| 490 |
+
],
|
| 491 |
label="Decoding Method",
|
| 492 |
+
value="Greedy Search (Fast, resonably accurate)", # Default to fast method
|
| 493 |
)
|
| 494 |
submit_btn = gr.Button("Predict Sequences", variant="primary")
|
| 495 |
with gr.Column(scale=2):
|
| 496 |
+
output_df = gr.DataFrame(
|
| 497 |
+
label="Prediction Results",
|
| 498 |
+
headers=[
|
| 499 |
+
"scan_number",
|
| 500 |
+
"precursor_mz",
|
| 501 |
+
"precursor_charge",
|
| 502 |
+
"prediction",
|
| 503 |
+
"log_probability",
|
| 504 |
+
"delta_mass_ppm",
|
| 505 |
+
],
|
| 506 |
+
wrap=True,
|
| 507 |
+
)
|
| 508 |
output_file = gr.File(label="Download Full Results (CSV)")
|
| 509 |
|
| 510 |
submit_btn.click(
|
| 511 |
predict_peptides,
|
| 512 |
inputs=[input_file, decoding_method],
|
| 513 |
+
outputs=[output_df, output_file],
|
| 514 |
)
|
| 515 |
|
| 516 |
gr.Examples(
|
| 517 |
+
[
|
| 518 |
+
["assets/sample_spectra.mgf", "Greedy Search (Fast, resonably accurate)"],
|
| 519 |
+
[
|
| 520 |
+
"assets/sample_spectra.mgf",
|
| 521 |
+
"Knapsack Beam Search (More accurate, but slower)",
|
| 522 |
+
],
|
| 523 |
+
],
|
| 524 |
+
inputs=[input_file, decoding_method],
|
| 525 |
+
outputs=[output_df, output_file],
|
| 526 |
+
fn=predict_peptides,
|
| 527 |
+
cache_examples=False, # Re-run examples if needed
|
| 528 |
+
label="Example Usage",
|
| 529 |
)
|
| 530 |
|
| 531 |
gr.Markdown(
|
| 532 |
+
"""
|
| 533 |
**Notes:**
|
| 534 |
+
* Predictions are based on the [InstaNovo](https://github.com/instadeepai/InstaNovo) model `{MODEL_ID}`.
|
| 535 |
* Knapsack Beam Search uses pre-calculated mass constraints and yields better results but takes longer.
|
| 536 |
* `delta_mass_ppm` shows the lowest absolute precursor mass error (in ppm) across potential isotopes (0-1 neutron).
|
| 537 |
* Ensure your input file format is correctly specified. Large files may take time to process.
|
|
|
|
| 542 |
with gr.Accordion("Application Logs", open=True):
|
| 543 |
log_display = Log(log_file, dark=True, height=300)
|
| 544 |
|
| 545 |
+
gr.Textbox(
|
| 546 |
+
value="""
|
| 547 |
+
@article{eloff_kalogeropoulos_2025_instanovo,
|
| 548 |
+
title = {InstaNovo enables diffusion-powered de novo peptide sequencing in large-scale proteomics experiments},
|
| 549 |
+
author = {Kevin Eloff and Konstantinos Kalogeropoulos and Amandla Mabona and Oliver Morell and Rachel Catzel and
|
| 550 |
+
Esperanza Rivera-de-Torre and Jakob Berg Jespersen and Wesley Williams and Sam P. B. van Beljouw and
|
| 551 |
+
Marcin J. Skwark and Andreas Hougaard Laustsen and Stan J. J. Brouns and Anne Ljungars and Erwin M.
|
| 552 |
+
Schoof and Jeroen Van Goey and Ulrich auf dem Keller and Karim Beguir and Nicolas Lopez Carranza and
|
| 553 |
+
Timothy P. Jenkins},
|
| 554 |
+
year = 2025,
|
| 555 |
+
month = {Mar},
|
| 556 |
+
day = 31,
|
| 557 |
+
journal = {Nature Machine Intelligence},
|
| 558 |
+
doi = {10.1038/s42256-025-01019-5},
|
| 559 |
+
url = {https://www.nature.com/articles/s42256-025-01019-5}
|
| 560 |
+
}
|
| 561 |
+
""",
|
| 562 |
+
show_copy_button=True,
|
| 563 |
+
label="If you use InstaNovo in your research, please cite:",
|
| 564 |
+
interactive=False,
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
# --- Launch the App ---
|
| 568 |
if __name__ == "__main__":
|
| 569 |
# Set share=True for temporary public link if running locally
|
| 570 |
# Set server_name="0.0.0.0" to allow access from network if needed
|
| 571 |
# demo.launch(server_name="0.0.0.0", server_port=7860)
|
| 572 |
# For Hugging Face Spaces, just demo.launch() is usually sufficient
|
| 573 |
+
demo.launch(share=True) # For local testing with public URL
|