ritutweets46 Aditibaheti commited on
Commit
c167a04
·
verified ·
1 Parent(s): 4ac8a6d

8 bit onxx (#5)

Browse files

- 8 bit onxx (a35a0a82d6202fd081cef55fdf17fadf23a43395)


Co-authored-by: Aditi Baheti <[email protected]>

Files changed (1) hide show
  1. app.py +20 -1
app.py CHANGED
@@ -5,6 +5,10 @@ from diffusers import DiffusionPipeline
5
  import torch
6
  from huggingface_hub import login
7
  import os
 
 
 
 
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
@@ -16,17 +20,32 @@ login(token=HUGGINGFACE_TOKEN)
16
  base_model_repo = "stabilityai/stable-diffusion-3-medium-diffusers"
17
  lora_weights_path = "./pytorch_lora_weights.safetensors"
18
 
19
- # Load the base model
20
  pipeline = DiffusionPipeline.from_pretrained(
21
  base_model_repo,
22
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
23
  use_auth_token=HUGGINGFACE_TOKEN
24
  )
 
 
25
  pipeline.load_lora_weights(lora_weights_path)
26
  pipeline.enable_sequential_cpu_offload() # Efficient memory usage
27
  pipeline.enable_xformers_memory_efficient_attention() # Enable xformers memory efficient attention
28
  pipeline = pipeline.to(device)
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  MAX_SEED = np.iinfo(np.int32).max
31
  MAX_IMAGE_SIZE = 768 # Reduce max image size to fit within memory constraints
32
 
 
5
  import torch
6
  from huggingface_hub import login
7
  import os
8
+ import bitsandbytes as bnb
9
+ import onnx
10
+ import onnxruntime as ort
11
+ from onnxruntime.quantization import quantize_dynamic, QuantType
12
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
 
20
  base_model_repo = "stabilityai/stable-diffusion-3-medium-diffusers"
21
  lora_weights_path = "./pytorch_lora_weights.safetensors"
22
 
23
+ # Load the base model with 8-bit precision
24
  pipeline = DiffusionPipeline.from_pretrained(
25
  base_model_repo,
26
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
27
  use_auth_token=HUGGINGFACE_TOKEN
28
  )
29
+ bnb.optim.load_int8_model(pipeline.model, device=device)
30
+
31
  pipeline.load_lora_weights(lora_weights_path)
32
  pipeline.enable_sequential_cpu_offload() # Efficient memory usage
33
  pipeline.enable_xformers_memory_efficient_attention() # Enable xformers memory efficient attention
34
  pipeline = pipeline.to(device)
35
 
36
+ # Export to ONNX
37
+ onnx_model_path = "model.onnx"
38
+ pipeline.model.eval()
39
+ dummy_input = torch.randn(1, 3, 512, 512, device=device)
40
+ torch.onnx.export(pipeline.model, dummy_input, onnx_model_path, export_params=True, opset_version=11, do_constant_folding=True, input_names=['input'], output_names=['output'])
41
+
42
+ # Quantize ONNX model to 8-bit
43
+ quantized_model_path = "model_quantized.onnx"
44
+ quantize_dynamic(onnx_model_path, quantized_model_path, weight_type=QuantType.QUInt8)
45
+
46
+ # Load quantized ONNX model
47
+ session = ort.InferenceSession(quantized_model_path)
48
+
49
  MAX_SEED = np.iinfo(np.int32).max
50
  MAX_IMAGE_SIZE = 768 # Reduce max image size to fit within memory constraints
51