ZhengPeng7 commited on
Commit
6643d46
·
1 Parent(s): ab9a192

Add the deployment option. Update the model codes.

Browse files
Files changed (2) hide show
  1. birefnet.py +29 -26
  2. handler.py +9 -4
birefnet.py CHANGED
@@ -615,6 +615,7 @@ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
615
 
616
  # config = Config()
617
 
 
618
  class Mlp(nn.Module):
619
  """ Multilayer perceptron."""
620
 
@@ -739,7 +740,8 @@ class WindowAttention(nn.Module):
739
  attn = (q @ k.transpose(-2, -1))
740
 
741
  relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
742
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
 
743
  relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
744
  attn = attn + relative_position_bias.unsqueeze(0)
745
 
@@ -974,8 +976,9 @@ class BasicLayer(nn.Module):
974
  """
975
 
976
  # calculate attention mask for SW-MSA
977
- Hp = int(np.ceil(H / self.window_size)) * self.window_size
978
- Wp = int(np.ceil(W / self.window_size)) * self.window_size
 
979
  img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
980
  h_slices = (slice(0, -self.window_size),
981
  slice(-self.window_size, -self.shift_size),
@@ -1961,6 +1964,7 @@ import torch.nn as nn
1961
  import torch.nn.functional as F
1962
  from kornia.filters import laplacian
1963
  from transformers import PreTrainedModel
 
1964
 
1965
  # from config import Config
1966
  # from dataset import class_labels_TR_sorted
@@ -1974,13 +1978,24 @@ from transformers import PreTrainedModel
1974
  from .BiRefNet_config import BiRefNetConfig
1975
 
1976
 
 
 
 
 
 
 
 
 
 
 
 
 
1977
  class BiRefNet(
1978
  PreTrainedModel
1979
  ):
1980
  config_class = BiRefNetConfig
1981
  def __init__(self, bb_pretrained=True, config=BiRefNetConfig()):
1982
- super(BiRefNet, self).__init__(config)
1983
- bb_pretrained = config.bb_pretrained
1984
  self.config = Config()
1985
  self.epoch = 1
1986
  self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained)
@@ -2124,18 +2139,6 @@ class Decoder(nn.Module):
2124
  self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2125
  self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2126
 
2127
- def get_patches_batch(self, x, p):
2128
- _size_h, _size_w = p.shape[2:]
2129
- patches_batch = []
2130
- for idx in range(x.shape[0]):
2131
- columns_x = torch.split(x[idx], split_size_or_sections=_size_w, dim=-1)
2132
- patches_x = []
2133
- for column_x in columns_x:
2134
- patches_x += [p.unsqueeze(0) for p in torch.split(column_x, split_size_or_sections=_size_h, dim=-2)]
2135
- patch_sample = torch.cat(patches_x, dim=1)
2136
- patches_batch.append(patch_sample)
2137
- return torch.cat(patches_batch, dim=0)
2138
-
2139
  def forward(self, features):
2140
  if self.training and self.config.out_ref:
2141
  outs_gdt_pred = []
@@ -2146,10 +2149,10 @@ class Decoder(nn.Module):
2146
  outs = []
2147
 
2148
  if self.config.dec_ipt:
2149
- patches_batch = self.get_patches_batch(x, x4) if self.split else x
2150
  x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
2151
  p4 = self.decoder_block4(x4)
2152
- m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision else None
2153
  if self.config.out_ref:
2154
  p4_gdt = self.gdt_convs_4(p4)
2155
  if self.training:
@@ -2167,10 +2170,10 @@ class Decoder(nn.Module):
2167
  _p3 = _p4 + self.lateral_block4(x3)
2168
 
2169
  if self.config.dec_ipt:
2170
- patches_batch = self.get_patches_batch(x, _p3) if self.split else x
2171
  _p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
2172
  p3 = self.decoder_block3(_p3)
2173
- m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision else None
2174
  if self.config.out_ref:
2175
  p3_gdt = self.gdt_convs_3(p3)
2176
  if self.training:
@@ -2193,10 +2196,10 @@ class Decoder(nn.Module):
2193
  _p2 = _p3 + self.lateral_block3(x2)
2194
 
2195
  if self.config.dec_ipt:
2196
- patches_batch = self.get_patches_batch(x, _p2) if self.split else x
2197
  _p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
2198
  p2 = self.decoder_block2(_p2)
2199
- m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision else None
2200
  if self.config.out_ref:
2201
  p2_gdt = self.gdt_convs_2(p2)
2202
  if self.training:
@@ -2214,17 +2217,17 @@ class Decoder(nn.Module):
2214
  _p1 = _p2 + self.lateral_block2(x1)
2215
 
2216
  if self.config.dec_ipt:
2217
- patches_batch = self.get_patches_batch(x, _p1) if self.split else x
2218
  _p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
2219
  _p1 = self.decoder_block1(_p1)
2220
  _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
2221
 
2222
  if self.config.dec_ipt:
2223
- patches_batch = self.get_patches_batch(x, _p1) if self.split else x
2224
  _p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
2225
  p1_out = self.conv_out1(_p1)
2226
 
2227
- if self.config.ms_supervision:
2228
  outs.append(m4)
2229
  outs.append(m3)
2230
  outs.append(m2)
 
615
 
616
  # config = Config()
617
 
618
+
619
  class Mlp(nn.Module):
620
  """ Multilayer perceptron."""
621
 
 
740
  attn = (q @ k.transpose(-2, -1))
741
 
742
  relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
743
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
744
+ ) # Wh*Ww, Wh*Ww, nH
745
  relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
746
  attn = attn + relative_position_bias.unsqueeze(0)
747
 
 
976
  """
