Spaces:
Runtime error
Runtime error
File size: 2,005 Bytes
24d0437 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import torch
class LockedDropout(torch.nn.Module):
"""
Implementation of locked (or variational) dropout.
Randomly drops out entire parameters in embedding space.
:param dropout_rate: represent the fraction of the input unit to be dropped. It will be from 0 to 1.
:param batch_first: represent if the drop will perform in an ascending manner
:param inplace:
"""
def __init__(self, dropout_rate=0.5, batch_first=True, inplace=False):
super(LockedDropout, self).__init__()
self.dropout_rate = dropout_rate
self.batch_first = batch_first
self.inplace = inplace
def forward(self, x):
if not self.training or not self.dropout_rate:
return x
if not self.batch_first:
m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - self.dropout_rate)
else:
m = x.data.new(x.size(0), 1, x.size(2)).bernoulli_(1 - self.dropout_rate)
mask = torch.autograd.Variable(m, requires_grad=False) / (1 - self.dropout_rate)
mask = mask.expand_as(x)
return mask * x
def extra_repr(self):
inplace_str = ", inplace" if self.inplace else ""
return "p={}{}".format(self.dropout_rate, inplace_str)
class WordDropout(torch.nn.Module):
"""
Implementation of word dropout. Randomly drops out entire words
(or characters) in embedding space.
"""
def __init__(self, dropout_rate=0.05, inplace=False):
super(WordDropout, self).__init__()
self.dropout_rate = dropout_rate
self.inplace = inplace
def forward(self, x):
if not self.training or not self.dropout_rate:
return x
m = x.data.new(x.size(0), x.size(1), 1).bernoulli_(1 - self.dropout_rate)
mask = torch.autograd.Variable(m, requires_grad=False)
return mask * x
def extra_repr(self):
inplace_str = ", inplace" if self.inplace else ""
return "p={}{}".format(self.dropout_rate, inplace_str) |