atossou commited on
Commit
c09c4de
·
verified ·
1 Parent(s): fd3f0e7

Create convert_to_external_data.py

Browse files
Files changed (1) hide show
  1. onnx/convert_to_external_data.py +71 -0
onnx/convert_to_external_data.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnx
2
+ import pathlib
3
+ import glob
4
+ import os
5
+ import json
6
+ import tempfile
7
+ import shutil
8
+ import typer
9
+ from typing import Annotated
10
+
11
+ # To use:
12
+ # pip install onnx typer
13
+ # python convert_to_external_data --base-path "/path/to/directory"
14
+
15
+ def convert(model_path, save_path):
16
+ model = onnx.load(model_path)
17
+
18
+ external_data_name = f"{pathlib.Path(model_path).stem}.onnx_data"
19
+
20
+ # Create the new model in a temporary directory and copy all it's content back to save_path
21
+ # Doing this because if save_path is same as model_path & we directly write to model_path
22
+ # onnx will append to the external data path which would make it grow more than expected.
23
+ with tempfile.TemporaryDirectory() as tmp_dir:
24
+
25
+ tmp_model_path = os.path.join(tmp_dir, os.path.basename(model_path))
26
+
27
+ onnx.save_model(
28
+ model, tmp_model_path, save_as_external_data=True, location=external_data_name
29
+ )
30
+
31
+ file_names = os.listdir(tmp_dir)
32
+ target_dir = str(pathlib.Path(save_path).parent)
33
+ os.makedirs(target_dir, exist_ok=True)
34
+ for file_name in file_names:
35
+ shutil.copy2(os.path.join(tmp_dir, file_name), target_dir)
36
+
37
+
38
+ def main(base_path: Annotated[str, typer.Option()]):
39
+ """
40
+ This will convert recursively all onnx models in that directory to one with external data format.
41
+ """
42
+ # Convert all
43
+ for model_path in glob.glob(
44
+ os.path.join(base_path, "**/*.onnx"),
45
+ recursive=True,
46
+ ):
47
+ print("Converting", model_path)
48
+ convert(model_path, model_path)
49
+
50
+ # Find all config.json and add enable use_external_data_format
51
+ for config_path in glob.glob(
52
+ os.path.join(base_path, "**/config.json"),
53
+ recursive=True,
54
+ ):
55
+ print("Modifying", config_path)
56
+ # Load the JSON file
57
+ with open(config_path, "r") as infile:
58
+ config_data = json.load(infile)
59
+
60
+ config_data["transformers.js_config"] = config_data.get(
61
+ "transformers.js_config", {}
62
+ )
63
+ config_data["transformers.js_config"]["use_external_data_format"] = True
64
+
65
+ # Save the JSON file with additional config
66
+ with open(config_path, "w") as outfile:
67
+ json.dump(config_data, outfile, indent=4, ensure_ascii=False)
68
+
69
+
70
+ if __name__ == "__main__":
71
+ typer.run(main)