ishworrsubedii commited on
Commit
1b92f63
·
1 Parent(s): dfee28c

add: rt_ntocto combined

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. src/api/batch_api.py +132 -0
app.py CHANGED
@@ -45,7 +45,7 @@ async def verify_login_token(credentials: HTTPAuthorizationCredentials = Depends
45
  raise HTTPException(status_code=401, detail="Invalid token")
46
 
47
 
48
- app = FastAPI(dependencies=[Depends(verify_login_token)])
49
  app.include_router(nto_cto_router, tags=["NTO-CTO"])
50
  app.include_router(preprocessing_router, tags=["Image-Preprocessing"])
51
  app.include_router(image_regeneration_router, tags=["Image-Regeneration"])
@@ -60,4 +60,4 @@ app.add_middleware(
60
  allow_credentials=True,
61
  allow_methods=["*"],
62
  allow_headers=["*"],
63
- )
 
45
  raise HTTPException(status_code=401, detail="Invalid token")
46
 
47
 
48
+ app = FastAPI()
49
  app.include_router(nto_cto_router, tags=["NTO-CTO"])
50
  app.include_router(preprocessing_router, tags=["Image-Preprocessing"])
51
  app.include_router(image_regeneration_router, tags=["Image-Regeneration"])
 
60
  allow_credentials=True,
61
  allow_methods=["*"],
62
  allow_headers=["*"],
63
+ )
src/api/batch_api.py CHANGED
@@ -9,6 +9,7 @@ import time
9
  from io import BytesIO
10
  import json
11
  import asyncio
 
12
 
13
  from PIL import Image
14
  from fastapi import File, UploadFile, Form
@@ -241,3 +242,134 @@ async def rt_nto(
241
  "Transfer-Encoding": "chunked"
242
  }
243
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from io import BytesIO
10
  import json
11
  import asyncio
12
+ import aiohttp
13
 
14
  from PIL import Image
15
  from fastapi import File, UploadFile, Form
 
242
  "Transfer-Encoding": "chunked"
243
  }
244
  )
