ishworrsubedii commited on
Commit
61c0885
·
1 Parent(s): 1b92f63

fix: inference time

Browse files
Files changed (1) hide show
  1. src/api/batch_api.py +24 -16
src/api/batch_api.py CHANGED
@@ -273,11 +273,12 @@ async def rt_cto_nto(
273
  )
274
 
275
  async def generate():
276
- start_time = time.time()
277
 
278
  try:
279
  mask, _, _ = await pipeline.shoulderPointMaskGeneration_(image=source_image)
280
- logger.info(">>> MASK GENERATION COMPLETED <<<")
 
281
  except Exception as e:
282
  logger.error(f">>> MASK GENERATION ERROR: {str(e)} <<<")
283
  yield json.dumps({"error": "Error generating mask", "code": 500}) + "\n"
@@ -299,8 +300,10 @@ async def rt_cto_nto(
299
  return
300
 
301
  for idx, clothing_type in enumerate(clothing_list):
 
302
  try:
303
  # Perform CTO
 
304
  cto_output = replicate_run_cto({
305
  "mask": mask_data_uri,
306
  "image": image_data_uri,
@@ -308,18 +311,17 @@ async def rt_cto_nto(
308
  "negative_prompt": "necklaces, jewellery, jewelry, necklace, neckpiece, garland, chain, neck wear, jewelled neck, jeweled neck, necklace on neck, jewellery on neck, accessories, watermark, text, changed background, wider body, narrower body, bad proportions, extra limbs, mutated hands, changed sizes, altered proportions, unnatural body proportions, blury, ugly",
309
  "num_inference_steps": 25
310
  })
311
-
312
- if not cto_output or not isinstance(cto_output, (list, tuple)) or not cto_output[0]:
313
- raise ValueError("Invalid output from clothing try-on")
314
 
315
- # Get CTO result image
 
316
  async with aiohttp.ClientSession() as session:
317
  async with session.get(str(cto_output[0])) as response:
318
  if response.status != 200:
319
  raise ValueError("Failed to fetch CTO output")
320
  cto_result_bytes = await response.read()
321
 
322
- # Process NTO
323
  with BytesIO(cto_result_bytes) as buf:
324
  cto_result_image = Image.open(buf).convert("RGB")
325
  result, headerText, mask = await pipeline.necklaceTryOn_(
@@ -327,27 +329,33 @@ async def rt_cto_nto(
327
  jewellery=jewellery,
328
  storename=storename
329
  )
330
-
331
- if result is None:
332
- raise ValueError("Failed to process necklace try-on")
333
 
334
  # Upload result
335
- result_url = await supabase_upload_and_return_url(prefix="clothing_necklace_try_on", image=result)
336
- if not result_url:
337
- raise ValueError("Failed to upload result image")
 
 
 
338
 
339
- # Stream result
340
  output_result = {
341
  "code": 200,
342
  "output": result_url,
343
- "inference_time": round((time.time() - start_time), 2),
 
 
 
 
 
344
  "clothing_type": clothing_type,
345
  "progress": f"{idx + 1}/{len(clothing_list)}"
346
  }
347
  yield json.dumps(output_result) + "\n"
348
  await asyncio.sleep(0.1)
349
 
350
- # Clean up
351
  del result
352
  gc.collect()
353
 
 
273
  )
274
 
275
  async def generate():
276
+ setup_time = time.time() # Track setup time separately
277
 
278
  try:
279
  mask, _, _ = await pipeline.shoulderPointMaskGeneration_(image=source_image)
280
+ mask_generation_time = round(time.time() - setup_time, 2)
281
+ logger.info(f">>> MASK GENERATION COMPLETED in {mask_generation_time}s <<<")
282
  except Exception as e:
283
  logger.error(f">>> MASK GENERATION ERROR: {str(e)} <<<")
284
  yield json.dumps({"error": "Error generating mask", "code": 500}) + "\n"
 
300
  return
301
 
302
  for idx, clothing_type in enumerate(clothing_list):
303
+ iteration_start_time = time.time()
304
  try:
305
  # Perform CTO
306
+ cto_start_time = time.time()
307
  cto_output = replicate_run_cto({
308
  "mask": mask_data_uri,
309
  "image": image_data_uri,
 
311
  "negative_prompt": "necklaces, jewellery, jewelry, necklace, neckpiece, garland, chain, neck wear, jewelled neck, jeweled neck, necklace on neck, jewellery on neck, accessories, watermark, text, changed background, wider body, narrower body, bad proportions, extra limbs, mutated hands, changed sizes, altered proportions, unnatural body proportions, blury, ugly",
312
  "num_inference_steps": 25
313
  })
314
+ cto_time = round(time.time() - cto_start_time, 2)
315
+ logger.info(f">>> CTO COMPLETED for {clothing_type} in {cto_time}s <<<")
 
316
 
317
+ # Get CTO result and process NTO
318
+ nto_start_time = time.time()
319
  async with aiohttp.ClientSession() as session:
320
  async with session.get(str(cto_output[0])) as response:
321
  if response.status != 200:
322
  raise ValueError("Failed to fetch CTO output")
323
  cto_result_bytes = await response.read()
324
 
 
325
  with BytesIO(cto_result_bytes) as buf:
326
  cto_result_image = Image.open(buf).convert("RGB")
327
  result, headerText, mask = await pipeline.necklaceTryOn_(
 
329
  jewellery=jewellery,
330
  storename=storename
331
  )
332
+ nto_time = round(time.time() - nto_start_time, 2)
333
+ logger.info(f">>> NTO COMPLETED for {clothing_type} in {nto_time}s <<<")
 
334
 
335
  # Upload result
336
+ upload_start_time = time.time()
337
+ result_url = await supabase_upload_and_return_url(
338
+ prefix="clothing_necklace_try_on",
339
+ image=result
340
+ )
341
+ upload_time = round(time.time() - upload_start_time, 2)
342
 
343
+ # Stream result with detailed timing
344
  output_result = {
345
  "code": 200,
346
  "output": result_url,
347
+ "timing": {
348
+ "cto_inference": cto_time,
349
+ "nto_inference": nto_time,
350
+ "upload": upload_time,
351
+ "total_iteration": round(time.time() - iteration_start_time, 2)
352
+ },
353
  "clothing_type": clothing_type,
354
  "progress": f"{idx + 1}/{len(clothing_list)}"
355
  }
356
  yield json.dumps(output_result) + "\n"
357
  await asyncio.sleep(0.1)
358
 
 
359
  del result
360
  gc.collect()
361