nielsr HF staff commited on
Commit
b59562f
1 Parent(s): a08472f
Files changed (2) hide show
  1. app.py +1 -6
  2. models/seemore.py +3 -1
app.py CHANGED
@@ -9,7 +9,6 @@ from PIL import Image
9
  from copy import deepcopy
10
  from torch.nn.parallel import DataParallel, DistributedDataParallel
11
 
12
- from huggingface_hub import hf_hub_download
13
  from gradio_imageslider import ImageSlider
14
 
15
  ## local code
@@ -71,7 +70,6 @@ def load_network(net, load_path, strict=True, param_key='params'):
71
  net.load_state_dict(load_net, strict=strict)
72
 
73
  CONFIG = "configs/eval_seemore_t_x4.yml"
74
- hf_hub_download(repo_id="eduardzamfir/SeemoRe-T", filename="SeemoRe_T_X4.pth", local_dir="./")
75
  MODEL_NAME = "SeemoRe_T_X4.pth"
76
 
77
  # parse config file
@@ -81,10 +79,7 @@ with open(os.path.join(CONFIG), "r") as f:
81
  cfg = dict2namespace(config)
82
 
83
  device = torch.device("cpu")
84
- model = seemore.SeemoRe(scale=cfg.model.scale, in_chans=cfg.model.in_chans,
85
- num_experts=cfg.model.num_experts, num_layers=cfg.model.num_layers, embedding_dim=cfg.model.embedding_dim,
86
- img_range=cfg.model.img_range, use_shuffle=cfg.model.use_shuffle, global_kernel_size=cfg.model.global_kernel_size,
87
- recursive=cfg.model.recursive, lr_space=cfg.model.lr_space, topk=cfg.model.topk)
88
 
89
  model = model.to(device)
90
  print ("IMAGE MODEL CKPT:", MODEL_NAME)
 
9
  from copy import deepcopy
10
  from torch.nn.parallel import DataParallel, DistributedDataParallel
11
 
 
12
  from gradio_imageslider import ImageSlider
13
 
14
  ## local code
 
70
  net.load_state_dict(load_net, strict=strict)
71
 
72
  CONFIG = "configs/eval_seemore_t_x4.yml"
 
73
  MODEL_NAME = "SeemoRe_T_X4.pth"
74
 
75
  # parse config file
 
79
  cfg = dict2namespace(config)
80
 
81
  device = torch.device("cpu")
82
+ model = seemore.SeemoRe.from_pretrained("eduardzamfir/SeemoRe-T")
 
 
 
83
 
84
  model = model.to(device)
85
  print ("IMAGE MODEL CKPT:", MODEL_NAME)
models/seemore.py CHANGED
@@ -6,11 +6,13 @@ import torch.nn as nn
6
  import torch.nn.functional as F
7
  from einops.layers.torch import Rearrange
8
 
 
 
9
 
10
  ######################
11
  # Meta Architecture
12
  ######################
13
- class SeemoRe(nn.Module):
14
  def __init__(self,
15
  scale: int = 4,
16
  in_chans: int = 3,
 
6
  import torch.nn.functional as F
7
  from einops.layers.torch import Rearrange
8
 
9
+ from huggingface_hub import PyTorchModelHubMixin
10
+
11
 
12
  ######################
13
  # Meta Architecture
14
  ######################
15
+ class SeemoRe(nn.Module, PyTorchModelHubMixin):
16
  def __init__(self,
17
  scale: int = 4,
18
  in_chans: int = 3,