File size: 3,916 Bytes
2b7bf83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# -*- coding: utf-8 -*-

# Copyright 2020 MINH ANH (@dathudeptrai)
#  MIT License (https://opensource.org/licenses/MIT)

"""Tensorflow Layer modules complatible with pytorch."""

import tensorflow as tf


class TFReflectionPad1d(tf.keras.layers.Layer):
    """Tensorflow ReflectionPad1d module."""

    def __init__(self, padding_size):
        """Initialize TFReflectionPad1d module.

        Args:
            padding_size (int): Padding size.

        """
        super(TFReflectionPad1d, self).__init__()
        self.padding_size = padding_size

    @tf.function
    def call(self, x):
        """Calculate forward propagation.

        Args:
            x (Tensor): Input tensor (B, T, 1, C).

        Returns:
            Tensor: Padded tensor (B, T + 2 * padding_size, 1, C).

        """
        return tf.pad(
            x,
            [[0, 0], [self.padding_size, self.padding_size], [0, 0], [0, 0]],
            "REFLECT",
        )


class TFConvTranspose1d(tf.keras.layers.Layer):
    """Tensorflow ConvTranspose1d module."""

    def __init__(self, channels, kernel_size, stride, padding):
        """Initialize TFConvTranspose1d( module.

        Args:
            channels (int): Number of channels.
            kernel_size (int): kernel size.
            strides (int): Stride width.
            padding (str): Padding type ("same" or "valid").

        """
        super(TFConvTranspose1d, self).__init__()
        self.conv1d_transpose = tf.keras.layers.Conv2DTranspose(
            filters=channels,
            kernel_size=(kernel_size, 1),
            strides=(stride, 1),
            padding=padding,
        )

    @tf.function
    def call(self, x):
        """Calculate forward propagation.

        Args:
            x (Tensor): Input tensor (B, T, 1, C).

        Returns:
            Tensors: Output tensor (B, T', 1, C').

        """
        x = self.conv1d_transpose(x)
        return x


class TFResidualStack(tf.keras.layers.Layer):
    """Tensorflow ResidualStack module."""

    def __init__(
        self,
        kernel_size,
        channels,
        dilation,
        bias,
        nonlinear_activation,
        nonlinear_activation_params,
        padding,
    ):
        """Initialize TFResidualStack module.

        Args:
            kernel_size (int): Kernel size.
            channles (int): Number of channels.
            dilation (int): Dilation ine.
            bias (bool): Whether to add bias parameter in convolution layers.
            nonlinear_activation (str): Activation function module name.
            nonlinear_activation_params (dict): Hyperparameters for activation function.
            padding (str): Padding type ("same" or "valid").

        """
        super(TFResidualStack, self).__init__()
        self.block = [
            getattr(tf.keras.layers, nonlinear_activation)(
                **nonlinear_activation_params
            ),
            TFReflectionPad1d(dilation),
            tf.keras.layers.Conv2D(
                filters=channels,
                kernel_size=(kernel_size, 1),
                dilation_rate=(dilation, 1),
                use_bias=bias,
                padding="valid",
            ),
            getattr(tf.keras.layers, nonlinear_activation)(
                **nonlinear_activation_params
            ),
            tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias),
        ]
        self.shortcut = tf.keras.layers.Conv2D(
            filters=channels, kernel_size=1, use_bias=bias
        )

    @tf.function
    def call(self, x):
        """Calculate forward propagation.

        Args:
            x (Tensor): Input tensor (B, T, 1, C).

        Returns:
            Tensor: Output tensor (B, T, 1, C).

        """
        _x = tf.identity(x)
        for i, layer in enumerate(self.block):
            _x = layer(_x)
        shortcut = self.shortcut(x)
        return shortcut + _x