Does this work and how did you achieve this?
#1
by
belisarius
- opened
I would like to try to scaled down the union controlnet, but what I tried didnt work. What did you do to convert it to fp8 and is the source available for it?
I would like to try to scaled down the union controlnet, but what I tried didnt work. What did you do to convert it to fp8 and is the source available for it?
import torch, os
from safetensors.torch import load_file, save_file
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
mantissa_scaled = torch.where(
normal_mask,
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
)
mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator)
return mantissa_scaled.floor() / (2**MANTISSA_BITS)
def manual_stochastic_round_to_float8(x, dtype, generator=None):
if dtype == torch.float8_e4m3fn:
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
elif dtype == torch.float8_e5m2:
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 5, 2, 15
else:
raise ValueError("Unsupported dtype")
x = x.half()
sign = torch.sign(x)
abs_x = x.abs()
sign = torch.where(abs_x == 0, 0, sign)
# Combine exponent calculation and clamping
exponent = torch.clamp(
torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS,
0, 2**EXPONENT_BITS - 1
)
# Combine mantissa calculation and rounding
normal_mask = ~(exponent == 0)
abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator)
sign *= torch.where(
normal_mask,
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
)
return sign
def stochastic_rounding(value, dtype, seed=0):
if dtype == torch.float32:
return value.to(dtype=torch.float32)
if dtype == torch.float16:
return value.to(dtype=torch.float16)
if dtype == torch.bfloat16:
return value.to(dtype=torch.bfloat16)
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
generator = torch.Generator(device=value.device)
generator.manual_seed(seed)
output = torch.empty_like(value, dtype=dtype)
num_slices = max(1, (value.numel() / (4096 * 4096)))
slice_size = max(1, round(value.shape[0] / num_slices))
for i in range(0, value.shape[0], slice_size):
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
return output
return value.to(dtype=dtype)
def process_model(input_file):
directory, filename = os.path.split(input_file)
output_file = os.path.join(directory, f"fp8-{filename}")
state_dict = load_file(input_file)
for k, v in state_dict.items():
if not args.full_weights and 'weight' not in k:
continue
v = v.to(torch.device('cuda')).to(torch.float32)
out_weight = stochastic_rounding(v, torch.float8_e4m3fn, seed=42)
out_weight = out_weight.to(torch.float8_e4m3fn)
state_dict[k] = out_weight
save_file(state_dict, output_file)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Process some model paths.')
parser.add_argument('file', type=str, help='model weights path')
parser.add_argument('--full_weights', action='store_true', help='convert all weights if set')
args = parser.parse_args()
model_file = args.file
print(f"Processing model: {model_file}")
process_model(model_file)
print(f"Processed and saved: fp8-{os.path.basename(model_file)}")
input('Press Enter for exit...')