977
 
978
  # calculate attention mask for SW-MSA
979
+ # Turn int to torch.tensor for the compatiability with torch.compile in PyTorch 2.5.
980
+ Hp = torch.ceil(torch.tensor(H) / self.window_size).to(torch.int64) * self.window_size
981
+ Wp = torch.ceil(torch.tensor(W) / self.window_size).to(torch.int64) * self.window_size
982
  img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
983
  h_slices = (slice(0, -self.window_size),
984
  slice(-self.window_size, -self.shift_size),
 
1964
  import torch.nn.functional as F
1965
  from kornia.filters import laplacian
1966
  from transformers import PreTrainedModel
1967
+ from einops import rearrange
1968
 
1969
  # from config import Config
1970
  # from dataset import class_labels_TR_sorted
 
1978
  from .BiRefNet_config import BiRefNetConfig
1979
 
1980
 
1981
+ def image2patches(image, grid_h=2, grid_w=2, patch_ref=None, transformation='b c (hg h) (wg w) -> (b hg wg) c h w'):
1982
+ if patch_ref is not None:
1983
+ grid_h, grid_w = image.shape[-2] // patch_ref.shape[-2], image.shape[-1] // patch_ref.shape[-1]
1984
+ patches = rearrange(image, transformation, hg=grid_h, wg=grid_w)
1985
+ return patches
1986
+
1987
+ def patches2image(patches, grid_h=2, grid_w=2, patch_ref=None, transformation='(b hg wg) c h w -> b c (hg h) (wg w)'):
1988
+ if patch_ref is not None:
1989
+ grid_h, grid_w = patch_ref.shape[-2] // patches[0].shape[-2], patch_ref.shape[-1] // patches[0].shape[-1]
1990
+ image = rearrange(patches, transformation, hg=grid_h, wg=grid_w)
1991
+ return image
1992
+
1993
  class BiRefNet(
1994
  PreTrainedModel
1995
  ):
1996
  config_class = BiRefNetConfig
1997
  def __init__(self, bb_pretrained=True, config=BiRefNetConfig()):
1998
+ super(BiRefNet, self).__init__()
 
1999
  self.config = Config()
2000
  self.epoch = 1
2001
  self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained)
 
2139
  self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2140
  self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2141
 
 
 
 
 
 
 
 
 
 
 
 
 
