kfahn commited on
Commit
cfc764d
1 Parent(s): f1d03bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -55,9 +55,9 @@ class Generator(nn.Module):
55
  model2 += [ResidualBlock(in_features)]
56
  self.model2 = nn.Sequential(*model2)
57
 
58
- # Upsampling
59
  model3 = []
60
- out_features = in_features//2
61
  for _ in range(2):
62
  model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
63
  norm_layer(out_features),
@@ -87,9 +87,13 @@ model1 = Generator(3, 1, 3)
87
  model1.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
88
  model1.eval()
89
 
90
- model2 = Generator(3, 1, 3)
91
- model2.load_state_dict(torch.load('model2.pth', map_location=torch.device('cpu')))
92
- model2.eval()
 
 
 
 
93
 
94
  def predict(input_img):
95
  input_img = Image.open(input_img)
@@ -99,7 +103,7 @@ def predict(input_img):
99
 
100
  drawing = 0
101
  with torch.no_grad():
102
- drawing = model1(input_img)[0].detach()
103
 
104
  drawing = transforms.ToPILImage()(drawing)
105
  return drawing
 
55
  model2 += [ResidualBlock(in_features)]
56
  self.model2 = nn.Sequential(*model2)
57
 
58
+ # More downsampling
59
  model3 = []
60
+ out_features = in_features*3
61
  for _ in range(2):
62
  model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
63
  norm_layer(out_features),
 
87
  model1.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
88
  model1.eval()
89
 
90
+ model3 = Generator(3, 1, 3)
91
+ model3.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
92
+ model3.eval()
93
+
94
+ # model2 = Generator(3, 1, 3)
95
+ # model2.load_state_dict(torch.load('model2.pth', map_location=torch.device('cpu')))
96
+ # model2.eval()
97
 
98
  def predict(input_img):
99
  input_img = Image.open(input_img)
 
103
 
104
  drawing = 0
105
  with torch.no_grad():
106
+ drawing = model3(input_img)[0].detach()
107
 
108
  drawing = transforms.ToPILImage()(drawing)
109
  return drawing