File size: 1,533 Bytes
548170b
 
22f2c54
548170b
 
22f2c54
ab753bf
548170b
 
 
 
 
 
 
 
 
 
 
22f2c54
548170b
 
 
 
 
 
 
 
 
 
 
 
 
22f2c54
 
 
 
 
 
 
548170b
 
 
 
22f2c54
80a49a4
 
22f2c54
 
548170b
80a49a4
548170b
22f2c54
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from dataloader import CellLoader

def run_sequence_prediction(
    sequence_input,
    nucleus_image,
    protein_image,
    model,
    device
):
    """
    Run Celle model with provided inputs and display results.

    :param sequence: Path to sequence file
    :param nucleus_image_path: Path to nucleus image
    :param protein_image_path: Path to protein image (optional)
    :param model_ckpt_path: Path to model checkpoint
    :param model_config_path: Path to model config
    """
    
    # Instantiate dataset object
    dataset = CellLoader(
        sequence_mode="embedding",
        vocab="esm2",
        split_key="val",
        crop_method="center",
        resize=600,
        crop_size=256,
        text_seq_len=1000,
        pad_mode="end",
        threshold="median",
    )

    # Check if sequence is provided and valid
    if len(sequence_input) == 0:
        raise ValueError("Sequence must be provided.")

    if "<mask>" not in sequence_input:
        print("Warning: Sequence does not contain any masked positions to predict.")

    # Convert SEQUENCE to sequence using dataset.tokenize_sequence()
    sequence = dataset.tokenize_sequence(sequence_input)

    # Sample from model using provided sequence and nucleus image
    _, predicted_sequence, _ = model.celle.sample_text(
        text=sequence.to(device),
        condition=nucleus_image.to(device),
        image=protein_image.to(device),
        force_aas=True,
        temperature=1,
        progress=False,
    )
    
    return predicted_sequence