2142
  def forward(self, features):
2143
  if self.training and self.config.out_ref:
2144
  outs_gdt_pred = []
 
2149
  outs = []
2150
 
2151
  if self.config.dec_ipt:
2152
+ patches_batch = image2patches(x, patch_ref=x4, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
2153
  x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
2154
  p4 = self.decoder_block4(x4)
2155
+ m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision and self.training else None
2156
  if self.config.out_ref:
2157
  p4_gdt = self.gdt_convs_4(p4)
2158
  if self.training:
 
2170
  _p3 = _p4 + self.lateral_block4(x3)
2171
 
2172
  if self.config.dec_ipt:
2173
+ patches_batch = image2patches(x, patch_ref=_p3, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
2174
  _p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
2175
  p3 = self.decoder_block3(_p3)
2176
+ m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision and self.training else None
2177
  if self.config.out_ref:
2178
  p3_gdt = self.gdt_convs_3(p3)
2179
  if self.training:
 
2196
  _p2 = _p3 + self.lateral_block3(x2)
2197
 
2198
  if self.config.dec_ipt:
2199
+ patches_batch = image2patches(x, patch_ref=_p2, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
2200
  _p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
2201
  p2 = self.decoder_block2(_p2)
2202
+ m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision and self.training else None
2203
  if self.config.out_ref:
2204
  p2_gdt = self.gdt_convs_2(p2)
2205
  if self.training:
 
2217
  _p1 = _p2 + self.lateral_block2(x1)
2218
 
2219
  if self.config.dec_ipt:
2220
+ patches_batch = image2patches(x, patch_ref=_p1, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
2221
  _p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
2222
  _p1 = self.decoder_block1(_p1)
2223
  _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
2224
 
2225
  if self.config.dec_ipt:
2226
+ patches_batch = image2patches(x, patch_ref=_p1, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
2227
  _p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
2228
  p1_out = self.conv_out1(_p1)
2229
 
2230
+ if self.config.ms_supervision and self.training:
2231
  outs.append(m4)
2232
  outs.append(m3)
2233
  outs.append(m2)
handler.py CHANGED
@@ -1,7 +1,11 @@
1
  # These HF deployment codes refer to https://huggingface.co/not-lain/BiRefNet/raw/main/handler.py.
2
  from typing import Dict, List, Any, Tuple
3
- import base64
 
4
  from io import BytesIO
 
 
 
5
  import torch
6
  from torchvision import transforms
7
  from transformers import AutoModelForImageSegmentation
@@ -70,14 +74,15 @@ usage_to_weights_file = {
70
  'General-legacy': 'BiRefNet-legacy'
71
  }
72
 
73
- birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
 
74
  birefnet.to(device)
75
  birefnet.eval()
76
 
77
  # Set resolution
78
- if weights_file in ['General-Lite-2K']:
79
  resolution = (2560, 1440)
80
- elif weights_file in ['General-reso_512']:
81
  resolution = (512, 512)
82
  else:
83
  resolution = (1024, 1024)
 
1
  # These HF deployment codes refer to https://huggingface.co/not-lain/BiRefNet/raw/main/handler.py.
2
  from typing import Dict, List, Any, Tuple
3
+ import os
4
+ import requests
5
  from io import BytesIO
6
+ import cv2
7
+ import numpy as np
8
+ from PIL import Image
9
  import torch
10
  from torchvision import transforms
11
  from transformers import AutoModelForImageSegmentation
 
74
  'General-legacy': 'BiRefNet-legacy'
75
  }
76
 
77
+ usage = 'General'
78
+ birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file[usage])), trust_remote_code=True)
79
  birefnet.to(device)
80
  birefnet.eval()
81
 
82
  # Set resolution
83
+ if usage in ['General-Lite-2K']:
84
  resolution = (2560, 1440)
85
+ elif usage in ['General-reso_512']:
86
  resolution = (512, 512)
87
  else:
88
  resolution = (1024, 1024)