ChenyangSi commited on
Commit
b1c39d5
1 Parent(s): 64efbd0

Update free_lunch_utils.py

Browse files
Files changed (1) hide show
  1. free_lunch_utils.py +4 -4
free_lunch_utils.py CHANGED
@@ -97,10 +97,10 @@ def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
97
  # Only operate on the first two stages
98
  if hidden_states.shape[1] == 1280:
99
  hidden_states[:,:640] = hidden_states[:,:640] * self.b1
100
- res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
101
  if hidden_states.shape[1] == 640:
102
  hidden_states[:,:320] = hidden_states[:,:320] * self.b2
103
- res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
104
  # ---------------------------------------------------------
105
 
106
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
@@ -235,10 +235,10 @@ def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
235
  # Only operate on the first two stages
236
  if hidden_states.shape[1] == 1280:
237
  hidden_states[:,:640] = hidden_states[:,:640] * self.b1
238
- res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
239
  if hidden_states.shape[1] == 640:
240
  hidden_states[:,:320] = hidden_states[:,:320] * self.b2
241
- res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
242
  # ---------------------------------------------------------
243
 
244
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
 
97
  # Only operate on the first two stages
98
  if hidden_states.shape[1] == 1280:
99
  hidden_states[:,:640] = hidden_states[:,:640] * self.b1
100
+ # # res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
101
  if hidden_states.shape[1] == 640:
102
  hidden_states[:,:320] = hidden_states[:,:320] * self.b2
103
+ # res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
104
  # ---------------------------------------------------------
105
 
106
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
 
235
  # Only operate on the first two stages
236
  if hidden_states.shape[1] == 1280:
237
  hidden_states[:,:640] = hidden_states[:,:640] * self.b1
238
+ # res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
239
  if hidden_states.shape[1] == 640:
240
  hidden_states[:,:320] = hidden_states[:,:320] * self.b2
241
+ # res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
242
  # ---------------------------------------------------------
243
 
244
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)