Spaces:
Sleeping
Sleeping
import torchio as tio | |
import torch | |
from apps.model import model | |
def preprocess_input(uploaded_file): | |
subject = tio.Subject({"CT": tio.ScalarImage(uploaded_file)}) | |
normalize_orientation = tio.ToCanonical() | |
preprocess_spatial = tio.Compose([ | |
normalize_orientation, | |
tio.RescaleIntensity((0, 1)), | |
tio.Resize((300, 300, 400)) | |
]) | |
transform = preprocess_spatial | |
dataset = tio.SubjectsDataset([subject], transform=transform) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
ckpt_path = 'apps/best_model_36.ckpt' | |
checkpoint = torch.load(ckpt_path, map_location=device) | |
model.load_state_dict(checkpoint['state_dict']) | |
model.to(device) | |
model.eval() | |
grid_sampler = tio.inference.GridSampler(dataset[0], 96, (8, 8, 8)) | |
aggregator = tio.inference.GridAggregator(grid_sampler) | |
patch_loader = tio.data.SubjectsLoader(grid_sampler, batch_size=4) | |
with torch.no_grad(): | |
for patches_batch in patch_loader: | |
input_tensor = patches_batch['CT']["data"].to(device) # Get batch of patches | |
locations = patches_batch[tio.LOCATION] # Get locations of patches | |
pred = model(input_tensor) # Compute prediction | |
aggregator.add_batch(pred, locations) | |
output_tensor = aggregator.get_output_tensor() | |
return output_tensor, dataset | |