Kororinpa commited on
Commit
3d9d64b
1 Parent(s): 546fd2b

Delete losses.py

Browse files
Files changed (1) hide show
  1. losses.py +0 -61
losses.py DELETED
@@ -1,61 +0,0 @@
1
- import torch
2
- from torch.nn import functional as F
3
-
4
- import commons
5
-
6
-
7
- def feature_loss(fmap_r, fmap_g):
8
- loss = 0
9
- for dr, dg in zip(fmap_r, fmap_g):
10
- for rl, gl in zip(dr, dg):
11
- rl = rl.float().detach()
12
- gl = gl.float()
13
- loss += torch.mean(torch.abs(rl - gl))
14
-
15
- return loss * 2
16
-
17
-
18
- def discriminator_loss(disc_real_outputs, disc_generated_outputs):
19
- loss = 0
20
- r_losses = []
21
- g_losses = []
22
- for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
23
- dr = dr.float()
24
- dg = dg.float()
25
- r_loss = torch.mean((1-dr)**2)
26
- g_loss = torch.mean(dg**2)
27
- loss += (r_loss + g_loss)
28
- r_losses.append(r_loss.item())
29
- g_losses.append(g_loss.item())
30
-
31
- return loss, r_losses, g_losses
32
-
33
-
34
- def generator_loss(disc_outputs):
35
- loss = 0
36
- gen_losses = []
37
- for dg in disc_outputs:
38
- dg = dg.float()
39
- l = torch.mean((1-dg)**2)
40
- gen_losses.append(l)
41
- loss += l
42
-
43
- return loss, gen_losses
44
-
45
-
46
- def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
47
- """
48
- z_p, logs_q: [b, h, t_t]
49
- m_p, logs_p: [b, h, t_t]
50
- """
51
- z_p = z_p.float()
52
- logs_q = logs_q.float()
53
- m_p = m_p.float()
54
- logs_p = logs_p.float()
55
- z_mask = z_mask.float()
56
-
57
- kl = logs_p - logs_q - 0.5
58
- kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p)
59
- kl = torch.sum(kl * z_mask)
60
- l = kl / torch.sum(z_mask)
61
- return l