File size: 1,063 Bytes
6ded986
 
 
2f82475
 
6ded986
2f82475
6ded986
 
 
 
2f82475
 
6ded986
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import pytest
import torch

# huggingface_openphenom_model_dir = "."
huggingface_modelpath = "recursionpharma/OpenPhenom"

from .huggingface_mae import MAEModel


@pytest.fixture
def huggingface_model():
    # This step downloads the model to a local cache, takes a bit to run
    huggingface_model = MAEModel.from_pretrained(huggingface_modelpath)
    huggingface_model.eval()
    return huggingface_model


@pytest.mark.parametrize("C", [1, 4, 6, 11])
@pytest.mark.parametrize("return_channelwise_embeddings", [True, False])
def test_model_predict(huggingface_model, C, return_channelwise_embeddings):
    example_input_array = torch.randint(
        low=0,
        high=255,
        size=(2, C, 256, 256),
        dtype=torch.uint8,
        device=huggingface_model.device,
    )
    huggingface_model.return_channelwise_embeddings = return_channelwise_embeddings
    embeddings = huggingface_model.predict(example_input_array)
    expected_output_dim = 384 * C if return_channelwise_embeddings else 384
    assert embeddings.shape == (2, expected_output_dim)