parokshsaxena commited on
Commit
d52990b
Β·
1 Parent(s): 72b00c6

using enhanced garment net based on the claude suggestions

Browse files
src/enhanced_garment_net.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class ResidualBlock(nn.Module):
6
+ def __init__(self, in_channels, out_channels):
7
+ super(ResidualBlock, self).__init__()
8
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
9
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
10
+ self.bn1 = nn.BatchNorm2d(out_channels)
11
+ self.bn2 = nn.BatchNorm2d(out_channels)
12
+ self.relu = nn.ReLU(inplace=True)
13
+ self.downsample = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else None
14
+
15
+ def forward(self, x):
16
+ residual = x
17
+ out = self.relu(self.bn1(self.conv1(x)))
18
+ out = self.bn2(self.conv2(out))
19
+ if self.downsample:
20
+ residual = self.downsample(x)
21
+ out += residual
22
+ return self.relu(out)
23
+
24
+ class EnhancedGarmentNet(nn.Module):
25
+ def __init__(self, in_channels=3, base_channels=64, num_residual_blocks=4):
26
+ super(EnhancedGarmentNet, self).__init__()
27
+
28
+ self.initial = nn.Sequential(
29
+ nn.Conv2d(in_channels, base_channels, kernel_size=7, padding=3),
30
+ nn.BatchNorm2d(base_channels),
31
+ nn.ReLU(inplace=True)
32
+ )
33
+
34
+ self.encoder1 = self._make_layer(base_channels, base_channels, num_residual_blocks)
35
+ self.encoder2 = self._make_layer(base_channels, base_channels*2, num_residual_blocks)
36
+ self.encoder3 = self._make_layer(base_channels*2, base_channels*4, num_residual_blocks)
37
+
38
+ self.bridge = self._make_layer(base_channels*4, base_channels*8, num_residual_blocks)
39
+
40
+ self.decoder3 = self._make_layer(base_channels*8, base_channels*4, num_residual_blocks)
41
+ self.decoder2 = self._make_layer(base_channels*4, base_channels*2, num_residual_blocks)
42
+ self.decoder1 = self._make_layer(base_channels*2, base_channels, num_residual_blocks)
43
+
44
+ self.final = nn.Conv2d(base_channels, in_channels, kernel_size=7, padding=3)
45
+
46
+ self.downsample = nn.MaxPool2d(2)
47
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
48
+
49
+ def _make_layer(self, in_channels, out_channels, num_blocks):
50
+ layers = []
51
+ layers.append(ResidualBlock(in_channels, out_channels))
52
+ for _ in range(1, num_blocks):
53
+ layers.append(ResidualBlock(out_channels, out_channels))
54
+ return nn.Sequential(*layers)
55
+
56
+ def forward(self, x):
57
+ # Initial convolution
58
+ x = self.initial(x)
59
+
60
+ # Encoder
61
+ e1 = self.encoder1(x)
62
+ e2 = self.encoder2(self.downsample(e1))
63
+ e3 = self.encoder3(self.downsample(e2))
64
+
65
+ # Bridge
66
+ b = self.bridge(self.downsample(e3))
67
+
68
+ # Decoder with skip connections
69
+ d3 = self.decoder3(torch.cat([self.upsample(b), e3], dim=1))
70
+ d2 = self.decoder2(torch.cat([self.upsample(d3), e2], dim=1))
71
+ d1 = self.decoder1(torch.cat([self.upsample(d2), e1], dim=1))
72
+
73
+ # Final convolution
74
+ out = self.final(d1)
75
+
76
+ return out, [e1, e2, e3, b]
77
+
78
+ class EnhancedGarmentNetWithTimestep(nn.Module):
79
+ def __init__(self, in_channels=3, base_channels=64, num_residual_blocks=4, time_emb_dim=256):
80
+ super(EnhancedGarmentNetWithTimestep, self).__init__()
81
+
82
+ self.garment_net = EnhancedGarmentNet(in_channels, base_channels, num_residual_blocks)
83
+
84
+ # Timestep embedding
85
+ self.time_mlp = nn.Sequential(
86
+ nn.Linear(1, time_emb_dim),
87
+ nn.SiLU(),
88
+ nn.Linear(time_emb_dim, time_emb_dim)
89
+ )
90
+
91
+ # Projection for text embeddings
92
+ self.text_proj = nn.Linear(768, time_emb_dim) # Assuming text embeddings are 768-dimensional
93
+
94
+ # Combine garment features with time and text embeddings
95
+ self.combine = nn.ModuleList([
96
+ nn.Conv2d(base_channels + time_emb_dim, base_channels, kernel_size=1),
97
+ nn.Conv2d(base_channels*2 + time_emb_dim, base_channels*2, kernel_size=1),
98
+ nn.Conv2d(base_channels*4 + time_emb_dim, base_channels*4, kernel_size=1),
99
+ nn.Conv2d(base_channels*8 + time_emb_dim, base_channels*8, kernel_size=1)
100
+ ])
101
+
102
+ def forward(self, x, t, text_embeds):
103
+ # Get garment features
104
+ garment_out, garment_features = self.garment_net(x)
105
+
106
+ # Process timestep
107
+ t_emb = self.time_mlp(t.unsqueeze(-1)).unsqueeze(-1).unsqueeze(-1)
108
+
109
+ # Process text embeddings
110
+ text_emb = self.text_proj(text_embeds).unsqueeze(-1).unsqueeze(-1)
111
+
112
+ # Combine embeddings
113
+ cond_emb = t_emb + text_emb
114
+
115
+ # Combine garment features with conditional embedding
116
+ combined_features = []
117
+ for feat, comb_layer in zip(garment_features, self.combine):
118
+ # Expand conditional embedding to match feature map size
119
+ expanded_cond_emb = cond_emb.expand(-1, -1, feat.size(2), feat.size(3))
120
+ combined = comb_layer(torch.cat([feat, expanded_cond_emb], dim=1))
121
+ combined_features.append(combined)
122
+
123
+ return garment_out, combined_features
src/tryon_pipeline.py CHANGED
@@ -56,6 +56,8 @@ from diffusers.utils import (
56
  from diffusers.utils.torch_utils import randn_tensor
57
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
58
 
 
 
59
 
60
 
61
  if is_torch_xla_available():
@@ -398,6 +400,7 @@ class StableDiffusionXLInpaintPipeline(
398
  force_zeros_for_empty_prompt: bool = True,
399
  ):
400
  super().__init__()
 
401
 
402
  self.register_modules(
403
  vae=vae,
@@ -1781,7 +1784,8 @@ class StableDiffusionXLInpaintPipeline(
1781
  if ip_adapter_image is not None:
1782
  added_cond_kwargs["image_embeds"] = image_embeds
1783
  # down,reference_features = self.UNet_Encoder(cloth,t, text_embeds_cloth,added_cond_kwargs= {"text_embeds": pooled_prompt_embeds_c, "time_ids": add_time_ids},return_dict=False)
1784
- down,reference_features = self.unet_encoder(cloth,t, text_embeds_cloth,return_dict=False)
 
1785
  # print(type(reference_features))
1786
  # print(reference_features)
1787
  reference_features = list(reference_features)
 
56
  from diffusers.utils.torch_utils import randn_tensor
57
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
58
 
59
+ from enhanced_garment_net import EnhancedGarmentNetWithTimestep
60
+
61
 
62
 
63
  if is_torch_xla_available():
 
400
  force_zeros_for_empty_prompt: bool = True,
401
  ):
402
  super().__init__()
403
+ self.garment_net = EnhancedGarmentNetWithTimestep()
404
 
405
  self.register_modules(
406
  vae=vae,
 
1784
  if ip_adapter_image is not None:
1785
  added_cond_kwargs["image_embeds"] = image_embeds
1786
  # down,reference_features = self.UNet_Encoder(cloth,t, text_embeds_cloth,added_cond_kwargs= {"text_embeds": pooled_prompt_embeds_c, "time_ids": add_time_ids},return_dict=False)
1787
+ # down,reference_features = self.unet_encoder(cloth,t, text_embeds_cloth,return_dict=False)
1788
+ garment_out, reference_features = self.garment_net(cloth, t, text_embeds_cloth)
1789
  # print(type(reference_features))
1790
  # print(reference_features)
1791
  reference_features = list(reference_features)