sachinkidzure's picture
initial (#1)
135b069 verified
raw
history blame
4.02 kB
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import flax.linen as nn
import jax
import jax.numpy as jnp
class FlaxUpsample2D(nn.Module):
out_channels: int
dtype: jnp.dtype = jnp.float32
def setup(self):
self.conv = nn.Conv(
self.out_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)
def __call__(self, hidden_states):
batch, height, width, channels = hidden_states.shape
hidden_states = jax.image.resize(
hidden_states,
shape=(batch, height * 2, width * 2, channels),
method="nearest",
)
hidden_states = self.conv(hidden_states)
return hidden_states
class FlaxDownsample2D(nn.Module):
out_channels: int
dtype: jnp.dtype = jnp.float32
def setup(self):
self.conv = nn.Conv(
self.out_channels,
kernel_size=(3, 3),
strides=(2, 2),
padding=((1, 1), (1, 1)), # padding="VALID",
dtype=self.dtype,
)
def __call__(self, hidden_states):
# pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
# hidden_states = jnp.pad(hidden_states, pad_width=pad)
hidden_states = self.conv(hidden_states)
return hidden_states
class FlaxResnetBlock2D(nn.Module):
in_channels: int
out_channels: int = None
dropout_prob: float = 0.0
use_nin_shortcut: bool = None
dtype: jnp.dtype = jnp.float32
def setup(self):
out_channels = self.in_channels if self.out_channels is None else self.out_channels
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
self.conv1 = nn.Conv(
out_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)
self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype)
self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
self.dropout = nn.Dropout(self.dropout_prob)
self.conv2 = nn.Conv(
out_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)
use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
self.conv_shortcut = None
if use_nin_shortcut:
self.conv_shortcut = nn.Conv(
out_channels,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
def __call__(self, hidden_states, temb, deterministic=True):
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = nn.swish(hidden_states)
hidden_states = self.conv1(hidden_states)
temb = self.time_emb_proj(nn.swish(temb))
temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
hidden_states = nn.swish(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
residual = self.conv_shortcut(residual)
return hidden_states + residual