Spaces:
Running
on
Zero
Running
on
Zero
feat: add citation
Browse files
app.py
CHANGED
@@ -26,13 +26,15 @@ except ImportError as e:
|
|
26 |
raise ImportError("Failed to import InstaNovo components: {e}")
|
27 |
|
28 |
# --- Configuration ---
|
29 |
-
MODEL_ID = "instanovo-v1.1.0"
|
30 |
KNAPSACK_DIR = Path("./knapsack_cache")
|
31 |
-
DEFAULT_CONFIG_PATH = Path(
|
|
|
|
|
32 |
|
33 |
# Determine device
|
34 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
35 |
-
FP16 = DEVICE == "cuda"
|
36 |
|
37 |
# --- Global Variables (Load Model and Knapsack Once) ---
|
38 |
MODEL: InstaNovo | None = None
|
@@ -41,7 +43,7 @@ MODEL_CONFIG: DictConfig | None = None
|
|
41 |
RESIDUE_SET: ResidueSet | None = None
|
42 |
|
43 |
# Assets
|
44 |
-
gr.set_static_paths(paths=[Path.cwd().absolute()/"assets"])
|
45 |
|
46 |
|
47 |
def load_model_and_knapsack():
|
@@ -64,9 +66,9 @@ def load_model_and_knapsack():
|
|
64 |
|
65 |
# --- Knapsack Handling ---
|
66 |
knapsack_exists = (
|
67 |
-
(KNAPSACK_DIR / "parameters.pkl").exists()
|
68 |
-
(KNAPSACK_DIR / "masses.npy").exists()
|
69 |
-
(KNAPSACK_DIR / "chart.npy").exists()
|
70 |
)
|
71 |
|
72 |
if knapsack_exists:
|
@@ -76,51 +78,61 @@ def load_model_and_knapsack():
|
|
76 |
print("Knapsack loaded successfully.")
|
77 |
except Exception as e:
|
78 |
print(f"Error loading knapsack: {e}. Will attempt to regenerate.")
|
79 |
-
KNAPSACK = None
|
80 |
-
knapsack_exists = False
|
81 |
|
82 |
if not knapsack_exists:
|
83 |
print("Knapsack not found or failed to load. Generating knapsack...")
|
84 |
if RESIDUE_SET is None:
|
85 |
-
|
|
|
|
|
86 |
try:
|
87 |
# Prepare residue masses for knapsack generation (handle negative/zero masses)
|
88 |
residue_masses_knapsack = dict(RESIDUE_SET.residue_masses.copy())
|
89 |
-
negative_residues = [
|
|
|
|
|
90 |
if negative_residues:
|
91 |
-
print(
|
92 |
-
|
|
|
|
|
93 |
for res in negative_residues:
|
94 |
del residue_masses_knapsack[res]
|
95 |
# Remove special tokens explicitly if they somehow got mass
|
96 |
for special_token in RESIDUE_SET.special_tokens:
|
97 |
-
|
98 |
-
|
99 |
|
100 |
# Ensure residue indices used match those without special/negative masses
|
101 |
valid_residue_indices = {
|
102 |
-
res: idx
|
|
|
103 |
if res in residue_masses_knapsack
|
104 |
}
|
105 |
|
106 |
-
|
107 |
KNAPSACK = Knapsack.construct_knapsack(
|
108 |
residue_masses=residue_masses_knapsack,
|
109 |
-
residue_indices=valid_residue_indices,
|
110 |
max_mass=MAX_MASS,
|
111 |
mass_scale=MASS_SCALE,
|
112 |
)
|
113 |
print(f"Knapsack generated. Saving to {KNAPSACK_DIR}...")
|
114 |
-
KNAPSACK.save(str(KNAPSACK_DIR))
|
115 |
print("Knapsack saved.")
|
116 |
except Exception as e:
|
117 |
print(f"Error generating or saving knapsack: {e}")
|
118 |
-
gr.Warning(
|
119 |
-
|
|
|
|
|
|
|
120 |
|
121 |
# Load the model and knapsack when the script starts
|
122 |
load_model_and_knapsack()
|
123 |
|
|
|
124 |
def create_inference_config(
|
125 |
input_path: str,
|
126 |
output_path: str,
|
@@ -129,53 +141,72 @@ def create_inference_config(
|
|
129 |
"""Creates the OmegaConf DictConfig needed for prediction."""
|
130 |
# Load default config if available, otherwise create from scratch
|
131 |
if DEFAULT_CONFIG_PATH.exists():
|
132 |
-
|
133 |
else:
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
# Override specific parameters
|
181 |
cfg_overrides = {
|
@@ -192,7 +223,9 @@ def create_inference_config(
|
|
192 |
cfg_overrides["use_knapsack"] = False
|
193 |
elif "Knapsack" in decoding_method:
|
194 |
if KNAPSACK is None:
|
195 |
-
raise gr.Error(
|
|
|
|
|
196 |
cfg_overrides["num_beams"] = 5
|
197 |
cfg_overrides["use_knapsack"] = True
|
198 |
cfg_overrides["knapsack_path"] = str(KNAPSACK_DIR)
|
@@ -209,14 +242,14 @@ def predict_peptides(input_file, decoding_method):
|
|
209 |
Main function to load data, run prediction, and return results.
|
210 |
"""
|
211 |
if MODEL is None or RESIDUE_SET is None or MODEL_CONFIG is None:
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
|
216 |
if input_file is None:
|
217 |
raise gr.Error("Please upload a mass spectrometry file.")
|
218 |
|
219 |
-
input_path = input_file.name
|
220 |
print(f"Processing file: {input_path}")
|
221 |
print(f"Using decoding method: {decoding_method}")
|
222 |
|
@@ -234,23 +267,28 @@ def predict_peptides(input_file, decoding_method):
|
|
234 |
try:
|
235 |
sdf = SpectrumDataFrame.load(
|
236 |
config.data_path,
|
237 |
-
lazy=False,
|
238 |
-
is_annotated=False,
|
239 |
column_mapping=config.get("column_map", None),
|
240 |
shuffle=False,
|
241 |
-
verbose=True
|
242 |
)
|
243 |
# Apply charge filter like in CLI
|
244 |
original_size = len(sdf)
|
245 |
max_charge = config.get("max_charge", 10)
|
246 |
sdf.filter_rows(
|
247 |
-
lambda row: (row["precursor_charge"] <= max_charge)
|
|
|
248 |
)
|
249 |
if len(sdf) < original_size:
|
250 |
-
print(
|
|
|
|
|
251 |
|
252 |
if len(sdf) == 0:
|
253 |
-
|
|
|
|
|
254 |
print(f"Data loaded: {len(sdf)} spectra.")
|
255 |
except Exception as e:
|
256 |
print(f"Error loading data: {e}")
|
@@ -261,16 +299,17 @@ def predict_peptides(input_file, decoding_method):
|
|
261 |
sdf,
|
262 |
RESIDUE_SET,
|
263 |
MODEL_CONFIG.get("n_peaks", 200),
|
264 |
-
return_str=True,
|
265 |
annotated=False,
|
266 |
-
pad_spectrum_max_length=config.get("compile_model", False)
|
|
|
267 |
bin_spectra=config.get("conv_peak_encoder", False),
|
268 |
)
|
269 |
dl = DataLoader(
|
270 |
ds,
|
271 |
batch_size=config.batch_size,
|
272 |
-
num_workers=0,
|
273 |
-
shuffle=False,
|
274 |
collate_fn=collate_batch,
|
275 |
)
|
276 |
|
@@ -279,38 +318,51 @@ def predict_peptides(input_file, decoding_method):
|
|
279 |
decoder: Decoder
|
280 |
if config.use_knapsack:
|
281 |
if KNAPSACK is None:
|
282 |
-
|
283 |
-
|
|
|
|
|
284 |
# KnapsackBeamSearchDecoder doesn't directly load from path in this version?
|
285 |
# We load Knapsack globally, so just pass it.
|
286 |
# If it needed path: decoder = KnapsackBeamSearchDecoder.from_file(model=MODEL, path=config.knapsack_path)
|
287 |
decoder = KnapsackBeamSearchDecoder(model=MODEL, knapsack=KNAPSACK)
|
288 |
elif config.num_beams > 1:
|
289 |
-
|
290 |
-
|
291 |
-
|
|
|
|
|
292 |
else:
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
|
|
|
|
300 |
print(f"Using decoder: {type(decoder).__name__}")
|
301 |
|
302 |
# 5. Run Prediction Loop (Adapted from instanovo/transformer/predict.py)
|
303 |
print("Starting prediction...")
|
304 |
start_time = time.time()
|
305 |
-
results_list: list[
|
|
|
|
|
306 |
|
307 |
for i, batch in enumerate(dl):
|
308 |
-
spectra, precursors, spectra_mask, _, _ =
|
|
|
|
|
309 |
spectra = spectra.to(DEVICE)
|
310 |
precursors = precursors.to(DEVICE)
|
311 |
spectra_mask = spectra_mask.to(DEVICE)
|
312 |
|
313 |
-
with
|
|
|
|
|
|
|
314 |
# Beam search decoder might return list[list[ScoredSequence]] if return_beam=True
|
315 |
# Greedy decoder returns list[ScoredSequence]
|
316 |
# KnapsackBeamSearchDecoder returns list[ScoredSequence] or list[list[ScoredSequence]]
|
@@ -320,12 +372,17 @@ def predict_peptides(input_file, decoding_method):
|
|
320 |
beam_size=config.num_beams,
|
321 |
max_length=config.max_length,
|
322 |
# Knapsack/Beam Search specific params if needed
|
323 |
-
mass_tolerance=config.get("filter_precursor_ppm", 20)
|
324 |
-
|
325 |
-
|
|
|
|
|
|
|
326 |
)
|
327 |
-
results_list.extend(
|
328 |
-
|
|
|
|
|
329 |
|
330 |
end_time = time.time()
|
331 |
print(f"Prediction finished in {end_time - start_time:.2f} seconds.")
|
@@ -335,29 +392,35 @@ def predict_peptides(input_file, decoding_method):
|
|
335 |
output_data = []
|
336 |
# Use sdf index columns + prediction results
|
337 |
index_cols = [col for col in config.index_columns if col in sdf.df.columns]
|
338 |
-
base_df_pd = sdf.df.select(index_cols).to_pandas()
|
339 |
|
340 |
metrics_calc = Metrics(RESIDUE_SET, config.isotope_error_range)
|
341 |
|
342 |
for i, res in enumerate(results_list):
|
343 |
-
row_data = base_df_pd.iloc[i].to_dict()
|
344 |
if isinstance(res, ScoredSequence) and res.sequence:
|
345 |
sequence_str = "".join(res.sequence)
|
346 |
row_data["prediction"] = sequence_str
|
347 |
row_data["log_probability"] = f"{res.sequence_log_probability:.4f}"
|
348 |
# Use metrics to calculate delta mass ppm for the top prediction
|
349 |
try:
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
|
|
|
|
|
|
|
|
358 |
except Exception as e:
|
359 |
-
|
360 |
-
|
|
|
|
|
361 |
|
362 |
else:
|
363 |
row_data["prediction"] = ""
|
@@ -368,13 +431,20 @@ def predict_peptides(input_file, decoding_method):
|
|
368 |
output_df = pl.DataFrame(output_data)
|
369 |
|
370 |
# Ensure specific columns are present and ordered
|
371 |
-
display_cols = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
372 |
final_display_cols = []
|
373 |
for col in display_cols:
|
374 |
if col in output_df.columns:
|
375 |
final_display_cols.append(col)
|
376 |
else:
|
377 |
-
|
378 |
|
379 |
# Add any remaining index columns that weren't in display_cols
|
380 |
for col in index_cols:
|
@@ -383,7 +453,6 @@ def predict_peptides(input_file, decoding_method):
|
|
383 |
|
384 |
output_df_display = output_df.select(final_display_cols)
|
385 |
|
386 |
-
|
387 |
# 7. Save full results to CSV
|
388 |
print(f"Saving results to {output_csv_path}...")
|
389 |
output_df.write_csv(output_csv_path)
|
@@ -399,6 +468,7 @@ def predict_peptides(input_file, decoding_method):
|
|
399 |
# Re-raise as Gradio error
|
400 |
raise gr.Error(f"Prediction failed: {e}")
|
401 |
|
|
|
402 |
# --- Gradio Interface ---
|
403 |
css = """
|
404 |
.gradio-container { font-family: sans-serif; }
|
@@ -408,7 +478,9 @@ footer { display: none !important; }
|
|
408 |
.logo-container img { margin-bottom: 1rem; }
|
409 |
"""
|
410 |
|
411 |
-
with gr.Blocks(
|
|
|
|
|
412 |
# --- Logo Display ---
|
413 |
gr.Markdown(
|
414 |
"""
|
@@ -416,7 +488,7 @@ with gr.Blocks(css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hu
|
|
416 |
<img src='/gradio_api/file=assets/instanovo.svg' alt="InstaNovo Logo" width="300" style="display: block; margin: 0 auto;">
|
417 |
</div>
|
418 |
""",
|
419 |
-
elem_classes="logo-container"
|
420 |
)
|
421 |
|
422 |
# --- App Content ---
|
@@ -431,36 +503,55 @@ with gr.Blocks(css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hu
|
|
431 |
with gr.Column(scale=1):
|
432 |
input_file = gr.File(
|
433 |
label="Upload Mass Spectrometry File (.mgf, .mzml, .mzxml)",
|
434 |
-
file_types=[".mgf", ".mzml", ".mzxml"]
|
435 |
)
|
436 |
decoding_method = gr.Radio(
|
437 |
-
[
|
|
|
|
|
|
|
438 |
label="Decoding Method",
|
439 |
-
value="Greedy Search (Fast, resonably accurate)"
|
440 |
)
|
441 |
submit_btn = gr.Button("Predict Sequences", variant="primary")
|
442 |
with gr.Column(scale=2):
|
443 |
-
output_df = gr.DataFrame(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
444 |
output_file = gr.File(label="Download Full Results (CSV)")
|
445 |
|
446 |
submit_btn.click(
|
447 |
predict_peptides,
|
448 |
inputs=[input_file, decoding_method],
|
449 |
-
outputs=[output_df, output_file]
|
450 |
)
|
451 |
|
452 |
gr.Examples(
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
|
|
|
|
|
|
|
|
|
|
460 |
)
|
461 |
|
462 |
gr.Markdown(
|
463 |
-
|
464 |
**Notes:**
|
465 |
* Predictions are based on the [InstaNovo](https://github.com/instadeepai/InstaNovo) model ({MODEL_ID}).
|
466 |
* Knapsack Beam Search uses pre-calculated mass constraints and yields better results but takes longer.
|
@@ -469,10 +560,32 @@ with gr.Blocks(css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hu
|
|
469 |
""".format(MODEL_ID=MODEL_ID)
|
470 |
)
|
471 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
472 |
# --- Launch the App ---
|
473 |
if __name__ == "__main__":
|
474 |
# Set share=True for temporary public link if running locally
|
475 |
# Set server_name="0.0.0.0" to allow access from network if needed
|
476 |
# demo.launch(server_name="0.0.0.0", server_port=7860)
|
477 |
# For Hugging Face Spaces, just demo.launch() is usually sufficient
|
478 |
-
demo.launch(share=True)
|
|
|
26 |
raise ImportError("Failed to import InstaNovo components: {e}")
|
27 |
|
28 |
# --- Configuration ---
|
29 |
+
MODEL_ID = "instanovo-v1.1.0" # Use the desired pretrained model ID
|
30 |
KNAPSACK_DIR = Path("./knapsack_cache")
|
31 |
+
DEFAULT_CONFIG_PATH = Path(
|
32 |
+
"./configs/inference/default.yaml"
|
33 |
+
) # Assuming instanovo installs configs locally relative to execution
|
34 |
|
35 |
# Determine device
|
36 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
37 |
+
FP16 = DEVICE == "cuda" # Enable FP16 only on CUDA
|
38 |
|
39 |
# --- Global Variables (Load Model and Knapsack Once) ---
|
40 |
MODEL: InstaNovo | None = None
|
|
|
43 |
RESIDUE_SET: ResidueSet | None = None
|
44 |
|
45 |
# Assets
|
46 |
+
gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"])
|
47 |
|
48 |
|
49 |
def load_model_and_knapsack():
|
|
|
66 |
|
67 |
# --- Knapsack Handling ---
|
68 |
knapsack_exists = (
|
69 |
+
(KNAPSACK_DIR / "parameters.pkl").exists()
|
70 |
+
and (KNAPSACK_DIR / "masses.npy").exists()
|
71 |
+
and (KNAPSACK_DIR / "chart.npy").exists()
|
72 |
)
|
73 |
|
74 |
if knapsack_exists:
|
|
|
78 |
print("Knapsack loaded successfully.")
|
79 |
except Exception as e:
|
80 |
print(f"Error loading knapsack: {e}. Will attempt to regenerate.")
|
81 |
+
KNAPSACK = None # Force regeneration
|
82 |
+
knapsack_exists = False # Ensure generation happens
|
83 |
|
84 |
if not knapsack_exists:
|
85 |
print("Knapsack not found or failed to load. Generating knapsack...")
|
86 |
if RESIDUE_SET is None:
|
87 |
+
raise gr.Error(
|
88 |
+
"Cannot generate knapsack because ResidueSet failed to load."
|
89 |
+
)
|
90 |
try:
|
91 |
# Prepare residue masses for knapsack generation (handle negative/zero masses)
|
92 |
residue_masses_knapsack = dict(RESIDUE_SET.residue_masses.copy())
|
93 |
+
negative_residues = [
|
94 |
+
k for k, v in residue_masses_knapsack.items() if v <= 0
|
95 |
+
]
|
96 |
if negative_residues:
|
97 |
+
print(
|
98 |
+
f"Warning: Non-positive masses found in residues: {negative_residues}. "
|
99 |
+
"Excluding from knapsack generation."
|
100 |
+
)
|
101 |
for res in negative_residues:
|
102 |
del residue_masses_knapsack[res]
|
103 |
# Remove special tokens explicitly if they somehow got mass
|
104 |
for special_token in RESIDUE_SET.special_tokens:
|
105 |
+
if special_token in residue_masses_knapsack:
|
106 |
+
del residue_masses_knapsack[special_token]
|
107 |
|
108 |
# Ensure residue indices used match those without special/negative masses
|
109 |
valid_residue_indices = {
|
110 |
+
res: idx
|
111 |
+
for res, idx in RESIDUE_SET.residue_to_index.items()
|
112 |
if res in residue_masses_knapsack
|
113 |
}
|
114 |
|
|
|
115 |
KNAPSACK = Knapsack.construct_knapsack(
|
116 |
residue_masses=residue_masses_knapsack,
|
117 |
+
residue_indices=valid_residue_indices, # Use only valid indices
|
118 |
max_mass=MAX_MASS,
|
119 |
mass_scale=MASS_SCALE,
|
120 |
)
|
121 |
print(f"Knapsack generated. Saving to {KNAPSACK_DIR}...")
|
122 |
+
KNAPSACK.save(str(KNAPSACK_DIR)) # Save for future runs
|
123 |
print("Knapsack saved.")
|
124 |
except Exception as e:
|
125 |
print(f"Error generating or saving knapsack: {e}")
|
126 |
+
gr.Warning(
|
127 |
+
"Failed to generate Knapsack. Knapsack Beam Search will not be available. {e}"
|
128 |
+
)
|
129 |
+
KNAPSACK = None # Ensure it's None if generation failed
|
130 |
+
|
131 |
|
132 |
# Load the model and knapsack when the script starts
|
133 |
load_model_and_knapsack()
|
134 |
|
135 |
+
|
136 |
def create_inference_config(
|
137 |
input_path: str,
|
138 |
output_path: str,
|
|
|
141 |
"""Creates the OmegaConf DictConfig needed for prediction."""
|
142 |
# Load default config if available, otherwise create from scratch
|
143 |
if DEFAULT_CONFIG_PATH.exists():
|
144 |
+
base_cfg = OmegaConf.load(DEFAULT_CONFIG_PATH)
|
145 |
else:
|
146 |
+
print(
|
147 |
+
f"Warning: Default config not found at {DEFAULT_CONFIG_PATH}. Using minimal config."
|
148 |
+
)
|
149 |
+
# Create a minimal config if default is missing
|
150 |
+
base_cfg = OmegaConf.create(
|
151 |
+
{
|
152 |
+
"data_path": None,
|
153 |
+
"instanovo_model": MODEL_ID,
|
154 |
+
"output_path": None,
|
155 |
+
"knapsack_path": str(KNAPSACK_DIR),
|
156 |
+
"denovo": True,
|
157 |
+
"refine": False, # Not doing refinement here
|
158 |
+
"num_beams": 1,
|
159 |
+
"max_length": 40,
|
160 |
+
"max_charge": 10,
|
161 |
+
"isotope_error_range": [0, 1],
|
162 |
+
"subset": 1.0,
|
163 |
+
"use_knapsack": False,
|
164 |
+
"save_beams": False,
|
165 |
+
"batch_size": 64, # Adjust as needed
|
166 |
+
"device": DEVICE,
|
167 |
+
"fp16": FP16,
|
168 |
+
"log_interval": 500, # Less relevant for Gradio app
|
169 |
+
"use_basic_logging": True,
|
170 |
+
"filter_precursor_ppm": 20,
|
171 |
+
"filter_confidence": 1e-4,
|
172 |
+
"filter_fdr_threshold": 0.05,
|
173 |
+
"residue_remapping": { # Add default mappings
|
174 |
+
"M(ox)": "M[UNIMOD:35]",
|
175 |
+
"M(+15.99)": "M[UNIMOD:35]",
|
176 |
+
"S(p)": "S[UNIMOD:21]",
|
177 |
+
"T(p)": "T[UNIMOD:21]",
|
178 |
+
"Y(p)": "Y[UNIMOD:21]",
|
179 |
+
"S(+79.97)": "S[UNIMOD:21]",
|
180 |
+
"T(+79.97)": "T[UNIMOD:21]",
|
181 |
+
"Y(+79.97)": "Y[UNIMOD:21]",
|
182 |
+
"Q(+0.98)": "Q[UNIMOD:7]",
|
183 |
+
"N(+0.98)": "N[UNIMOD:7]",
|
184 |
+
"Q(+.98)": "Q[UNIMOD:7]",
|
185 |
+
"N(+.98)": "N[UNIMOD:7]",
|
186 |
+
"C(+57.02)": "C[UNIMOD:4]",
|
187 |
+
"(+42.01)": "[UNIMOD:1]",
|
188 |
+
"(+43.01)": "[UNIMOD:5]",
|
189 |
+
"(-17.03)": "[UNIMOD:385]",
|
190 |
+
},
|
191 |
+
"column_map": { # Add default mappings
|
192 |
+
"Modified sequence": "modified_sequence",
|
193 |
+
"MS/MS m/z": "precursor_mz",
|
194 |
+
"Mass": "precursor_mass",
|
195 |
+
"Charge": "precursor_charge",
|
196 |
+
"Mass values": "mz_array",
|
197 |
+
"Mass spectrum": "mz_array",
|
198 |
+
"Intensity": "intensity_array",
|
199 |
+
"Raw intensity spectrum": "intensity_array",
|
200 |
+
"Scan number": "scan_number",
|
201 |
+
},
|
202 |
+
"index_columns": [
|
203 |
+
"scan_number",
|
204 |
+
"precursor_mz",
|
205 |
+
"precursor_charge",
|
206 |
+
],
|
207 |
+
# Add other defaults if needed based on errors
|
208 |
+
}
|
209 |
+
)
|
210 |
|
211 |
# Override specific parameters
|
212 |
cfg_overrides = {
|
|
|
223 |
cfg_overrides["use_knapsack"] = False
|
224 |
elif "Knapsack" in decoding_method:
|
225 |
if KNAPSACK is None:
|
226 |
+
raise gr.Error(
|
227 |
+
"Knapsack is not available. Cannot use Knapsack Beam Search."
|
228 |
+
)
|
229 |
cfg_overrides["num_beams"] = 5
|
230 |
cfg_overrides["use_knapsack"] = True
|
231 |
cfg_overrides["knapsack_path"] = str(KNAPSACK_DIR)
|
|
|
242 |
Main function to load data, run prediction, and return results.
|
243 |
"""
|
244 |
if MODEL is None or RESIDUE_SET is None or MODEL_CONFIG is None:
|
245 |
+
load_model_and_knapsack() # Attempt to reload if None (e.g., after space restart)
|
246 |
+
if MODEL is None:
|
247 |
+
raise gr.Error("InstaNovo model is not loaded. Cannot perform prediction.")
|
248 |
|
249 |
if input_file is None:
|
250 |
raise gr.Error("Please upload a mass spectrometry file.")
|
251 |
|
252 |
+
input_path = input_file.name # Gradio provides the path in .name
|
253 |
print(f"Processing file: {input_path}")
|
254 |
print(f"Using decoding method: {decoding_method}")
|
255 |
|
|
|
267 |
try:
|
268 |
sdf = SpectrumDataFrame.load(
|
269 |
config.data_path,
|
270 |
+
lazy=False, # Load eagerly for Gradio simplicity
|
271 |
+
is_annotated=False, # De novo mode
|
272 |
column_mapping=config.get("column_map", None),
|
273 |
shuffle=False,
|
274 |
+
verbose=True, # Print loading logs
|
275 |
)
|
276 |
# Apply charge filter like in CLI
|
277 |
original_size = len(sdf)
|
278 |
max_charge = config.get("max_charge", 10)
|
279 |
sdf.filter_rows(
|
280 |
+
lambda row: (row["precursor_charge"] <= max_charge)
|
281 |
+
and (row["precursor_charge"] > 0)
|
282 |
)
|
283 |
if len(sdf) < original_size:
|
284 |
+
print(
|
285 |
+
f"Warning: Filtered {original_size - len(sdf)} spectra with charge > {max_charge} or <= 0."
|
286 |
+
)
|
287 |
|
288 |
if len(sdf) == 0:
|
289 |
+
raise gr.Error(
|
290 |
+
"No valid spectra found in the uploaded file after filtering."
|
291 |
+
)
|
292 |
print(f"Data loaded: {len(sdf)} spectra.")
|
293 |
except Exception as e:
|
294 |
print(f"Error loading data: {e}")
|
|
|
299 |
sdf,
|
300 |
RESIDUE_SET,
|
301 |
MODEL_CONFIG.get("n_peaks", 200),
|
302 |
+
return_str=True, # Needed for greedy/beam search targets later (though not used here)
|
303 |
annotated=False,
|
304 |
+
pad_spectrum_max_length=config.get("compile_model", False)
|
305 |
+
or config.get("use_flash_attention", False),
|
306 |
bin_spectra=config.get("conv_peak_encoder", False),
|
307 |
)
|
308 |
dl = DataLoader(
|
309 |
ds,
|
310 |
batch_size=config.batch_size,
|
311 |
+
num_workers=0, # Required by SpectrumDataFrame
|
312 |
+
shuffle=False, # Required by SpectrumDataFrame
|
313 |
collate_fn=collate_batch,
|
314 |
)
|
315 |
|
|
|
318 |
decoder: Decoder
|
319 |
if config.use_knapsack:
|
320 |
if KNAPSACK is None:
|
321 |
+
# This check should ideally be earlier, but double-check
|
322 |
+
raise gr.Error(
|
323 |
+
"Knapsack is required for Knapsack Beam Search but is not available."
|
324 |
+
)
|
325 |
# KnapsackBeamSearchDecoder doesn't directly load from path in this version?
|
326 |
# We load Knapsack globally, so just pass it.
|
327 |
# If it needed path: decoder = KnapsackBeamSearchDecoder.from_file(model=MODEL, path=config.knapsack_path)
|
328 |
decoder = KnapsackBeamSearchDecoder(model=MODEL, knapsack=KNAPSACK)
|
329 |
elif config.num_beams > 1:
|
330 |
+
# BeamSearchDecoder is available but not explicitly requested, use Greedy for num_beams=1
|
331 |
+
print(
|
332 |
+
f"Warning: num_beams={config.num_beams} > 1 but only Greedy and Knapsack Beam Search are implemented in this app. Defaulting to Greedy."
|
333 |
+
)
|
334 |
+
decoder = GreedyDecoder(model=MODEL, mass_scale=MASS_SCALE)
|
335 |
else:
|
336 |
+
decoder = GreedyDecoder(
|
337 |
+
model=MODEL,
|
338 |
+
mass_scale=MASS_SCALE,
|
339 |
+
# Add suppression options if needed from config
|
340 |
+
suppressed_residues=config.get("suppressed_residues", None),
|
341 |
+
disable_terminal_residues_anywhere=config.get(
|
342 |
+
"disable_terminal_residues_anywhere", True
|
343 |
+
),
|
344 |
+
)
|
345 |
print(f"Using decoder: {type(decoder).__name__}")
|
346 |
|
347 |
# 5. Run Prediction Loop (Adapted from instanovo/transformer/predict.py)
|
348 |
print("Starting prediction...")
|
349 |
start_time = time.time()
|
350 |
+
results_list: list[
|
351 |
+
ScoredSequence | list
|
352 |
+
] = [] # Store ScoredSequence or empty list
|
353 |
|
354 |
for i, batch in enumerate(dl):
|
355 |
+
spectra, precursors, spectra_mask, _, _ = (
|
356 |
+
batch # Ignore peptides/masks for de novo
|
357 |
+
)
|
358 |
spectra = spectra.to(DEVICE)
|
359 |
precursors = precursors.to(DEVICE)
|
360 |
spectra_mask = spectra_mask.to(DEVICE)
|
361 |
|
362 |
+
with (
|
363 |
+
torch.no_grad(),
|
364 |
+
torch.amp.autocast(DEVICE, dtype=torch.float16, enabled=FP16),
|
365 |
+
):
|
366 |
# Beam search decoder might return list[list[ScoredSequence]] if return_beam=True
|
367 |
# Greedy decoder returns list[ScoredSequence]
|
368 |
# KnapsackBeamSearchDecoder returns list[ScoredSequence] or list[list[ScoredSequence]]
|
|
|
372 |
beam_size=config.num_beams,
|
373 |
max_length=config.max_length,
|
374 |
# Knapsack/Beam Search specific params if needed
|
375 |
+
mass_tolerance=config.get("filter_precursor_ppm", 20)
|
376 |
+
* 1e-6, # Convert ppm to relative
|
377 |
+
max_isotope=config.isotope_error_range[1]
|
378 |
+
if config.isotope_error_range
|
379 |
+
else 1,
|
380 |
+
return_beam=False, # Only get the top prediction for simplicity
|
381 |
)
|
382 |
+
results_list.extend(
|
383 |
+
batch_predictions
|
384 |
+
) # Should be list[ScoredSequence] or list[list]
|
385 |
+
print(f"Processed batch {i + 1}/{len(dl)}")
|
386 |
|
387 |
end_time = time.time()
|
388 |
print(f"Prediction finished in {end_time - start_time:.2f} seconds.")
|
|
|
392 |
output_data = []
|
393 |
# Use sdf index columns + prediction results
|
394 |
index_cols = [col for col in config.index_columns if col in sdf.df.columns]
|
395 |
+
base_df_pd = sdf.df.select(index_cols).to_pandas() # Get base info
|
396 |
|
397 |
metrics_calc = Metrics(RESIDUE_SET, config.isotope_error_range)
|
398 |
|
399 |
for i, res in enumerate(results_list):
|
400 |
+
row_data = base_df_pd.iloc[i].to_dict() # Get corresponding input data
|
401 |
if isinstance(res, ScoredSequence) and res.sequence:
|
402 |
sequence_str = "".join(res.sequence)
|
403 |
row_data["prediction"] = sequence_str
|
404 |
row_data["log_probability"] = f"{res.sequence_log_probability:.4f}"
|
405 |
# Use metrics to calculate delta mass ppm for the top prediction
|
406 |
try:
|
407 |
+
_, delta_mass_list = metrics_calc.matches_precursor(
|
408 |
+
res.sequence,
|
409 |
+
row_data["precursor_mz"],
|
410 |
+
row_data["precursor_charge"],
|
411 |
+
)
|
412 |
+
# Find the smallest absolute ppm error across isotopes
|
413 |
+
min_abs_ppm = (
|
414 |
+
min(abs(p) for p in delta_mass_list)
|
415 |
+
if delta_mass_list
|
416 |
+
else float("nan")
|
417 |
+
)
|
418 |
+
row_data["delta_mass_ppm"] = f"{min_abs_ppm:.2f}"
|
419 |
except Exception as e:
|
420 |
+
print(
|
421 |
+
f"Warning: Could not calculate delta mass for prediction {i}: {e}"
|
422 |
+
)
|
423 |
+
row_data["delta_mass_ppm"] = "N/A"
|
424 |
|
425 |
else:
|
426 |
row_data["prediction"] = ""
|
|
|
431 |
output_df = pl.DataFrame(output_data)
|
432 |
|
433 |
# Ensure specific columns are present and ordered
|
434 |
+
display_cols = [
|
435 |
+
"scan_number",
|
436 |
+
"precursor_mz",
|
437 |
+
"precursor_charge",
|
438 |
+
"prediction",
|
439 |
+
"log_probability",
|
440 |
+
"delta_mass_ppm",
|
441 |
+
]
|
442 |
final_display_cols = []
|
443 |
for col in display_cols:
|
444 |
if col in output_df.columns:
|
445 |
final_display_cols.append(col)
|
446 |
else:
|
447 |
+
print(f"Warning: Expected display column '{col}' not found in results.")
|
448 |
|
449 |
# Add any remaining index columns that weren't in display_cols
|
450 |
for col in index_cols:
|
|
|
453 |
|
454 |
output_df_display = output_df.select(final_display_cols)
|
455 |
|
|
|
456 |
# 7. Save full results to CSV
|
457 |
print(f"Saving results to {output_csv_path}...")
|
458 |
output_df.write_csv(output_csv_path)
|
|
|
468 |
# Re-raise as Gradio error
|
469 |
raise gr.Error(f"Prediction failed: {e}")
|
470 |
|
471 |
+
|
472 |
# --- Gradio Interface ---
|
473 |
css = """
|
474 |
.gradio-container { font-family: sans-serif; }
|
|
|
478 |
.logo-container img { margin-bottom: 1rem; }
|
479 |
"""
|
480 |
|
481 |
+
with gr.Blocks(
|
482 |
+
css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")
|
483 |
+
) as demo:
|
484 |
# --- Logo Display ---
|
485 |
gr.Markdown(
|
486 |
"""
|
|
|
488 |
<img src='/gradio_api/file=assets/instanovo.svg' alt="InstaNovo Logo" width="300" style="display: block; margin: 0 auto;">
|
489 |
</div>
|
490 |
""",
|
491 |
+
elem_classes="logo-container", # Optional class for CSS targeting
|
492 |
)
|
493 |
|
494 |
# --- App Content ---
|
|
|
503 |
with gr.Column(scale=1):
|
504 |
input_file = gr.File(
|
505 |
label="Upload Mass Spectrometry File (.mgf, .mzml, .mzxml)",
|
506 |
+
file_types=[".mgf", ".mzml", ".mzxml"],
|
507 |
)
|
508 |
decoding_method = gr.Radio(
|
509 |
+
[
|
510 |
+
"Greedy Search (Fast, resonably accurate)",
|
511 |
+
"Knapsack Beam Search (More accurate, but slower)",
|
512 |
+
],
|
513 |
label="Decoding Method",
|
514 |
+
value="Greedy Search (Fast, resonably accurate)", # Default to fast method
|
515 |
)
|
516 |
submit_btn = gr.Button("Predict Sequences", variant="primary")
|
517 |
with gr.Column(scale=2):
|
518 |
+
output_df = gr.DataFrame(
|
519 |
+
label="Prediction Results",
|
520 |
+
headers=[
|
521 |
+
"scan_number",
|
522 |
+
"precursor_mz",
|
523 |
+
"precursor_charge",
|
524 |
+
"prediction",
|
525 |
+
"log_probability",
|
526 |
+
"delta_mass_ppm",
|
527 |
+
],
|
528 |
+
wrap=True,
|
529 |
+
)
|
530 |
output_file = gr.File(label="Download Full Results (CSV)")
|
531 |
|
532 |
submit_btn.click(
|
533 |
predict_peptides,
|
534 |
inputs=[input_file, decoding_method],
|
535 |
+
outputs=[output_df, output_file],
|
536 |
)
|
537 |
|
538 |
gr.Examples(
|
539 |
+
[
|
540 |
+
["assets/sample_spectra.mgf", "Greedy Search (Fast, resonably accurate)"],
|
541 |
+
[
|
542 |
+
"assets/sample_spectra.mgf",
|
543 |
+
"Knapsack Beam Search (More accurate, but slower)",
|
544 |
+
],
|
545 |
+
],
|
546 |
+
inputs=[input_file, decoding_method],
|
547 |
+
outputs=[output_df, output_file],
|
548 |
+
fn=predict_peptides,
|
549 |
+
cache_examples=False, # Re-run examples if needed
|
550 |
+
label="Example Usage",
|
551 |
)
|
552 |
|
553 |
gr.Markdown(
|
554 |
+
"""
|
555 |
**Notes:**
|
556 |
* Predictions are based on the [InstaNovo](https://github.com/instadeepai/InstaNovo) model ({MODEL_ID}).
|
557 |
* Knapsack Beam Search uses pre-calculated mass constraints and yields better results but takes longer.
|
|
|
560 |
""".format(MODEL_ID=MODEL_ID)
|
561 |
)
|
562 |
|
563 |
+
gr.Textbox(
|
564 |
+
value="""
|
565 |
+
@article{eloff_kalogeropoulos_2025_instanovo,
|
566 |
+
title = {InstaNovo enables diffusion-powered de novo peptide sequencing in large-scale proteomics experiments},
|
567 |
+
author = {Kevin Eloff and Konstantinos Kalogeropoulos and Amandla Mabona and Oliver Morell and Rachel Catzel and
|
568 |
+
Esperanza Rivera-de-Torre and Jakob Berg Jespersen and Wesley Williams and Sam P. B. van Beljouw and
|
569 |
+
Marcin J. Skwark and Andreas Hougaard Laustsen and Stan J. J. Brouns and Anne Ljungars and Erwin M.
|
570 |
+
Schoof and Jeroen Van Goey and Ulrich auf dem Keller and Karim Beguir and Nicolas Lopez Carranza and
|
571 |
+
Timothy P. Jenkins},
|
572 |
+
year = 2025,
|
573 |
+
month = {Mar},
|
574 |
+
day = 31,
|
575 |
+
journal = {Nature Machine Intelligence},
|
576 |
+
doi = {10.1038/s42256-025-01019-5},
|
577 |
+
url = {https://www.nature.com/articles/s42256-025-01019-5}
|
578 |
+
}
|
579 |
+
""",
|
580 |
+
show_copy_button=True,
|
581 |
+
label="If you use InstaNovo in your research, please cite:",
|
582 |
+
interactive=False,
|
583 |
+
)
|
584 |
+
|
585 |
# --- Launch the App ---
|
586 |
if __name__ == "__main__":
|
587 |
# Set share=True for temporary public link if running locally
|
588 |
# Set server_name="0.0.0.0" to allow access from network if needed
|
589 |
# demo.launch(server_name="0.0.0.0", server_port=7860)
|
590 |
# For Hugging Face Spaces, just demo.launch() is usually sufficient
|
591 |
+
demo.launch(share=True) # For local testing with public URL
|