zww
commited on
Commit
·
ba99fa4
1
Parent(s):
4929721
functional handler with gcs support
Browse files- Dockerfile +30 -0
- replace_handler.ipynb +143 -0
- stable_diffusion_handler.py +6 -5
Dockerfile
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM pytorch/torchserve:latest-gpu
|
2 |
+
|
3 |
+
# set user root
|
4 |
+
USER root
|
5 |
+
|
6 |
+
RUN pip install --upgrade pip
|
7 |
+
|
8 |
+
# Install dependencies
|
9 |
+
RUN pip install diffusers transformers accelerate invisible-watermark nvgpu google-cloud-storage tensorrt
|
10 |
+
|
11 |
+
# Copying model files
|
12 |
+
COPY ./config.properties /home/model-server/config.properties
|
13 |
+
COPY ./sketch-model-3.mar /home/model-server/sketch-model-3.mar
|
14 |
+
|
15 |
+
USER model-server
|
16 |
+
|
17 |
+
# Expose health and prediction listener ports from the image
|
18 |
+
EXPOSE 7080
|
19 |
+
EXPOSE 7081
|
20 |
+
|
21 |
+
# # Generate MAR file
|
22 |
+
|
23 |
+
CMD ["torchserve", \
|
24 |
+
"--start", \
|
25 |
+
"--ts-config=/home/model-server/config.properties", \
|
26 |
+
"--models", \
|
27 |
+
"sketch-model-3.mar", \
|
28 |
+
"--model-store", \
|
29 |
+
"/home/model-server"]
|
30 |
+
|
replace_handler.ipynb
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 2,
|
6 |
+
"id": "db06489e-2e2e-4d7f-bfab-13cf42261688",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stdout",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"im here\n"
|
14 |
+
]
|
15 |
+
}
|
16 |
+
],
|
17 |
+
"source": [
|
18 |
+
"print('im here')\n"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "code",
|
23 |
+
"execution_count": 5,
|
24 |
+
"id": "9217fed7-ed85-4592-a480-ef15f0632501",
|
25 |
+
"metadata": {},
|
26 |
+
"outputs": [
|
27 |
+
{
|
28 |
+
"name": "stdout",
|
29 |
+
"output_type": "stream",
|
30 |
+
"text": [
|
31 |
+
"/home/jupyter\n"
|
32 |
+
]
|
33 |
+
}
|
34 |
+
],
|
35 |
+
"source": [
|
36 |
+
"cd .."
|
37 |
+
]
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"cell_type": "code",
|
41 |
+
"execution_count": 6,
|
42 |
+
"id": "502b64af-59f1-4ae9-89f3-378458ba52be",
|
43 |
+
"metadata": {},
|
44 |
+
"outputs": [
|
45 |
+
{
|
46 |
+
"name": "stdout",
|
47 |
+
"output_type": "stream",
|
48 |
+
"text": [
|
49 |
+
"replacing old file\n",
|
50 |
+
"Untitled.ipynb\n",
|
51 |
+
"README.md\n",
|
52 |
+
"config.properties\n",
|
53 |
+
".gitattributes\n",
|
54 |
+
"model_index.json\n",
|
55 |
+
"Untitled1.ipynb\n",
|
56 |
+
"Dockerfile\n",
|
57 |
+
"stable_diffusion_handler.py\n",
|
58 |
+
"replacing stable_diffusion_handler.py\n",
|
59 |
+
"tokenizer_2/tokenizer_config.json\n",
|
60 |
+
"tokenizer_2/special_tokens_map.json\n",
|
61 |
+
"tokenizer_2/vocab.json\n",
|
62 |
+
"tokenizer_2/merges.txt\n",
|
63 |
+
"text_encoder/config.json\n",
|
64 |
+
"text_encoder/pytorch_model.bin\n",
|
65 |
+
"text_encoder_2/config.json\n",
|
66 |
+
"text_encoder_2/pytorch_model.bin\n",
|
67 |
+
"unet/config.json\n",
|
68 |
+
"unet/diffusion_pytorch_model.bin\n",
|
69 |
+
"tokenizer/tokenizer_config.json\n",
|
70 |
+
"tokenizer/special_tokens_map.json\n",
|
71 |
+
"tokenizer/vocab.json\n",
|
72 |
+
"tokenizer/merges.txt\n",
|
73 |
+
"scheduler/scheduler_config.json\n",
|
74 |
+
"vae/config.json\n",
|
75 |
+
"vae/diffusion_pytorch_model.bin\n",
|
76 |
+
"MAR-INF/MANIFEST.json\n"
|
77 |
+
]
|
78 |
+
}
|
79 |
+
],
|
80 |
+
"source": [
|
81 |
+
"import zipfile\n",
|
82 |
+
"\n",
|
83 |
+
"# Path to the existing .mar file and the file to be replaced\n",
|
84 |
+
"mar_path = \"export/sketch-model-3.mar\"\n",
|
85 |
+
"file_to_replace = \"stable_diffusion_handler.py\"\n",
|
86 |
+
"new_file_path = \"sketch-model-3/stable_diffusion_handler.py\"\n",
|
87 |
+
"\n",
|
88 |
+
"# # Create a temporary .mar file\n",
|
89 |
+
"temp_mar_path = \"sketch-model-3-updated.mar\"\n",
|
90 |
+
"\n",
|
91 |
+
"print(\"replacing old file\")\n",
|
92 |
+
"# Open the existing .mar and the temporary .mar\n",
|
93 |
+
"with zipfile.ZipFile(mar_path, 'r') as zip_ref, zipfile.ZipFile(temp_mar_path, 'w', zipfile.ZIP_STORED) as new_zip:\n",
|
94 |
+
" # Loop through existing files\n",
|
95 |
+
" for item in zip_ref.infolist():\n",
|
96 |
+
" print(item.filename)\n",
|
97 |
+
" buffer = zip_ref.read(item.filename)\n",
|
98 |
+
" \n",
|
99 |
+
" # Replace the file if it matches the target file name\n",
|
100 |
+
" if item.filename == file_to_replace:\n",
|
101 |
+
" print('replacing ', item.filename)\n",
|
102 |
+
" with open(new_file_path, \"rb\") as f:\n",
|
103 |
+
" new_buffer = f.read()\n",
|
104 |
+
" new_zip.writestr(item, new_buffer)\n",
|
105 |
+
" else:\n",
|
106 |
+
" new_zip.writestr(item, buffer)\n",
|
107 |
+
"\n",
|
108 |
+
"print('done')\n",
|
109 |
+
"# Remove the original .mar file and replace it with the new one\n",
|
110 |
+
"import os\n",
|
111 |
+
"os.remove(mar_path)\n",
|
112 |
+
"os.rename(temp_mar_path, mar_path)"
|
113 |
+
]
|
114 |
+
}
|
115 |
+
],
|
116 |
+
"metadata": {
|
117 |
+
"environment": {
|
118 |
+
"kernel": "conda-root-py",
|
119 |
+
"name": "workbench-notebooks.m110",
|
120 |
+
"type": "gcloud",
|
121 |
+
"uri": "gcr.io/deeplearning-platform-release/workbench-notebooks:m110"
|
122 |
+
},
|
123 |
+
"kernelspec": {
|
124 |
+
"display_name": "Python 3",
|
125 |
+
"language": "python",
|
126 |
+
"name": "conda-root-py"
|
127 |
+
},
|
128 |
+
"language_info": {
|
129 |
+
"codemirror_mode": {
|
130 |
+
"name": "ipython",
|
131 |
+
"version": 3
|
132 |
+
},
|
133 |
+
"file_extension": ".py",
|
134 |
+
"mimetype": "text/x-python",
|
135 |
+
"name": "python",
|
136 |
+
"nbconvert_exporter": "python",
|
137 |
+
"pygments_lexer": "ipython3",
|
138 |
+
"version": "3.10.12"
|
139 |
+
}
|
140 |
+
},
|
141 |
+
"nbformat": 4,
|
142 |
+
"nbformat_minor": 5
|
143 |
+
}
|
stable_diffusion_handler.py
CHANGED
@@ -36,6 +36,7 @@ class DiffusersHandler(BaseHandler, ABC):
|
|
36 |
"""
|
37 |
|
38 |
logger.info("Loading diffusion model")
|
|
|
39 |
|
40 |
self.manifest = ctx.manifest
|
41 |
properties = ctx.system_properties
|
@@ -97,10 +98,10 @@ class DiffusersHandler(BaseHandler, ABC):
|
|
97 |
logger.info("Generated image: '%s'", inferences)
|
98 |
return inferences
|
99 |
|
100 |
-
def postprocess(self,
|
101 |
"""Post Process Function converts the generated image into Torchserve readable format.
|
102 |
Args:
|
103 |
-
|
104 |
Returns:
|
105 |
(list): Returns a list of the images.
|
106 |
"""
|
@@ -108,7 +109,7 @@ class DiffusersHandler(BaseHandler, ABC):
|
|
108 |
client = storage.Client()
|
109 |
bucket = client.get_bucket(bucket_name)
|
110 |
outputs = []
|
111 |
-
for image in
|
112 |
image_name = str(uuid.uuid4())
|
113 |
|
114 |
blob = bucket.blob(image_name + '.png')
|
@@ -119,8 +120,8 @@ class DiffusersHandler(BaseHandler, ABC):
|
|
119 |
blob.upload_from_file(tmp, content_type='image/png')
|
120 |
|
121 |
# generate txt file with the image name and the prompt inside
|
122 |
-
blob = bucket.blob(image_name + '.txt')
|
123 |
-
blob.upload_from_string(self.prompt)
|
124 |
|
125 |
outputs.append('https://storage.googleapis.com/' + bucket_name + '/' + image_name + '.png')
|
126 |
return outputs
|
|
|
36 |
"""
|
37 |
|
38 |
logger.info("Loading diffusion model")
|
39 |
+
logger.info("I'm totally new and updated")
|
40 |
|
41 |
self.manifest = ctx.manifest
|
42 |
properties = ctx.system_properties
|
|
|
98 |
logger.info("Generated image: '%s'", inferences)
|
99 |
return inferences
|
100 |
|
101 |
+
def postprocess(self, inference_outputs):
|
102 |
"""Post Process Function converts the generated image into Torchserve readable format.
|
103 |
Args:
|
104 |
+
inference_outputs (list): It contains the generated image of the input text.
|
105 |
Returns:
|
106 |
(list): Returns a list of the images.
|
107 |
"""
|
|
|
109 |
client = storage.Client()
|
110 |
bucket = client.get_bucket(bucket_name)
|
111 |
outputs = []
|
112 |
+
for image in inference_outputs:
|
113 |
image_name = str(uuid.uuid4())
|
114 |
|
115 |
blob = bucket.blob(image_name + '.png')
|
|
|
120 |
blob.upload_from_file(tmp, content_type='image/png')
|
121 |
|
122 |
# generate txt file with the image name and the prompt inside
|
123 |
+
# blob = bucket.blob(image_name + '.txt')
|
124 |
+
# blob.upload_from_string(self.prompt)
|
125 |
|
126 |
outputs.append('https://storage.googleapis.com/' + bucket_name + '/' + image_name + '.png')
|
127 |
return outputs
|