jamino30 commited on
Commit
b33b2b4
·
verified ·
1 Parent(s): 877adbb

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. u2net/train.py +3 -2
app.py CHANGED
@@ -9,6 +9,7 @@ import torchvision.models as models
9
  import numpy as np
10
  import gradio as gr
11
  from gradio_imageslider import ImageSlider
 
12
 
13
  from utils import preprocess_img, preprocess_img_from_path, postprocess_img
14
  from vgg.vgg19 import VGG_19
@@ -22,8 +23,7 @@ print('DEVICE:', device)
22
  if device == 'cuda': print('CUDA DEVICE:', torch.cuda.get_device_name())
23
 
24
  def load_model_without_module(model, model_path):
25
- state_dict = torch.load(model_path, map_location=device, weights_only=False)
26
-
27
  new_state_dict = {}
28
  for k, v in state_dict.items():
29
  name = k[7:] if k.startswith('module.') else k
 
9
  import numpy as np
10
  import gradio as gr
11
  from gradio_imageslider import ImageSlider
12
+ from safetensors.torch import load_file
13
 
14
  from utils import preprocess_img, preprocess_img_from_path, postprocess_img
15
  from vgg.vgg19 import VGG_19
 
23
  if device == 'cuda': print('CUDA DEVICE:', torch.cuda.get_device_name())
24
 
25
  def load_model_without_module(model, model_path):
26
+ state_dict = load_file(model_path, device=device)
 
27
  new_state_dict = {}
28
  for k, v in state_dict.items():
29
  name = k[7:] if k.startswith('module.') else k
u2net/train.py CHANGED
@@ -6,6 +6,7 @@ import torch.nn as nn
6
  import torch.optim as optim
7
  from torch.utils.data import DataLoader, ConcatDataset
8
  from torch.amp import autocast, GradScaler
 
9
 
10
  from data_loader import DUTSDataset, MSRADataset
11
  from model import U2Net
@@ -78,11 +79,11 @@ if __name__ == '__main__':
78
 
79
  if val_loss < best_val_loss:
80
  best_val_loss = val_loss
81
- torch.save(model.state_dict(), f'results/best-{model_name}.pt')
82
  print('Best model saved.')
83
 
84
  print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f} (Best: {best_val_loss:.4f})')
85
 
86
- torch.save(model.state_dict(), f'results/{model_name}.pt')
87
  with open('results/loss.txt', 'wb') as f:
88
  pickle.dump(losses, f)
 
6
  import torch.optim as optim
7
  from torch.utils.data import DataLoader, ConcatDataset
8
  from torch.amp import autocast, GradScaler
9
+ from safetensors.torch import save_file
10
 
11
  from data_loader import DUTSDataset, MSRADataset
12
  from model import U2Net
 
79
 
80
  if val_loss < best_val_loss:
81
  best_val_loss = val_loss
82
+ save_file(model.state_dict(), f'results/best-{model_name}.safetensors')
83
  print('Best model saved.')
84
 
85
  print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f} (Best: {best_val_loss:.4f})')
86
 
87
+ save_file(model.state_dict(), f'results/{model_name}.safetensors')
88
  with open('results/loss.txt', 'wb') as f:
89
  pickle.dump(losses, f)