245
+
246
+
247
+ @batch_router.post("/rt_cto_nto")
248
+ async def rt_cto_nto(
249
+ image: UploadFile = File(...),
250
+ c_list: str = Form(...),
251
+ necklace_id: str = Form(...),
252
+ necklace_category: str = Form(...),
253
+ storename: str = Form(...)
254
+ ):
255
+ logger.info("-" * 50)
256
+ logger.info(">>> REAL-TIME CTO-NTO STARTED <<<")
257
+ logger.info(f"Parameters: storename={storename}, necklace_category={necklace_category}, "
258
+ f"necklace_id={necklace_id}, clothing_list={c_list}")
259
+
260
+ try:
261
+ clothing_list = [item.strip() for item in c_list.split(",")]
262
+ image_bytes = await image.read()
263
+ source_image = Image.open(BytesIO(image_bytes)).convert("RGB")
264
+
265
+ jewellery_url = f"https://lvuhhlrkcuexzqtsbqyu.supabase.co/storage/v1/object/public/Stores/{storename}/{necklace_category}/image/{necklace_id}.png"
266
+ jewellery = Image.open(returnBytesData(url=jewellery_url)).convert("RGBA")
267
+ logger.info(">>> IMAGES LOADED SUCCESSFULLY <<<")
268
+ except Exception as e:
269
+ logger.error(f">>> INITIAL SETUP ERROR: {str(e)} <<<")
270
+ return JSONResponse(
271
+ content={"error": "Error in initial setup", "details": str(e), "code": 500},
272
+ status_code=500
273
+ )
274
+
275
+ async def generate():
276
+ start_time = time.time()
277
+
278
+ try:
279
+ mask, _, _ = await pipeline.shoulderPointMaskGeneration_(image=source_image)
280
+ logger.info(">>> MASK GENERATION COMPLETED <<<")
281
+ except Exception as e:
282
+ logger.error(f">>> MASK GENERATION ERROR: {str(e)} <<<")
283
+ yield json.dumps({"error": "Error generating mask", "code": 500}) + "\n"
284
+ return
285
+
286
+ try:
287
+ mask_img_base_64, act_img_base_64 = BytesIO(), BytesIO()
288
+ mask.save(mask_img_base_64, format="WEBP")
289
+ source_image.save(act_img_base_64, format="WEBP")
290
+ mask_bytes_ = base64.b64encode(mask_img_base_64.getvalue()).decode("utf-8")
291
+ image_bytes_ = base64.b64encode(act_img_base_64.getvalue()).decode("utf-8")
292
+
293
+ mask_data_uri = f"data:image/webp;base64,{mask_bytes_}"
294
+ image_data_uri = f"data:image/webp;base64,{image_bytes_}"
295
+ logger.info(">>> IMAGE ENCODING COMPLETED <<<")
296
+ except Exception as e:
297
+ logger.error(f">>> IMAGE ENCODING ERROR: {str(e)} <<<")
298
+ yield json.dumps({"error": "Error converting images to base64", "code": 500}) + "\n"
299
+ return
300
+
301
+ for idx, clothing_type in enumerate(clothing_list):
302
+ try:
303
+ # Perform CTO
304
+ cto_output = replicate_run_cto({
305
+ "mask": mask_data_uri,
306
+ "image": image_data_uri,
307
+ "prompt": f"Dull {clothing_type}, non-reflective clothing, properly worn, natural setting, elegant, natural look, neckline without jewellery, simple, perfect eyes, perfect face, perfect body, high quality, realistic, photorealistic, high resolution,traditional full sleeve blouse",
308
+ "negative_prompt": "necklaces, jewellery, jewelry, necklace, neckpiece, garland, chain, neck wear, jewelled neck, jeweled neck, necklace on neck, jewellery on neck, accessories, watermark, text, changed background, wider body, narrower body, bad proportions, extra limbs, mutated hands, changed sizes, altered proportions, unnatural body proportions, blury, ugly",
309
+ "num_inference_steps": 25
310
+ })
311
+
312
+ if not cto_output or not isinstance(cto_output, (list, tuple)) or not cto_output[0]:
313
+ raise ValueError("Invalid output from clothing try-on")
314
+
315
+ # Get CTO result image
316
+ async with aiohttp.ClientSession() as session:
317
+ async with session.get(str(cto_output[0])) as response:
318
+ if response.status != 200:
319
+ raise ValueError("Failed to fetch CTO output")
320
+ cto_result_bytes = await response.read()
321
+
322
+ # Process NTO
323
+ with BytesIO(cto_result_bytes) as buf:
324
+ cto_result_image = Image.open(buf).convert("RGB")
325
+ result, headerText, mask = await pipeline.necklaceTryOn_(
326
+ image=cto_result_image,
327
+ jewellery=jewellery,
328
+ storename=storename
329
+ )
330
+
331
+ if result is None:
332
+ raise ValueError("Failed to process necklace try-on")
333
+
334
+ # Upload result
335
+ result_url = await supabase_upload_and_return_url(prefix="clothing_necklace_try_on", image=result)
336
+ if not result_url:
337
+ raise ValueError("Failed to upload result image")
338
+
339
+ # Stream result
340
+ output_result = {
341
+ "code": 200,
342
+ "output": result_url,
343
+ "inference_time": round((time.time() - start_time), 2),
344
+ "clothing_type": clothing_type,
345
+ "progress": f"{idx + 1}/{len(clothing_list)}"
346
+ }
347
+ yield json.dumps(output_result) + "\n"
348
+ await asyncio.sleep(0.1)
349
+
350
+ # Clean up
351
+ del result
352
+ gc.collect()
353
+
354
+ except Exception as e:
355
+ logger.error(f">>> PROCESSING ERROR FOR {clothing_type}: {str(e)} <<<")
356
+ error_result = {
357
+ "error": f"Error processing clothing {clothing_type}",
358
+ "details": str(e),
359
+ "code": 500,
360
+ "clothing_type": clothing_type,
361
+ "progress": f"{idx + 1}/{len(clothing_list)}"
362
+ }
363
+ yield json.dumps(error_result) + "\n"
364
+ await asyncio.sleep(0.1)
365
+
366
+ return StreamingResponse(
367
+ generate(),
368
+ media_type="application/x-ndjson",
369
+ headers={
370
+ "Cache-Control": "no-cache",
371
+ "Connection": "keep-alive",
372
+ "X-Accel-Buffering": "no",
373
+ "Transfer-Encoding": "chunked"
374
+ }
375
+ )