ishworrsubedii commited on
Commit
5ec57bf
·
1 Parent(s): 61c0885

update: inference time, added nto_cto combined

Browse files
Files changed (2) hide show
  1. src/api/batch_api.py +64 -62
  2. src/api/nto_api.py +130 -2
src/api/batch_api.py CHANGED
@@ -39,13 +39,14 @@ async def rt_cto(
39
  logger.info("-" * 50)
40
  logger.info(">>> REAL-TIME CTO STARTED <<<")
41
  logger.info(f"Parameters: clothing_list={c_list}")
 
 
42
  try:
43
  clothing_list = [item.strip() for item in c_list.split(",")]
44
- logger.info(f">>> CLOTHING LIST: {clothing_list} <<<")
45
-
46
  image_bytes = await image.read()
47
  pil_image = Image.open(BytesIO(image_bytes)).convert("RGB")
48
- logger.info(">>> IMAGE LOADED SUCCESSFULLY <<<")
 
49
  except Exception as e:
50
  logger.error(f">>> IMAGE LOADING ERROR: {str(e)} <<<")
51
  return {"error": "Error reading image", "code": 500}
@@ -53,17 +54,21 @@ async def rt_cto(
53
  async def generate():
54
  logger.info("-" * 50)
55
  logger.info(">>> CLOTHING TRY ON V2 STARTED <<<")
56
- start_time = time.time()
57
 
 
 
58
  try:
59
  mask, _, _ = await pipeline.shoulderPointMaskGeneration_(image=pil_image)
60
- logger.info(">>> MASK GENERATION COMPLETED <<<")
 
61
  except Exception as e:
62
  logger.error(f">>> MASK GENERATION ERROR: {str(e)} <<<")
63
  yield json.dumps({"error": "Error generating mask", "code": 500}) + "\n"
64
  await asyncio.sleep(0.1)
65
  return
66
 
 
 
67
  try:
68
  mask_img_base_64, act_img_base_64 = BytesIO(), BytesIO()
69
  mask.save(mask_img_base_64, format="WEBP")
@@ -73,7 +78,8 @@ async def rt_cto(
73
 
74
  mask_data_uri = f"data:image/webp;base64,{mask_bytes_}"
75
  image_data_uri = f"data:image/webp;base64,{image_bytes_}"
76
- logger.info(">>> IMAGE ENCODING COMPLETED <<<")
 
77
  except Exception as e:
78
  logger.error(f">>> IMAGE ENCODING ERROR: {str(e)} <<<")
79
  yield json.dumps({"error": "Error converting images to base64", "code": 500}) + "\n"
@@ -84,25 +90,32 @@ async def rt_cto(
84
  if not clothing_type:
85
  continue
86
 
87
- input = {
88
- "mask": mask_data_uri,
89
- "image": image_data_uri,
90
- "prompt": f"Dull {clothing_type}, non-reflective clothing, properly worn, natural setting, elegant, natural look, neckline without jewellery, simple, perfect eyes, perfect face, perfect body, high quality, realistic, photorealistic, high resolution,traditional full sleeve blouse",
91
- "negative_prompt": "necklaces, jewellery, jewelry, necklace, neckpiece, garland, chain, neck wear, jewelled neck, jeweled neck, necklace on neck, jewellery on neck, accessories, watermark, text, changed background, wider body, narrower body, bad proportions, extra limbs, mutated hands, changed sizes, altered proportions, unnatural body proportions, blury, ugly",
92
- "num_inference_steps": 25
93
- }
94
-
95
  try:
96
- output = replicate_run_cto(input)
97
- logger.info(f">>> REPLICATE PROCESSING COMPLETED FOR {clothing_type} <<<")
 
 
 
 
 
 
 
 
98
 
99
  output_url = str(output[0]) if output and output[0] else None
100
- total_inference_time = round((time.time() - start_time), 2)
101
 
102
  result = {
103
  "code": 200,
104
  "output": output_url,
105
- "inference_time": total_inference_time,
 
 
 
 
 
 
106
  "clothing_type": clothing_type,
107
  "progress": f"{idx + 1}/{len(clothing_list)}"
108
  }
@@ -166,48 +179,51 @@ async def rt_nto(
166
  )
167
 
168
  async def generate():
169
- start_time = time.time()
 
 
 
 
170
 
171
  for idx, (necklace_id, category) in enumerate(zip(necklace_ids, categories)):
 
172
  try:
 
 
173
  jewellery_url = f"https://lvuhhlrkcuexzqtsbqyu.supabase.co/storage/v1/object/public/Stores/{storename}/{category}/image/{necklace_id}.png"
174
  jewellery = Image.open(returnBytesData(url=jewellery_url))
175
- logger.info(f">>> JEWELLERY IMAGE {necklace_id} LOADED SUCCESSFULLY <<<")
 
176
 
177
- # Process the necklace try-on
 
178
  result, headetText, mask = await pipeline.necklaceTryOn_(
179
  image=source_image,
180
  jewellery=jewellery,
181
  storename=storename
182
  )
 
183
 
184
- if result is None:
185
- error_result = {
186
- "error": "No face detected in the image",
187
- "code": 400,
188
- "necklace_id": necklace_id,
189
- "category": category,
190
- "progress": f"{idx + 1}/{len(necklace_ids)}"
191
- }
192
- yield json.dumps(error_result) + "\n"
193
- continue
194
-
195
- # Upload results concurrently
196
- logger.info(">>> UPLOADING RESULTS <<<")
197
  upload_tasks = [
198
  supabase_upload_and_return_url(prefix="necklace_try_on", image=result),
199
  supabase_upload_and_return_url(prefix="necklace_try_on_mask", image=mask)
200
  ]
201
  result_url, mask_url = await asyncio.gather(*upload_tasks)
202
-
203
- total_inference_time = round((time.time() - start_time), 2)
204
- logger.info(f">>> UPLOADING COMPLETED FOR {necklace_id} <<<")
205
 
206
  result = {
207
  "code": 200,
208
  "output": result_url,
209
  "mask": mask_url,
210
- "inference_time": total_inference_time,
 
 
 
 
 
 
211
  "necklace_id": necklace_id,
212
  "category": category,
213
  "progress": f"{idx + 1}/{len(necklace_ids)}"
@@ -273,31 +289,15 @@ async def rt_cto_nto(
273
  )
274
 
275
  async def generate():
276
- setup_time = time.time() # Track setup time separately
277
 
278
- try:
279
- mask, _, _ = await pipeline.shoulderPointMaskGeneration_(image=source_image)
280
- mask_generation_time = round(time.time() - setup_time, 2)
281
- logger.info(f">>> MASK GENERATION COMPLETED in {mask_generation_time}s <<<")
282
- except Exception as e:
283
- logger.error(f">>> MASK GENERATION ERROR: {str(e)} <<<")
284
- yield json.dumps({"error": "Error generating mask", "code": 500}) + "\n"
285
- return
286
-
287
- try:
288
- mask_img_base_64, act_img_base_64 = BytesIO(), BytesIO()
289
- mask.save(mask_img_base_64, format="WEBP")
290
- source_image.save(act_img_base_64, format="WEBP")
291
- mask_bytes_ = base64.b64encode(mask_img_base_64.getvalue()).decode("utf-8")
292
- image_bytes_ = base64.b64encode(act_img_base_64.getvalue()).decode("utf-8")
293
-
294
- mask_data_uri = f"data:image/webp;base64,{mask_bytes_}"
295
- image_data_uri = f"data:image/webp;base64,{image_bytes_}"
296
- logger.info(">>> IMAGE ENCODING COMPLETED <<<")
297
- except Exception as e:
298
- logger.error(f">>> IMAGE ENCODING ERROR: {str(e)} <<<")
299
- yield json.dumps({"error": "Error converting images to base64", "code": 500}) + "\n"
300
- return
301
 
302
  for idx, clothing_type in enumerate(clothing_list):
303
  iteration_start_time = time.time()
@@ -345,6 +345,8 @@ async def rt_cto_nto(
345
  "code": 200,
346
  "output": result_url,
347
  "timing": {
 
 
348
  "cto_inference": cto_time,
349
  "nto_inference": nto_time,
350
  "upload": upload_time,
 
39
  logger.info("-" * 50)
40
  logger.info(">>> REAL-TIME CTO STARTED <<<")
41
  logger.info(f"Parameters: clothing_list={c_list}")
42
+
43
+ setup_start_time = time.time()
44
  try:
45
  clothing_list = [item.strip() for item in c_list.split(",")]
 
 
46
  image_bytes = await image.read()
47
  pil_image = Image.open(BytesIO(image_bytes)).convert("RGB")
48
+ setup_time = round(time.time() - setup_start_time, 2)
49
+ logger.info(f">>> IMAGE LOADED SUCCESSFULLY in {setup_time}s <<<")
50
  except Exception as e:
51
  logger.error(f">>> IMAGE LOADING ERROR: {str(e)} <<<")
52
  return {"error": "Error reading image", "code": 500}
 
54
  async def generate():
55
  logger.info("-" * 50)
56
  logger.info(">>> CLOTHING TRY ON V2 STARTED <<<")
 
57
 
58
+ # Mask generation timing
59
+ mask_start_time = time.time()
60
  try:
61
  mask, _, _ = await pipeline.shoulderPointMaskGeneration_(image=pil_image)
62
+ mask_time = round(time.time() - mask_start_time, 2)
63
+ logger.info(f">>> MASK GENERATION COMPLETED in {mask_time}s <<<")
64
  except Exception as e:
65
  logger.error(f">>> MASK GENERATION ERROR: {str(e)} <<<")
66
  yield json.dumps({"error": "Error generating mask", "code": 500}) + "\n"
67
  await asyncio.sleep(0.1)
68
  return
69
 
70
+ # Encoding timing
71
+ encoding_start_time = time.time()
72
  try:
73
  mask_img_base_64, act_img_base_64 = BytesIO(), BytesIO()
74
  mask.save(mask_img_base_64, format="WEBP")
 
78
 
79
  mask_data_uri = f"data:image/webp;base64,{mask_bytes_}"
80
  image_data_uri = f"data:image/webp;base64,{image_bytes_}"
81
+ encoding_time = round(time.time() - encoding_start_time, 2)
82
+ logger.info(f">>> IMAGE ENCODING COMPLETED in {encoding_time}s <<<")
83
  except Exception as e:
84
  logger.error(f">>> IMAGE ENCODING ERROR: {str(e)} <<<")
85
  yield json.dumps({"error": "Error converting images to base64", "code": 500}) + "\n"
 
90
  if not clothing_type:
91
  continue
92
 
93
+ iteration_start_time = time.time()
 
 
 
 
 
 
 
94
  try:
95
+ inference_start_time = time.time()
96
+ output = replicate_run_cto({
97
+ "mask": mask_data_uri,
98
+ "image": image_data_uri,
99
+ "prompt": f"Dull {clothing_type}, non-reflective clothing, properly worn, natural setting, elegant, natural look, neckline without jewellery, simple, perfect eyes, perfect face, perfect body, high quality, realistic, photorealistic, high resolution,traditional full sleeve blouse",
100
+ "negative_prompt": "necklaces, jewellery, jewelry, necklace, neckpiece, garland, chain, neck wear, jewelled neck, jeweled neck, necklace on neck, jewellery on neck, accessories, watermark, text, changed background, wider body, narrower body, bad proportions, extra limbs, mutated hands, changed sizes, altered proportions, unnatural body proportions, blury, ugly",
101
+ "num_inference_steps": 25
102
+ })
103
+ inference_time = round(time.time() - inference_start_time, 2)
104
+ logger.info(f">>> REPLICATE PROCESSING COMPLETED FOR {clothing_type} in {inference_time}s <<<")
105
 
106
  output_url = str(output[0]) if output and output[0] else None
107
+ iteration_time = round(time.time() - iteration_start_time, 2)
108
 
109
  result = {
110
  "code": 200,
111
  "output": output_url,
112
+ "timing": {
113
+ "setup": setup_time,
114
+ "mask_generation": mask_time,
115
+ "encoding": encoding_time,
116
+ "inference": inference_time,
117
+ "iteration": iteration_time
118
+ },
119
  "clothing_type": clothing_type,
120
  "progress": f"{idx + 1}/{len(clothing_list)}"
121
  }
 
179
  )
180
 
181
  async def generate():
182
+ setup_start_time = time.time() # Add setup timing
183
+
184
+ # After loading images
185
+ setup_time = round(time.time() - setup_start_time, 2)
186
+ logger.info(f">>> SETUP COMPLETED in {setup_time}s <<<")
187
 
188
  for idx, (necklace_id, category) in enumerate(zip(necklace_ids, categories)):
189
+ iteration_start_time = time.time()
190
  try:
191
+ # Load jewellery timing
192
+ jewellery_load_start = time.time()
193
  jewellery_url = f"https://lvuhhlrkcuexzqtsbqyu.supabase.co/storage/v1/object/public/Stores/{storename}/{category}/image/{necklace_id}.png"
194
  jewellery = Image.open(returnBytesData(url=jewellery_url))
195
+ jewellery_time = round(time.time() - jewellery_load_start, 2)
196
+ logger.info(f">>> JEWELLERY LOADED in {jewellery_time}s <<<")
197
 
198
+ # NTO timing
199
+ nto_start_time = time.time()
200
  result, headetText, mask = await pipeline.necklaceTryOn_(
201
  image=source_image,
202
  jewellery=jewellery,
203
  storename=storename
204
  )
205
+ nto_time = round(time.time() - nto_start_time, 2)
206
 
207
+ # Upload timing
208
+ upload_start_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
209
  upload_tasks = [
210
  supabase_upload_and_return_url(prefix="necklace_try_on", image=result),
211
  supabase_upload_and_return_url(prefix="necklace_try_on_mask", image=mask)
212
  ]
213
  result_url, mask_url = await asyncio.gather(*upload_tasks)
214
+ upload_time = round(time.time() - upload_start_time, 2)
 
 
215
 
216
  result = {
217
  "code": 200,
218
  "output": result_url,
219
  "mask": mask_url,
220
+ "timing": {
221
+ "setup": setup_time,
222
+ "jewellery_load": jewellery_time,
223
+ "nto_inference": nto_time,
224
+ "upload": upload_time,
225
+ "total_iteration": round(time.time() - iteration_start_time, 2)
226
+ },
227
  "necklace_id": necklace_id,
228
  "category": category,
229
  "progress": f"{idx + 1}/{len(necklace_ids)}"
 
289
  )
290
 
291
  async def generate():
292
+ setup_start_time = time.time()
293
 
294
+ # After mask generation
295
+ mask_time = round(time.time() - setup_start_time, 2)
296
+
297
+ # Encoding timing
298
+ encoding_start_time = time.time()
299
+ # After encoding
300
+ encoding_time = round(time.time() - encoding_start_time, 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
  for idx, clothing_type in enumerate(clothing_list):
303
  iteration_start_time = time.time()
 
345
  "code": 200,
346
  "output": result_url,
347
  "timing": {
348
+ "setup": mask_time, # Include setup time
349
+ "encoding": encoding_time,
350
  "cto_inference": cto_time,
351
  "nto_inference": nto_time,
352
  "upload": upload_time,
src/api/nto_api.py CHANGED
@@ -505,7 +505,6 @@ async def necklace_try_on_id(necklace_try_on_id: NecklaceTryOnIDEntity = Depends
505
 
506
  finally:
507
  if 'result' in locals(): del result
508
- if 'mask' in locals(): del mask
509
  gc.collect()
510
 
511
 
@@ -749,7 +748,7 @@ async def mannequin_nto(necklace_try_on_id: NecklaceTryOnIDEntity = Depends(pars
749
  "code": 404}, status_code=404)
750
 
751
  try:
752
- result,resized_img = await pipeline.necklaceTryOnMannequin_(image=image, jewellery=jewellery)
753
 
754
  if result is None:
755
  logger.error(">>> NO FACE DETECTED IN THE IMAGE <<<")
@@ -802,3 +801,132 @@ async def mannequin_nto(necklace_try_on_id: NecklaceTryOnIDEntity = Depends(pars
802
  finally:
803
  if 'result' in locals(): del result
804
  gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505
 
506
  finally:
507
  if 'result' in locals(): del result
 
508
  gc.collect()
509
 
510
 
 
748
  "code": 404}, status_code=404)
749
 
750
  try:
751
+ result, resized_img = await pipeline.necklaceTryOnMannequin_(image=image, jewellery=jewellery)
752
 
753
  if result is None:
754
  logger.error(">>> NO FACE DETECTED IN THE IMAGE <<<")
 
801
  finally:
802
  if 'result' in locals(): del result
803
  gc.collect()
804
+
805
+
806
+ @nto_cto_router.post("/nto_mto_combined")
807
+ async def combined_cto_nto(
808
+ image: UploadFile = File(...),
809
+ clothing_type: str = Form(...),
810
+ necklace_id: str = Form(...),
811
+ necklace_category: str = Form(...),
812
+ storename: str = Form(...)
813
+ ):
814
+ logger.info("-" * 50)
815
+ logger.info(">>> COMBINED CTO-NTO STARTED <<<")
816
+ logger.info(f"Parameters: storename={storename}, necklace_category={necklace_category}, "
817
+ f"necklace_id={necklace_id}, clothing_type={clothing_type}")
818
+ start_time = time.time()
819
+
820
+ def image_to_base64(img: Image.Image) -> str:
821
+ buffer = BytesIO()
822
+ img.save(buffer, format="WEBP", quality=85, optimize=True)
823
+ return f"data:image/webp;base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
824
+
825
+ try:
826
+ # Load source image and necklace
827
+ image_bytes = await image.read()
828
+ source_image = Image.open(BytesIO(image_bytes)).convert("RGB").resize((512, 512))
829
+
830
+ jewellery_url = f"https://lvuhhlrkcuexzqtsbqyu.supabase.co/storage/v1/object/public/Stores/{storename}/{necklace_category}/image/{necklace_id}.png"
831
+ necklace_image = Image.open(returnBytesData(url=jewellery_url)).convert("RGBA")
832
+ logger.info(">>> IMAGES LOADED SUCCESSFULLY <<<")
833
+ except Exception as e:
834
+ logger.error(f">>> IMAGE LOADING ERROR: {str(e)} <<<")
835
+ return JSONResponse(content={
836
+ "error": "Error loading images. Please verify the image and necklace availability.",
837
+ "code": 404
838
+ }, status_code=404)
839
+
840
+ try:
841
+ # Generate mask and shoulder points
842
+ mask_start_time = time.time()
843
+ mask, _, _ = await pipeline.shoulderPointMaskGeneration_(image=source_image)
844
+ mask_time = round(time.time() - mask_start_time, 2)
845
+ logger.info(f">>> MASK GENERATION COMPLETED in {mask_time}s <<<")
846
+
847
+ # Convert images to base64
848
+ encoding_start_time = time.time()
849
+ mask_data_uri, image_data_uri = await asyncio.gather(
850
+ asyncio.to_thread(image_to_base64, mask),
851
+ asyncio.to_thread(image_to_base64, source_image)
852
+ )
853
+ encoding_time = round(time.time() - encoding_start_time, 2)
854
+ logger.info(f">>> IMAGE ENCODING COMPLETED in {encoding_time}s <<<")
855
+
856
+ # Perform CTO
857
+ cto_start_time = time.time()
858
+ cto_output = replicate_run_cto({
859
+ "mask": mask_data_uri,
860
+ "image": image_data_uri,
861
+ "prompt": f"Dull {clothing_type}, non-reflective clothing, properly worn, natural setting, elegant, natural look, neckline without jewellery, simple, perfect eyes, perfect face, perfect body, high quality, realistic, photorealistic, high resolution,traditional full sleeve blouse",
862
+ "negative_prompt": "necklaces, jewellery, jewelry, necklace, neckpiece, garland, chain, neck wear, jewelled neck, jeweled neck, necklace on neck, jewellery on neck, accessories, watermark, text, changed background, wider body, narrower body, bad proportions, extra limbs, mutated hands, changed sizes, altered proportions, unnatural body proportions, blury, ugly",
863
+ "num_inference_steps": 20
864
+ })
865
+ cto_time = round(time.time() - cto_start_time, 2)
866
+ logger.info(f">>> CTO COMPLETED in {cto_time}s <<<")
867
+
868
+ if not cto_output or not isinstance(cto_output, (list, tuple)) or not cto_output[0]:
869
+ raise ValueError("Invalid output from clothing try-on")
870
+
871
+ # Get CTO result image
872
+ async with aiohttp.ClientSession() as session:
873
+ async with session.get(str(cto_output[0])) as response:
874
+ if response.status != 200:
875
+ raise HTTPException(status_code=response.status, detail="Failed to fetch CTO output")
876
+ cto_result_bytes = await response.read()
877
+
878
+ # Perform NTO
879
+ nto_start_time = time.time()
880
+ with BytesIO(cto_result_bytes) as buf:
881
+ cto_result_image = Image.open(buf).convert("RGB")
882
+ result, headerText, _ = await pipeline.necklaceTryOn_(
883
+ image=cto_result_image,
884
+ jewellery=necklace_image,
885
+ storename=storename
886
+ )
887
+ nto_time = round(time.time() - nto_start_time, 2)
888
+ logger.info(f">>> NTO COMPLETED in {nto_time}s <<<")
889
+
890
+ if result is None:
891
+ raise ValueError("Failed to process necklace try-on")
892
+
893
+ upload_start_time = time.time()
894
+ result_url = await supabase_upload_and_return_url(
895
+ prefix="combined_cto_nto",
896
+ image=result
897
+ )
898
+ upload_time = round(time.time() - upload_start_time, 2)
899
+ logger.info(f">>> RESULT UPLOADED in {upload_time}s <<<")
900
+
901
+ if not result_url:
902
+ raise ValueError("Failed to upload result image")
903
+
904
+ total_time = round(time.time() - start_time, 2)
905
+ response = {
906
+ "code": 200,
907
+ "output": result_url,
908
+ "timing": {
909
+ "mask_generation": mask_time,
910
+ "encoding": encoding_time,
911
+ "cto_inference": cto_time,
912
+ "nto_inference": nto_time,
913
+ "upload": upload_time,
914
+ "total": total_time
915
+ }
916
+ }
917
+
918
+ except ValueError as ve:
919
+ logger.error(f">>> PROCESSING ERROR: {str(ve)} <<<")
920
+ return JSONResponse(status_code=400, content={"error": str(ve), "code": 400})
921
+ except Exception as e:
922
+ logger.error(f">>> PROCESSING ERROR: {str(e)} <<<")
923
+ return JSONResponse(status_code=500, content={"error": "Error during image processing", "code": 500})
924
+ finally:
925
+ if 'result' in locals(): del result
926
+ gc.collect()
927
+
928
+ logger.info(f">>> TOTAL PROCESSING TIME: {total_time}s <<<")
929
+ logger.info(">>> COMBINED CTO-NTO COMPLETED SUCCESSFULLY <<<")
930
+ logger.info("-" * 50)
931
+
932
+ return JSONResponse(content=response, status_code=200)