eagle0504 commited on
Commit
ce44247
1 Parent(s): d6769dc

Update utils/helper_functions.py

Browse files
Files changed (1) hide show
  1. utils/helper_functions.py +2 -1
utils/helper_functions.py CHANGED
@@ -277,11 +277,12 @@ def quantized_influence(arr1: np.ndarray, arr2: np.ndarray, k: int = 16, use_dag
277
  unique_values = np.unique(arr1_quantized)
278
 
279
  # Compute the global average of quantized arr2
 
280
  y_bar_global = np.mean(arr2_quantized)
281
 
282
  # Compute weighted local averages and normalize
283
  weighted_local_averages = [(np.mean(arr2_quantized[arr1_quantized == val]) - y_bar_global)**2 * len(arr2_quantized[arr1_quantized == val])**2 for val in unique_values]
284
- qim = np.mean(weighted_local_averages) / np.std(arr2_quantized) # Calculate the quantized influence measure
285
 
286
  if use_dagger:
287
  # If use_dagger is True, compute local estimates and map them to unique quantized values
 
277
  unique_values = np.unique(arr1_quantized)
278
 
279
  # Compute the global average of quantized arr2
280
+ total_samples = len(arr2_quantized)
281
  y_bar_global = np.mean(arr2_quantized)
282
 
283
  # Compute weighted local averages and normalize
284
  weighted_local_averages = [(np.mean(arr2_quantized[arr1_quantized == val]) - y_bar_global)**2 * len(arr2_quantized[arr1_quantized == val])**2 for val in unique_values]
285
+ qim = np.sum(weighted_local_averages) / (total_samples * np.std(arr2_quantized)) # Calculate the quantized influence measure
286
 
287
  if use_dagger:
288
  # If use_dagger is True, compute local estimates and map them to unique quantized values