Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
def device(gpu_id=0): | |
if torch.cuda.is_available(): | |
return torch.device(f"cuda:{gpu_id}") | |
return torch.device("cpu") | |
def load_matching_state_dict(model: nn.Module, state_dict): | |
model_dict = model.state_dict() | |
filtered_dict = {k: v for k, v in state_dict.items() if k in model_dict} | |
model.load_state_dict(filtered_dict) | |
def resize(t: torch.Tensor, size: int) -> torch.Tensor: | |
B, C, H, W = t.shape | |
t = t.reshape(B, C, size, H // size, size, W // size) | |
return t.mean([3, 5]) | |
def make_image(tensor): | |
return ( | |
tensor.detach() | |
.clamp_(min=-1, max=1) | |
.add(1) | |
.div_(2) | |
.mul(255) | |
.type(torch.uint8) | |
.permute(0, 2, 3, 1) | |
.to('cpu') | |
.numpy() | |
) | |