khulnasoft commited on
Commit
1092afe
·
verified ·
1 Parent(s): 3578d39

Create models_server.py

Browse files
Files changed (1) hide show
  1. models_server.py +258 -0
models_server.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ start = time.time()
2
+
3
+ pipe = pipes[model_id]["model"]
4
+
5
+ if "device" in pipes[model_id]:
6
+ try:
7
+ pipe.to(pipes[model_id]["device"])
8
+ except:
9
+ pipe.device = torch.device(pipes[model_id]["device"])
10
+ pipe.model.to(pipes[model_id]["device"])
11
+
12
+ result = None
13
+ try:
14
+ # text to video
15
+ if model_id == "damo-vilab/text-to-video-ms-1.7b":
16
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
17
+ # pipe.enable_model_cpu_offload()
18
+ prompt = data["text"]
19
+ video_frames = pipe(prompt, num_inference_steps=50, num_frames=40).frames
20
+ file_name = str(uuid.uuid4())[:4]
21
+ video_path = export_to_video(video_frames, f"public/videos/{file_name}.mp4")
22
+
23
+ new_file_name = str(uuid.uuid4())[:4]
24
+ os.system(f"ffmpeg -i {video_path} -vcodec libx264 public/videos/{new_file_name}.mp4")
25
+
26
+ if os.path.exists(f"public/videos/{new_file_name}.mp4"):
27
+ result = {"path": f"/videos/{new_file_name}.mp4"}
28
+ else:
29
+ result = {"path": f"/videos/{file_name}.mp4"}
30
+
31
+ # controlnet
32
+ if model_id.startswith("lllyasviel/sd-controlnet-"):
33
+ pipe.controlnet.to('cpu')
34
+ pipe.controlnet = pipes[model_id]["control"].to(pipes[model_id]["device"])
35
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
36
+ control_image = load_image(data["img_url"])
37
+ # generator = torch.manual_seed(66)
38
+ out_image: Image = pipe(data["text"], num_inference_steps=20, image=control_image).images[0]
39
+ file_name = str(uuid.uuid4())[:4]
40
+ out_image.save(f"public/images/{file_name}.png")
41
+ result = {"path": f"/images/{file_name}.png"}
42
+
43
+ if model_id.endswith("-control"):
44
+ image = load_image(data["img_url"])
45
+ if "scribble" in model_id:
46
+ control = pipe(image, scribble = True)
47
+ elif "canny" in model_id:
48
+ control = pipe(image, low_threshold=100, high_threshold=200)
49
+ else:
50
+ control = pipe(image)
51
+ file_name = str(uuid.uuid4())[:4]
52
+ control.save(f"public/images/{file_name}.png")
53
+ result = {"path": f"/images/{file_name}.png"}
54
+
55
+ # image to image
56
+ if model_id == "lambdalabs/sd-image-variations-diffusers":
57
+ im = load_image(data["img_url"])
58
+ file_name = str(uuid.uuid4())[:4]
59
+ with open(f"public/images/{file_name}.png", "wb") as f:
60
+ f.write(data)
61
+ tform = transforms.Compose([
62
+ transforms.ToTensor(),
63
+ transforms.Resize(
64
+ (224, 224),
65
+ interpolation=transforms.InterpolationMode.BICUBIC,
66
+ antialias=False,
67
+ ),
68
+ transforms.Normalize(
69
+ [0.48145466, 0.4578275, 0.40821073],
70
+ [0.26862954, 0.26130258, 0.27577711]),
71
+ ])
72
+ inp = tform(im).to(pipes[model_id]["device"]).unsqueeze(0)
73
+ out = pipe(inp, guidance_scale=3)
74
+ out["images"][0].save(f"public/images/{file_name}.jpg")
75
+ result = {"path": f"/images/{file_name}.jpg"}
76
+
77
+ # image to text
78
+ if model_id == "Salesforce/blip-image-captioning-large":
79
+ raw_image = load_image(data["img_url"]).convert('RGB')
80
+ text = data["text"]
81
+ inputs = pipes[model_id]["processor"](raw_image, return_tensors="pt").to(pipes[model_id]["device"])
82
+ out = pipe.generate(**inputs)
83
+ caption = pipes[model_id]["processor"].decode(out[0], skip_special_tokens=True)
84
+ result = {"generated text": caption}
85
+ if model_id == "ydshieh/vit-gpt2-coco-en":
86
+ img_url = data["img_url"]
87
+ generated_text = pipe(img_url)[0]['generated_text']
88
+ result = {"generated text": generated_text}
89
+ if model_id == "nlpconnect/vit-gpt2-image-captioning":
90
+ image = load_image(data["img_url"]).convert("RGB")
91
+ pixel_values = pipes[model_id]["feature_extractor"](images=image, return_tensors="pt").pixel_values
92
+ pixel_values = pixel_values.to(pipes[model_id]["device"])
93
+ generated_ids = pipe.generate(pixel_values, **{"max_length": 200, "num_beams": 1})
94
+ generated_text = pipes[model_id]["tokenizer"].batch_decode(generated_ids, skip_special_tokens=True)[0]
95
+ result = {"generated text": generated_text}
96
+ # image to text: OCR
97
+ if model_id == "microsoft/trocr-base-printed" or model_id == "microsoft/trocr-base-handwritten":
98
+ image = load_image(data["img_url"]).convert("RGB")
99
+ pixel_values = pipes[model_id]["processor"](image, return_tensors="pt").pixel_values
100
+ pixel_values = pixel_values.to(pipes[model_id]["device"])
101
+ generated_ids = pipe.generate(pixel_values)
102
+ generated_text = pipes[model_id]["processor"].batch_decode(generated_ids, skip_special_tokens=True)[0]
103
+ result = {"generated text": generated_text}
104
+
105
+ # text to image
106
+ if model_id == "runwayml/stable-diffusion-v1-5":
107
+ file_name = str(uuid.uuid4())[:4]
108
+ text = data["text"]
109
+ out = pipe(prompt=text)
110
+ out["images"][0].save(f"public/images/{file_name}.jpg")
111
+ result = {"path": f"/images/{file_name}.jpg"}
112
+
113
+ # object detection
114
+ if model_id == "google/owlvit-base-patch32" or model_id == "facebook/detr-resnet-101":
115
+ img_url = data["img_url"]
116
+ open_types = ["cat", "couch", "person", "car", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird"]
117
+ result = pipe(img_url, candidate_labels=open_types)
118
+
119
+ # VQA
120
+ if model_id == "dandelin/vilt-b32-finetuned-vqa":
121
+ question = data["text"]
122
+ img_url = data["img_url"]
123
+ result = pipe(question=question, image=img_url)
124
+
125
+ #DQA
126
+ if model_id == "impira/layoutlm-document-qa":
127
+ question = data["text"]
128
+ img_url = data["img_url"]
129
+ result = pipe(img_url, question)
130
+
131
+ # depth-estimation
132
+ if model_id == "Intel/dpt-large":
133
+ output = pipe(data["img_url"])
134
+ image = output['depth']
135
+ name = str(uuid.uuid4())[:4]
136
+ image.save(f"public/images/{name}.jpg")
137
+ result = {"path": f"/images/{name}.jpg"}
138
+
139
+ if model_id == "Intel/dpt-hybrid-midas" and model_id == "Intel/dpt-large":
140
+ image = load_image(data["img_url"])
141
+ inputs = pipes[model_id]["feature_extractor"](images=image, return_tensors="pt")
142
+ with torch.no_grad():
143
+ outputs = pipe(**inputs)
144
+ predicted_depth = outputs.predicted_depth
145
+ prediction = torch.nn.functional.interpolate(
146
+ predicted_depth.unsqueeze(1),
147
+ size=image.size[::-1],
148
+ mode="bicubic",
149
+ align_corners=False,
150
+ )
151
+ output = prediction.squeeze().cpu().numpy()
152
+ formatted = (output * 255 / np.max(output)).astype("uint8")
153
+ image = Image.fromarray(formatted)
154
+ name = str(uuid.uuid4())[:4]
155
+ image.save(f"public/images/{name}.jpg")
156
+ result = {"path": f"/images/{name}.jpg"}
157
+
158
+ # TTS
159
+ if model_id == "espnet/kan-bayashi_ljspeech_vits":
160
+ text = data["text"]
161
+ wav = pipe(text)["wav"]
162
+ name = str(uuid.uuid4())[:4]
163
+ sf.write(f"public/audios/{name}.wav", wav.cpu().numpy(), pipe.fs, "PCM_16")
164
+ result = {"path": f"/audios/{name}.wav"}
165
+
166
+ if model_id == "microsoft/speecht5_tts":
167
+ text = data["text"]
168
+ inputs = pipes[model_id]["processor"](text=text, return_tensors="pt")
169
+ embeddings_dataset = pipes[model_id]["embeddings_dataset"]
170
+ speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(pipes[model_id]["device"])
171
+ pipes[model_id]["vocoder"].to(pipes[model_id]["device"])
172
+ speech = pipe.generate_speech(inputs["input_ids"].to(pipes[model_id]["device"]), speaker_embeddings, vocoder=pipes[model_id]["vocoder"])
173
+ name = str(uuid.uuid4())[:4]
174
+ sf.write(f"public/audios/{name}.wav", speech.cpu().numpy(), samplerate=16000)
175
+ result = {"path": f"/audios/{name}.wav"}
176
+
177
+ # ASR
178
+ if model_id == "openai/whisper-base" or model_id == "microsoft/speecht5_asr":
179
+ audio_url = data["audio_url"]
180
+ result = { "text": pipe(audio_url)["text"]}
181
+
182
+ # audio to audio
183
+ if model_id == "JorisCos/DCCRNet_Libri1Mix_enhsingle_16k":
184
+ audio_url = data["audio_url"]
185
+ wav, sr = torchaudio.load(audio_url)
186
+ with torch.no_grad():
187
+ result_wav = pipe(wav.to(pipes[model_id]["device"]))
188
+ name = str(uuid.uuid4())[:4]
189
+ sf.write(f"public/audios/{name}.wav", result_wav.cpu().squeeze().numpy(), sr)
190
+ result = {"path": f"/audios/{name}.wav"}
191
+
192
+ if model_id == "microsoft/speecht5_vc":
193
+ audio_url = data["audio_url"]
194
+ wav, sr = torchaudio.load(audio_url)
195
+ inputs = pipes[model_id]["processor"](audio=wav, sampling_rate=sr, return_tensors="pt")
196
+ embeddings_dataset = pipes[model_id]["embeddings_dataset"]
197
+ speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
198
+ pipes[model_id]["vocoder"].to(pipes[model_id]["device"])
199
+ speech = pipe.generate_speech(inputs["input_ids"].to(pipes[model_id]["device"]), speaker_embeddings, vocoder=pipes[model_id]["vocoder"])
200
+ name = str(uuid.uuid4())[:4]
201
+ sf.write(f"public/audios/{name}.wav", speech.cpu().numpy(), samplerate=16000)
202
+ result = {"path": f"/audios/{name}.wav"}
203
+
204
+ # segmentation
205
+ if model_id == "facebook/detr-resnet-50-panoptic":
206
+ result = []
207
+ segments = pipe(data["img_url"])
208
+ image = load_image(data["img_url"])
209
+
210
+ colors = []
211
+ for i in range(len(segments)):
212
+ colors.append((random.randint(100, 255), random.randint(100, 255), random.randint(100, 255), 50))
213
+
214
+ for segment in segments:
215
+ mask = segment["mask"]
216
+ mask = mask.convert('L')
217
+ layer = Image.new('RGBA', mask.size, colors[i])
218
+ image.paste(layer, (0, 0), mask)
219
+ name = str(uuid.uuid4())[:4]
220
+ image.save(f"public/images/{name}.jpg")
221
+ result = {"path": f"/images/{name}.jpg"}
222
+
223
+ if model_id == "facebook/maskformer-swin-base-coco" or model_id == "facebook/maskformer-swin-large-ade":
224
+ image = load_image(data["img_url"])
225
+ inputs = pipes[model_id]["feature_extractor"](images=image, return_tensors="pt").to(pipes[model_id]["device"])
226
+ outputs = pipe(**inputs)
227
+ result = pipes[model_id]["feature_extractor"].post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
228
+ predicted_panoptic_map = result["segmentation"].cpu().numpy()
229
+ predicted_panoptic_map = Image.fromarray(predicted_panoptic_map.astype(np.uint8))
230
+ name = str(uuid.uuid4())[:4]
231
+ predicted_panoptic_map.save(f"public/images/{name}.jpg")
232
+ result = {"path": f"/images/{name}.jpg"}
233
+
234
+ except Exception as e:
235
+ print(e)
236
+ traceback.print_exc()
237
+ result = {"error": {"message": "Error when running the model inference."}}
238
+
239
+ if "device" in pipes[model_id]:
240
+ try:
241
+ pipe.to("cpu")
242
+ torch.cuda.empty_cache()
243
+ except:
244
+ pipe.device = torch.device("cpu")
245
+ pipe.model.to("cpu")
246
+ torch.cuda.empty_cache()
247
+
248
+ pipes[model_id]["using"] = False
249
+
250
+ if result is None:
251
+ result = {"error": {"message": "model not found"}}
252
+
253
+ end = time.time()
254
+ during = end - start
255
+ print(f"[ complete {model_id} ] {during}s")
256
+ print(f"[ result {model_id} ] {result}")
257
+
258
+ return result