subaqua commited on
Commit
425d93d
1 Parent(s): b678a86

Update as_safetensors+fp16_en.ipynb

Browse files
Files changed (1) hide show
  1. as_safetensors+fp16_en.ipynb +78 -96
as_safetensors+fp16_en.ipynb CHANGED
@@ -14,6 +14,17 @@
14
  }
15
  },
16
  "cells": [
 
 
 
 
 
 
 
 
 
 
 
17
  {
18
  "cell_type": "markdown",
19
  "source": [
@@ -39,6 +50,7 @@
39
  "cell_type": "code",
40
  "source": [
41
  "!pip install torch safetensors\n",
 
42
  "!pip install wget"
43
  ],
44
  "metadata": {
@@ -63,27 +75,20 @@
63
  "#@markdown Please specify the model name or download link for Google Drive, separated by commas\n",
64
  "#@markdown - If it is the model name on Google Drive, specify it as a relative path to My Drive\n",
65
  "#@markdown - If it is a download link, copy the link address by right-clicking and paste it in place of the link below\n",
66
- "\n",
67
  "import shutil\n",
68
  "import urllib.parse\n",
69
  "import urllib.request\n",
70
  "import wget\n",
71
  "\n",
72
- "models = \"Please use your own model in place of this example, example.safetensors, https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.ckpt\" #@param {type:\"string\"}\n",
73
- "models = [m.strip() for m in models.split(\",\") if not models == \"\"]\n",
74
  "for model in models:\n",
75
  " if 0 < len(urllib.parse.urlparse(model).scheme): # if model is url\n",
76
  " wget.download(model)\n",
77
- " # once the bug on python 3.8 is fixed, replace the above code with the following code\n",
78
- " ## model_data = urllib.request.urlopen(model).read()\n",
79
- " ## with open(os.path.basename(model), mode=\"wb\") as f:\n",
80
- " ## f.write(model_data)\n",
81
- " elif model.endswith((\".ckpt\", \".safetensors\", \".pt\", \".pth\")):\n",
82
- " from_ = \"/content/drive/MyDrive/\" + model\n",
83
- " to_ = \"/content/\" + model\n",
84
- " shutil.copy(from_, to_)\n",
85
  " else:\n",
86
- " print(f\"\\\"{model}\\\"URLではなく、正しい形式のファイルでもありません\")"
87
  ],
88
  "metadata": {
89
  "cellView": "form",
@@ -92,66 +97,65 @@
92
  "execution_count": null,
93
  "outputs": []
94
  },
95
- {
96
- "cell_type": "markdown",
97
- "source": [
98
- "if you use a relatively newer model such as SD2.1, run the following code"
99
- ],
100
- "metadata": {
101
- "id": "m1mHzOMjcDhz"
102
- }
103
- },
104
- {
105
- "cell_type": "code",
106
- "source": [
107
- "!pip install pytorch-lightning"
108
- ],
109
- "metadata": {
110
- "id": "TkrmByc0aYVN"
111
- },
112
- "execution_count": null,
113
- "outputs": []
114
- },
115
- {
116
- "cell_type": "markdown",
117
- "source": [
118
- "Run either of the following two codes. If you run out of memory and crash, use a smaller model or a paid high-memory runtime"
119
- ],
120
- "metadata": {
121
- "id": "0SUK6Alv2ItS"
122
- }
123
- },
124
  {
125
  "cell_type": "code",
126
  "source": [
127
- "#@title <font size=\"-0\">If you specify the name of the model you want to convert and convert it manually</font>\n",
 
 
128
  "import os\n",
 
129
  "import torch\n",
130
  "import safetensors.torch\n",
 
131
  "\n",
132
- "model = \"v2-1_768-ema-pruned.ckpt\" #@param {type:\"string\"}\n",
133
- "model_name, model_ext = os.path.splitext(model)\n",
134
  "as_fp16 = True #@param {type:\"boolean\"}\n",
135
  "save_directly_to_Google_Drive = True #@param {type:\"boolean\"}\n",
 
136
  "\n",
137
- "with torch.no_grad():\n",
138
- " if model_ext == \".safetensors\":\n",
139
- " weights = safetensors.torch.load_file(model_name + model_ext, device=\"cpu\")\n",
140
- " elif model_ext == \".ckpt\":\n",
141
- " weights = torch.load(model_name + model_ext, map_location=torch.device('cpu'))[\"state_dict\"]\n",
142
- " else:\n",
143
- " raise Exception(\"対応形式は.ckptと.safetensorsです\\n\" + f\"\\\"{model}\\\"は対応形式ではありません\")\n",
144
- " if as_fp16:\n",
145
- " model_name = model_name + \"-fp16\"\n",
146
- " for key in weights.keys():\n",
147
- " weights[key] = weights[key].half()\n",
148
  " if save_directly_to_Google_Drive:\n",
149
  " os.chdir(\"/content/drive/MyDrive\")\n",
150
- " safetensors.torch.save_file(weights, model_name + \".safetensors\")\n",
151
- " os.chdir(\"/content\")\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  " else:\n",
153
- " safetensors.torch.save_file(weights, model_name + \".safetensors\")\n",
154
- " del weights\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  "\n",
156
  "!reset"
157
  ],
@@ -163,48 +167,26 @@
163
  "outputs": []
164
  },
165
  {
166
- "cell_type": "code",
167
  "source": [
168
- "#@title <font size=\"-0\">If you automatically convert all pre-loaded models</font>\n",
169
- "import os\n",
170
- "import glob\n",
171
- "import torch\n",
172
- "import safetensors.torch\n",
173
- "\n",
174
- "as_fp16 = True #@param {type:\"boolean\"}\n",
175
- "save_directly_to_Google_Drive = True #@param {type:\"boolean\"}\n",
176
  "\n",
177
- "with torch.no_grad():\n",
178
- " model_paths = glob.glob(r\"/content/*.ckpt\") + glob.glob(r\"/content/*.safetensors\") + glob.glob(r\"/content/*.pt\") + glob.glob(r\"/content/*.pth\")\n",
179
- " for model_path in model_paths:\n",
180
- " model_name, model_ext = os.path.splitext(os.path.basename(model_path))\n",
181
- " if model_ext == \".safetensors\":\n",
182
- " weights = safetensors.torch.load_file(model_name + model_ext, device=\"cpu\")\n",
183
- " elif model_ext == \".ckpt\":\n",
184
- " weights = torch.load(model_name + model_ext, map_location=torch.device('cpu'))[\"state_dict\"]\n",
185
- " else:\n",
186
- " print(\"対応形式は.ckpt\tと.safetensorsです\\n\" + f\"\\\"{model}\\\"は対応形式ではありません\")\n",
187
- " break\n",
188
- " if as_fp16:\n",
189
- " model_name = model_name + \"-fp16\"\n",
190
- " for key in weights.keys():\n",
191
- " weights[key] = weights[key].half()\n",
192
- " if save_directly_to_Google_Drive:\n",
193
- " os.chdir(\"/content/drive/MyDrive\")\n",
194
- " safetensors.torch.save_file(weights, model_name + \".safetensors\")\n",
195
- " os.chdir(\"/content\")\n",
196
- " else:\n",
197
- " safetensors.torch.save_file(weights, model_name + \".safetensors\")\n",
198
- " del weights\n",
199
  "\n",
200
- "!reset"
201
  ],
202
  "metadata": {
203
- "id": "5TUvrW5VzLst",
204
- "cellView": "form"
205
- },
206
- "execution_count": null,
207
- "outputs": []
208
  },
209
  {
210
  "cell_type": "markdown",
 
14
  }
15
  },
16
  "cells": [
17
+ {
18
+ "cell_type": "markdown",
19
+ "source": [
20
+ "### This is a script that converts the format of the model (.ckpt/.safetensors)\n",
21
+ "#### It also converts the .yaml file included with the SD2.x series\n",
22
+ "#### It can also be saved as fp16 as an option"
23
+ ],
24
+ "metadata": {
25
+ "id": "fAIY_GORNEYa"
26
+ }
27
+ },
28
  {
29
  "cell_type": "markdown",
30
  "source": [
 
50
  "cell_type": "code",
51
  "source": [
52
  "!pip install torch safetensors\n",
53
+ "!pip install pytorch-lightning\n",
54
  "!pip install wget"
55
  ],
56
  "metadata": {
 
75
  "#@markdown Please specify the model name or download link for Google Drive, separated by commas\n",
76
  "#@markdown - If it is the model name on Google Drive, specify it as a relative path to My Drive\n",
77
  "#@markdown - If it is a download link, copy the link address by right-clicking and paste it in place of the link below\n",
 
78
  "import shutil\n",
79
  "import urllib.parse\n",
80
  "import urllib.request\n",
81
  "import wget\n",
82
  "\n",
83
+ "models = \"Specify_the_model_in_this_way_if_the_model_is_on_My_Drive.safetensors, https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.ckpt, https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.yaml\" #@param {type:\"string\"}\n",
84
+ "models = [m.strip() for m in models.split(\",\")]\n",
85
  "for model in models:\n",
86
  " if 0 < len(urllib.parse.urlparse(model).scheme): # if model is url\n",
87
  " wget.download(model)\n",
88
+ " elif model.endswith((\".ckpt\", \".safetensors\")):\n",
89
+ " shutil.copy(\"/content/drive/MyDrive/\" + model, \"/content/\" + model) # get the model from mydrive\n",
 
 
 
 
 
 
90
  " else:\n",
91
+ " print(f\"\\\"{model}\\\" is not a URL and is also not a file with a proper extension\")"
92
  ],
93
  "metadata": {
94
  "cellView": "form",
 
97
  "execution_count": null,
98
  "outputs": []
99
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  {
101
  "cell_type": "code",
102
  "source": [
103
+ "#@title <font size=\"-0\">Convert the Models</font>\n",
104
+ "#@markdown Specify the models to be converted, separated by commas<br>\n",
105
+ "#@markdown If nothing is inputted, all loaded models will be converted\n",
106
  "import os\n",
107
+ "import glob\n",
108
  "import torch\n",
109
  "import safetensors.torch\n",
110
+ "from functools import partial\n",
111
  "\n",
112
+ "models = \"wd-1-4-anime_e1.ckpt, wd-1-4-anime_e1.yaml\" #@param {type:\"string\"}\n",
 
113
  "as_fp16 = True #@param {type:\"boolean\"}\n",
114
  "save_directly_to_Google_Drive = True #@param {type:\"boolean\"}\n",
115
+ "save_type = \".safetensors\" #@param [\".safetensors\", \".ckpt\"]\n",
116
  "\n",
117
+ "def convert_yaml(file_name):\n",
118
+ " with open(file_name) as f:\n",
119
+ " yaml = f.read()\n",
 
 
 
 
 
 
 
 
120
  " if save_directly_to_Google_Drive:\n",
121
  " os.chdir(\"/content/drive/MyDrive\")\n",
122
+ " is_safe = save_type == \".safetensors\"\n",
123
+ " yaml = yaml.replace(f\"use_checkpoint: {is_safe}\", f\"use_checkpoint: {not is_safe}\")\n",
124
+ " if as_fp16:\n",
125
+ " yaml = yaml.replace(\"use_fp16: False\", \"use_fp16: True\")\n",
126
+ " file_name = os.path.splitext(file_name)[0] + \"-fp16.yaml\"\n",
127
+ " with open(file_name, mode=\"w\") as f:\n",
128
+ " f.write(yaml)\n",
129
+ " os.chdir(\"/content\")\n",
130
+ "\n",
131
+ "if models == \"\":\n",
132
+ " models = [os.path.basename(m) for m in glob.glob(r\"/content/*.ckpt\") + glob.glob(r\"/content/*.safetensors\") + glob.glob(r\"/content/*.yaml\")]\n",
133
+ "else:\n",
134
+ " models = [m.strip() for m in models.split(\",\")]\n",
135
+ "\n",
136
+ "for model in models:\n",
137
+ " model_name, model_ext = os.path.splitext(model)\n",
138
+ " if model_ext == \".yaml\":\n",
139
+ " convert_yaml(model)\n",
140
+ " elif (model_ext != \".safetensors\") & (model_ext != \".ckpt\"):\n",
141
+ " print(\"The supported formats are only .ckpt, .safetensors, and .yaml\\n\" + f\"\\\"{model}\\\" is not a supported format\")\n",
142
  " else:\n",
143
+ " load_model = partial(safetensors.torch.load_file, device=\"cpu\") if model_ext == \".safetensors\" else partial(torch.load, map_location=torch.device(\"cpu\"))\n",
144
+ " save_model = safetensors.torch.save_file if save_type == \".safetensors\" else torch.save\n",
145
+ " # convert model\n",
146
+ " with torch.no_grad():\n",
147
+ " weights = load_model(model)\n",
148
+ " if \"state_dict\" in weights:\n",
149
+ " weights = weights[\"state_dict\"]\n",
150
+ " if as_fp16:\n",
151
+ " model_name = model_name + \"-fp16\"\n",
152
+ " for key in weights.keys():\n",
153
+ " weights[key] = weights[key].half()\n",
154
+ " if save_directly_to_Google_Drive:\n",
155
+ " os.chdir(\"/content/drive/MyDrive\")\n",
156
+ " save_model(weights, model_name + save_type)\n",
157
+ " os.chdir(\"/content\")\n",
158
+ " del weights\n",
159
  "\n",
160
  "!reset"
161
  ],
 
167
  "outputs": []
168
  },
169
  {
170
+ "cell_type": "markdown",
171
  "source": [
172
+ "If you are converting SD2.x series models, etc., be sure to download/convert the accompanying configuration file (a .yaml file with the same name as the model) at the same time.\n",
 
 
 
 
 
 
 
173
  "\n",
174
+ "It can be converted in the same way as the model."
175
+ ],
176
+ "metadata": {
177
+ "id": "SWTFKmGFLec6"
178
+ }
179
+ },
180
+ {
181
+ "cell_type": "markdown",
182
+ "source": [
183
+ "If you run out of memory and crash, you can use a smaller model or a paid high memory runtime.\n",
 
 
 
 
 
 
 
 
 
 
 
 
184
  "\n",
185
+ "With the free ~12GB runtime, you can convert models up to ~10GB."
186
  ],
187
  "metadata": {
188
+ "id": "0SUK6Alv2ItS"
189
+ }
 
 
 
190
  },
191
  {
192
  "cell_type": "markdown",