Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,432 Bytes
96bea52 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
from diffusers.models.unets.unet_2d_blocks import *
class UNetMidBlock2DCrossAttn(nn.Module):
def __init__(
self,
in_channels: int,
temb_channels: int,
out_channels: Optional[int] = None,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_groups_out: Optional[int] = None,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
output_scale_factor: float = 1.0,
cross_attention_dim: int = 1280,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
upcast_attention: bool = False,
attention_type: str = "default",
):
super().__init__()
out_channels = out_channels or in_channels
self.in_channels = in_channels
self.out_channels = out_channels
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# support for variable transformer layers per block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
resnet_groups_out = resnet_groups_out or resnet_groups
# there is always at least one resnet
resnets = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
groups_out=resnet_groups_out,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
attentions = []
for i in range(num_layers):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups_out,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
attention_type=attention_type,
)
)
else:
attentions.append(
DualTransformer2DModel(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
)
resnets.append(
ResnetBlock2D(
in_channels=out_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups_out,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
else:
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = resnet(hidden_states, temb)
return hidden_states |