from dataloader import CellLoader def run_sequence_prediction( sequence_input, nucleus_image, protein_image, model, device ): """ Run Celle model with provided inputs and display results. :param sequence: Path to sequence file :param nucleus_image_path: Path to nucleus image :param protein_image_path: Path to protein image (optional) :param model_ckpt_path: Path to model checkpoint :param model_config_path: Path to model config """ # Instantiate dataset object dataset = CellLoader( sequence_mode="embedding", vocab="esm2", split_key="val", crop_method="center", resize=600, crop_size=256, text_seq_len=1000, pad_mode="end", threshold="median", ) # Check if sequence is provided and valid if len(sequence_input) == 0: raise ValueError("Sequence must be provided.") if "" not in sequence_input: print("Warning: Sequence does not contain any masked positions to predict.") # Convert SEQUENCE to sequence using dataset.tokenize_sequence() sequence = dataset.tokenize_sequence(sequence_input) # Sample from model using provided sequence and nucleus image _, predicted_sequence, _ = model.celle.sample_text( text=sequence.to(device), condition=nucleus_image.to(device), image=protein_image.to(device), force_aas=True, temperature=1, progress=False, ) return predicted_sequence