ishworrsubedii commited on
Commit
6104dd3
·
verified ·
1 Parent(s): 85fd07f

Initial commit

Browse files
.github/workflows/dockerhub.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Publish Docker image
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+
7
+ jobs:
8
+ push_to_registry:
9
+ name: Push Docker image to Docker Hub
10
+ runs-on: ubuntu-latest
11
+ permissions:
12
+ packages: write
13
+ contents: read
14
+ attestations: write
15
+ steps:
16
+ - name: Check out the repo
17
+ uses: actions/checkout@v4
18
+
19
+ - name: Log in to Docker Hub
20
+ uses: docker/login-action@f4ef78c080cd8ba55a85445d5b36e214a81df20a
21
+ with:
22
+ username: ${{ secrets.DOCKER_USERNAME }}
23
+ password: ${{ secrets.DOCKER_PASSWORD }}
24
+
25
+ - name: Extract metadata (tags, labels) for Docker
26
+ id: meta
27
+ uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7
28
+ with:
29
+ images: ishworrsubedii/web_plugin_api
30
+
31
+ - name: Build and push Docker image
32
+ id: push
33
+ uses: docker/build-push-action@3b5e8027fcad23fda98b2e3ac259d8d67585f671
34
+ with:
35
+ context: .
36
+ file: ./Dockerfile
37
+ push: true
38
+ tags: ${{ steps.meta.outputs.tags }}
39
+ labels: ${{ steps.meta.outputs.labels }}
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ .idea/
6
+
7
+ *.log
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /api
4
+
5
+ COPY . /api
6
+
7
+ RUN apt-get update && apt-get install -y
8
+
9
+ RUN apt install libgl1-mesa-glx -y
10
+
11
+ RUN apt-get install 'ffmpeg'\
12
+ 'libsm6'\
13
+ 'libxext6' -y
14
+
15
+ RUN pip install -r requirements.txt
16
+
17
+ RUN ulimit -s 2000
18
+
19
+ EXPOSE 8000
20
+
21
+ CMD ["uvicorn", "app:app", "--host","0.0.0.0","--port","8000"]
app.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi.encoders import jsonable_encoder
2
+ from src.utils import supabaseGetPublicURL, deductAndTrackCredit, returnBytesData
3
+ from fastapi import FastAPI, File, UploadFile, Header, HTTPException, Form, Depends
4
+ from src.pipelines.completePipeline import Pipeline
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ from fastapi.responses import JSONResponse
7
+ from supabase import create_client, Client
8
+ from typing import Dict, Union, List
9
+ from io import BytesIO
10
+ from PIL import Image
11
+ import pandas as pd
12
+ import base64
13
+ import os
14
+ from pydantic import BaseModel
15
+
16
+ pipeline = Pipeline()
17
+ app = FastAPI(title="Magical Mirror Web Plugin")
18
+
19
+ app.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=["*"],
22
+ allow_credentials=True,
23
+ allow_methods=["*"],
24
+ allow_headers=["*"],
25
+ )
26
+ url: str = os.environ["SUPABASE_URL"]
27
+ key: str = os.environ["SUPABASE_KEY"]
28
+ supabase: Client = create_client(url, key)
29
+
30
+
31
+ @app.post("/productData/{storeId}")
32
+ async def product_data(
33
+ storeId: str,
34
+ filterattributes: List[Dict[str, Union[str, int, float]]],
35
+ storename: str = Header(default="default")
36
+ ):
37
+ """Filters product data based on the provided attributes and store ID."""
38
+
39
+ try:
40
+ response = supabase.table('MagicMirror').select("*").execute()
41
+ df = pd.DataFrame(response.dict()["data"])
42
+
43
+ df = df[df["StoreName"] == storeId]
44
+
45
+ # Preprocess filterattributes to handle multiple or duplicated attributes
46
+ attribute_dict = {}
47
+ for attr in filterattributes:
48
+ key, value = list(attr.items())[
49
+ 0] # This will convert the dictionary into a list and get the key and value.
50
+ if key in attribute_dict: # This will check if the key is already present in the dictionary.
51
+ if isinstance(attribute_dict[key],
52
+ list): # This will create a list if there are multiple values for the same key and we are doing or operation.
53
+ attribute_dict[key].append(value) # This will append the value to the list.
54
+ else:
55
+ attribute_dict[key] = [attribute_dict[key], value]
56
+ else:
57
+ attribute_dict[key] = [value] # This will create a list if there is only one value for the key.
58
+
59
+ priceFrom = None
60
+ priceTo = None
61
+ weightFrom = None
62
+ weightTo = None
63
+ weightAscending = None
64
+ priceAscending = None
65
+ idAscending = None
66
+ dateAscending = None
67
+
68
+ for key, value in attribute_dict.items():
69
+ if key == 'priceFrom':
70
+ priceFrom = value[0]
71
+
72
+ elif key == "priceTo":
73
+ priceTo = value[0]
74
+
75
+ elif key == "priceAscending":
76
+ priceAscending = value[0]
77
+
78
+ elif key == "weightFrom":
79
+ weightFrom = value[0]
80
+
81
+ elif key == "weightTo":
82
+ weightTo = value[0]
83
+
84
+ elif key == "weightAscending":
85
+ weightAscending = value[0]
86
+
87
+ elif key == "idAscending":
88
+ idAscending = value[0]
89
+
90
+ elif key == "dateAscending":
91
+ dateAscending = value[0]
92
+
93
+ df["image_url"] = df.apply(
94
+ lambda row: supabaseGetPublicURL(f"{row['StoreName']}/{row['Category']}/image/{row['Id']}.png"),
95
+ axis=1)
96
+ df["thumbnail_url"] = df.apply(
97
+ lambda row: supabaseGetPublicURL(f"{row['StoreName']}/{row['Category']}/thumbnail/{row['Id']}.png"),
98
+ axis=1)
99
+
100
+ df.reset_index(drop=True, inplace=True)
101
+ for key, values in attribute_dict.items():
102
+ try:
103
+ df = df[df[key].isin(values)]
104
+
105
+ except:
106
+ pass
107
+
108
+ # applying filter for price and weight
109
+ if priceFrom is not None:
110
+ df = df[df["Price"] >= priceFrom]
111
+ if priceTo is not None:
112
+ df = df[df["Price"] <= priceTo]
113
+ if weightFrom is not None:
114
+ df = df[df["Weight"] >= weightFrom]
115
+ if weightTo is not None:
116
+ df = df[df["Weight"] <= weightTo]
117
+
118
+ if priceAscending is not None:
119
+ if priceAscending == 1:
120
+ value = True
121
+
122
+ else:
123
+ value = False
124
+ df = df.sort_values(by="Price", ascending=value)
125
+ if weightAscending is not None:
126
+ if weightAscending == 1:
127
+ value = True
128
+
129
+ else:
130
+ value = False
131
+ df = df.sort_values(by="Weight", ascending=value)
132
+
133
+ if idAscending is not None:
134
+ if idAscending == 1:
135
+ value = True
136
+ else:
137
+ value = False
138
+ df = df.sort_values(by="Id", ascending=value)
139
+
140
+ if dateAscending is not None:
141
+ if dateAscending == 1:
142
+ value = True
143
+ else:
144
+ value = False
145
+ df = df.sort_values(by="UpdatedAt", ascending=value)
146
+
147
+ df = df.drop(["CreatedAt", "EstimatedPrice"], axis=1)
148
+
149
+ result = {}
150
+ for _, row in df.iterrows():
151
+ category = row["Category"]
152
+ if category not in result: # this is for checking duplicate category
153
+ result[category] = []
154
+ result[category].append(row.to_dict())
155
+
156
+ return JSONResponse(content=jsonable_encoder(result)) # this will convert the result into json format.
157
+
158
+ except Exception as e:
159
+ raise HTTPException(status_code=500, detail=f"Failed to fetch or process data: {e}")
160
+
161
+
162
+ class NecklaceTryOnIDEntity(BaseModel):
163
+ necklaceImageId: str
164
+ necklaceCategory: str
165
+ storename: str
166
+ api_token: str
167
+
168
+
169
+ async def parse_necklace_try_on_id(necklaceImageId: str = Form(...),
170
+ necklaceCategory: str = Form(...),
171
+ storename: str = Form(...),
172
+ api_token: str = Form(...)) -> NecklaceTryOnIDEntity:
173
+ return NecklaceTryOnIDEntity(
174
+ necklaceImageId=necklaceImageId,
175
+ necklaceCategory=necklaceCategory,
176
+ storename=storename,
177
+ api_token=api_token
178
+ )
179
+
180
+
181
+ @app.post("/necklaceTryOnID")
182
+ async def necklace_try_on_id(necklace_try_on_id: NecklaceTryOnIDEntity = Depends(parse_necklace_try_on_id),
183
+ image: UploadFile = File(...)):
184
+ data, _ = supabase.table("APIKeyList").select("*").filter("API_KEY", "eq",
185
+ necklace_try_on_id.api_token).execute()
186
+
187
+ api_key_actual = data[1][0]['API_KEY']
188
+ if api_key_actual != necklace_try_on_id.api_token:
189
+ return JSONResponse(content={"error": "Invalid API Key"}, status_code=401)
190
+
191
+ else:
192
+ imageBytes = await image.read()
193
+
194
+ jewellery_url = f"https://lvuhhlrkcuexzqtsbqyu.supabase.co/storage/v1/object/public/Stores/{necklace_try_on_id.storename}/{necklace_try_on_id.necklaceCategory}/image/{necklace_try_on_id.necklaceImageId}.png"
195
+
196
+ try:
197
+ image, jewellery = Image.open(BytesIO(imageBytes)), Image.open(returnBytesData(url=jewellery_url))
198
+
199
+ except:
200
+ error_message = {
201
+ "error": "The requested resource (Image, necklace category, or store) is not available. Please verify the availability and try again."
202
+ }
203
+
204
+ return JSONResponse(content=error_message, status_code=404)
205
+
206
+ result, headetText = await pipeline.necklaceTryOn(image=image, jewellery=jewellery,
207
+ storename=necklace_try_on_id.storename)
208
+
209
+ inMemFile = BytesIO()
210
+ result.save(inMemFile, format="WEBP", quality=85)
211
+ outputBytes = inMemFile.getvalue()
212
+ response = {
213
+ "output": f"data:image/WEBP;base64,{base64.b64encode(outputBytes).decode('utf-8')}"
214
+ }
215
+ creditResponse = deductAndTrackCredit(storename=necklace_try_on_id.storename, endpoint="/necklaceTryOnID")
216
+ if creditResponse == "No Credits Available":
217
+ response = {
218
+ "error": "No Credits Remaining"
219
+ }
220
+
221
+ return JSONResponse(content=response)
222
+
223
+ else:
224
+ return JSONResponse(content=response)
requirements.txt ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ annotated-types==0.7.0
3
+ anyio==4.4.0
4
+ attrs==23.2.0
5
+ certifi==2024.7.4
6
+ cffi==1.16.0
7
+ charset-normalizer==3.3.2
8
+ click==8.1.7
9
+ contourpy==1.2.1
10
+ cvzone==1.6.1
11
+ cycler==0.12.1
12
+ deprecation==2.1.0
13
+ dnspython==2.6.1
14
+ email_validator==2.2.0
15
+ exceptiongroup==1.2.1
16
+ fastapi==0.111.0
17
+ fastapi-cli==0.0.4
18
+ flatbuffers==24.3.25
19
+ fonttools==4.53.1
20
+ gotrue==2.5.4
21
+ h11==0.14.0
22
+ httpcore==1.0.5
23
+ httptools==0.6.1
24
+ httpx==0.27.0
25
+ idna==3.7
26
+ jax==0.4.30
27
+ jaxlib==0.4.30
28
+ Jinja2==3.1.4
29
+ kiwisolver==1.4.5
30
+ markdown-it-py==3.0.0
31
+ MarkupSafe==2.1.5
32
+ matplotlib==3.9.1
33
+ mdurl==0.1.2
34
+ mediapipe==0.10.14
35
+ ml-dtypes==0.4.0
36
+ numpy==2.0.0
37
+ opencv-contrib-python==4.10.0.84
38
+ opencv-python==4.10.0.84
39
+ opt-einsum==3.3.0
40
+ orjson==3.10.6
41
+ packaging==24.1
42
+ pandas==2.2.2
43
+ pillow==10.4.0
44
+ postgrest==0.16.8
45
+ protobuf==4.25.3
46
+ pycparser==2.22
47
+ pydantic==2.8.2
48
+ pydantic_core==2.20.1
49
+ Pygments==2.18.0
50
+ pyparsing==3.1.2
51
+ python-dateutil==2.9.0.post0
52
+ python-dotenv==1.0.1
53
+ python-multipart==0.0.9
54
+ pytz==2024.1
55
+ PyYAML==6.0.1
56
+ realtime==1.0.6
57
+ requests==2.32.3
58
+ rich==13.7.1
59
+ scikit-build==0.18.0
60
+ scipy==1.14.0
61
+ shellingham==1.5.4
62
+ six==1.16.0
63
+ sniffio==1.3.1
64
+ sounddevice==0.4.7
65
+ starlette==0.37.2
66
+ storage3==0.7.6
67
+ StrEnum==0.4.15
68
+ supabase==2.5.1
69
+ supafunc==0.4.6
70
+ tomli==2.0.1
71
+ typer==0.12.3
72
+ typing_extensions==4.12.2
73
+ tzdata==2024.1
74
+ ujson==5.10.0
75
+ urllib3==2.2.2
76
+ uvicorn==0.30.1
77
+ uvloop==0.19.0
78
+ watchfiles==0.22.0
79
+ websockets==12.0
80
+ -e .
setup.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ HYPER_E_DOT = "-e ."
4
+ def getRequirements(requirementsPath: str) -> list[str]:
5
+ with open(requirementsPath) as file:
6
+ requirements = file.read().split("\n")
7
+ requirements.remove(HYPER_E_DOT)
8
+ return requirements
9
+
10
+ setup(
11
+ name = "Magical-Mirror",
12
+ author = "Subramani Sivakumar",
13
+ author_email = "[email protected]",
14
+ version = "0.1",
15
+ packages = find_packages(),
16
+ install_requires = getRequirements(requirementsPath = "./requirements.txt")
17
+ )
src/components/__init__.py ADDED
File without changes
src/components/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (152 Bytes). View file
 
src/components/__pycache__/necklaceTryOn.cpython-310.pyc ADDED
Binary file (6.35 kB). View file
 
src/components/necklaceTryOn.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cvzone.FaceMeshModule import FaceMeshDetector
2
+ from src.utils import addWatermark, returnBytesData
3
+ from src.utils.exceptions import CustomException
4
+ from cvzone.PoseModule import PoseDetector
5
+ from src.utils.logger import logger
6
+ from dataclasses import dataclass
7
+ from typing import Union
8
+ from PIL import Image
9
+ import numpy as np
10
+ import cvzone
11
+ import math
12
+ import cv2
13
+ import gc
14
+
15
+
16
+ @dataclass
17
+ class NecklaceTryOnConfig:
18
+ logoURL: str = "https://lvuhhlrkcuexzqtsbqyu.supabase.co/storage/v1/object/public/MagicMirror/FullImages/{}.png"
19
+
20
+
21
+ class NecklaceTryOn:
22
+ def __init__(self) -> None:
23
+ self.detector = PoseDetector()
24
+ self.necklaceTryOnConfig = NecklaceTryOnConfig()
25
+ self.meshDetector = FaceMeshDetector(staticMode=True, maxFaces=1)
26
+
27
+ def necklaceTryOn(self, image: Image.Image, jewellery: Image.Image, storename: str) -> list[
28
+ Union[Image.Image, str]]:
29
+ try:
30
+ logger.info(f">>> NECKLACE TRY ON STARTED :: {storename} <<<")
31
+
32
+ # reading the images
33
+ image, jewellery = image.convert("RGB").resize((3000, 3000)), jewellery.convert("RGBA")
34
+ image = np.array(image)
35
+ copy_image = image.copy()
36
+ jewellery = np.array(jewellery)
37
+
38
+ logger.info(f"NECKLACE TRY ON :: detecting pose and landmarks :: {storename}")
39
+
40
+ image = self.detector.findPose(image)
41
+ lmList, _ = self.detector.findPosition(image, bboxWithHands=False, draw=False)
42
+
43
+ img, faces = self.meshDetector.findFaceMesh(image, draw=False)
44
+ leftLandmarkIndex = 172
45
+ rightLandmarkIndex = 397
46
+
47
+ leftLandmark, rightLandmark = faces[0][leftLandmarkIndex], faces[0][rightLandmarkIndex]
48
+ landmarksDistance = int(
49
+ ((leftLandmark[0] - rightLandmark[0]) ** 2 + (leftLandmark[1] - rightLandmark[1]) ** 2) ** 0.5)
50
+
51
+ logger.info(f"NECKLACE TRY ON :: estimating neck points :: {storename}")
52
+
53
+ # avg_x1 = int(leftLandmark[0] - landmarksDistance * 0.15) -> V2.1
54
+ avg_x1 = int(leftLandmark[0] - landmarksDistance * 0.12)
55
+ avg_x2 = int(rightLandmark[0] + landmarksDistance * 0.12)
56
+
57
+ avg_y1 = int(leftLandmark[1] + landmarksDistance * 0.5)
58
+ avg_y2 = int(rightLandmark[1] + landmarksDistance * 0.5)
59
+
60
+ logger.info(f"NECKLACE TRY ON :: scaling the necklace image :: {storename}")
61
+
62
+ if avg_y2 < avg_y1:
63
+ angle = math.ceil(
64
+ self.detector.findAngle(
65
+ p1=(avg_x2, avg_y2), p2=(avg_x1, avg_y1), p3=(avg_x2, avg_y1)
66
+ )[0]
67
+ )
68
+ else:
69
+ angle = math.ceil(
70
+ self.detector.findAngle(
71
+ p1=(avg_x2, avg_y2), p2=(avg_x1, avg_y1), p3=(avg_x2, avg_y1)
72
+ )[0]
73
+ )
74
+ angle = angle * -1
75
+
76
+ xdist = avg_x2 - avg_x1
77
+ origImgRatio = xdist / jewellery.shape[1]
78
+ ydist = jewellery.shape[0] * origImgRatio
79
+
80
+ logger.info(f"NECKLACE TRY ON :: adding offset based on the necklace shape :: {storename}")
81
+
82
+ image_gray = cv2.cvtColor(jewellery, cv2.COLOR_BGRA2GRAY)
83
+ for offset_orig in range(image_gray.shape[1]):
84
+ pixel_value = image_gray[0, :][offset_orig]
85
+ if (pixel_value != 255) & (pixel_value != 0):
86
+ break
87
+ else:
88
+ continue
89
+ offset = int(0.8 * xdist * (offset_orig / jewellery.shape[1]))
90
+ jewellery = cv2.resize(
91
+ jewellery, (int(xdist), int(ydist)), interpolation=cv2.INTER_AREA
92
+ )
93
+ jewellery = cvzone.rotateImage(jewellery, angle)
94
+ y_coordinate = avg_y1 - offset
95
+ available_space = copy_image.shape[0] - y_coordinate
96
+ extra = jewellery.shape[0] - available_space
97
+
98
+ logger.info(f"NECKLACE TRY ON :: generating necklace placement status :: {storename}")
99
+
100
+ if extra > 0:
101
+ headerText = "To see more of the necklace, please step back slightly."
102
+ else:
103
+ headerText = "success"
104
+
105
+ logger.info(f"NECKLACE TRY ON :: generating output :: {storename}")
106
+
107
+ result = cvzone.overlayPNG(copy_image, jewellery, (avg_x1, y_coordinate))
108
+ image = Image.fromarray(result.astype(np.uint8))
109
+ logo = Image.open(returnBytesData(url=self.necklaceTryOnConfig.logoURL.format(storename)))
110
+ result = addWatermark(background=image, logo=logo)
111
+
112
+ gc.collect()
113
+
114
+ return [result, headerText]
115
+
116
+ except Exception as e:
117
+ logger.error(f"{CustomException(e)}:: {storename}")
118
+ raise CustomException(e)
119
+
120
+ def necklaceTryOnV3(self, image: Image.Image, jewellery: Image.Image, storename: str) -> list[
121
+ Union[Image.Image, str]]:
122
+ try:
123
+ logger.info(f">>> NECKLACE TRY ON STARTED :: {storename} <<<")
124
+
125
+ # reading the images
126
+ image, jewellery = image.convert("RGB"), jewellery.convert("RGBA")
127
+ image = np.array(image.resize((4000, 4000)))
128
+ copy_image = image.copy()
129
+ jewellery = np.array(jewellery)
130
+
131
+ logger.info(f"NECKLACE TRY ON :: detecting pose and landmarks :: {storename}")
132
+
133
+ image = self.detector.findPose(image)
134
+ lmList, _ = self.detector.findPosition(image, bboxWithHands=False, draw=False)
135
+ meshDetector = FaceMeshDetector(staticMode=True, maxFaces=1)
136
+ img, faces = meshDetector.findFaceMesh(image, draw=False)
137
+ left_lip_point = faces[0][61]
138
+ right_lip_point = faces[0][291]
139
+ pt12, pt11, pt10, pt9 = (
140
+ lmList[12][:2],
141
+ lmList[11][:2],
142
+ lmList[10][:2],
143
+ lmList[9][:2],
144
+ )
145
+
146
+ mid_lips = (
147
+ int((left_lip_point[0] + right_lip_point[0]) / 2), int((left_lip_point[1] + right_lip_point[1]) / 2))
148
+
149
+ mid_lips_x1 = int(pt12[0] + (mid_lips[0] - pt12[0]) / 2)
150
+ mid_lips_y1 = int(pt12[1] + (mid_lips[1] - pt12[1]) / 2)
151
+
152
+ mid_lips_x2 = int(pt11[0] + (mid_lips[0] - pt11[0]) / 2)
153
+ mid_lips_y2 = int(pt11[1] + (mid_lips[1] - pt11[1]) / 2)
154
+
155
+ # left right lip
156
+ left_right_lip_org_x11 = int(pt12[0] + (right_lip_point[0] - pt12[0]) / 2)
157
+ left_right_lip_org_y11 = int(pt12[1] + (right_lip_point[1] - pt12[1]) / 2)
158
+
159
+ left_right_lip_org_x12 = int(pt11[0] + (left_lip_point[0] - pt11[0]) / 2)
160
+ left_right_lip_org_y12 = int(pt11[1] + (left_lip_point[1] - pt11[1]) / 2)
161
+
162
+ # left right lip 2
163
+ left_right_lip_org_x21 = int(pt12[0] + (left_lip_point[0] - pt12[0]) / 2)
164
+ left_right_lip_org_y21 = int(pt12[1] + (left_lip_point[1] - pt12[1]) / 2)
165
+
166
+ left_right_lip_org_x22 = int(pt11[0] + (right_lip_point[0] - pt11[0]) / 2)
167
+ left_right_lip_org_y22 = int(pt11[1] + (right_lip_point[1] - pt11[1]) / 2)
168
+
169
+ logger.info(f"NECKLACE TRY ON :: estimating neck points :: {storename}")
170
+
171
+ avg_x1 = int((mid_lips_x1 + left_right_lip_org_x11 + left_right_lip_org_x21) / 3)
172
+ avg_y1 = int((mid_lips_y1 + left_right_lip_org_y11 + left_right_lip_org_y21) / 3)
173
+
174
+ avg_x2 = int((mid_lips_x2 + left_right_lip_org_x12 + left_right_lip_org_x22) / 3)
175
+ avg_y2 = int((mid_lips_y2 + left_right_lip_org_y12 + left_right_lip_org_y22) / 3)
176
+
177
+ logger.info(f"NECKLACE TRY ON :: scaling the necklace image :: {storename}")
178
+
179
+ if avg_y2 < avg_y1:
180
+ angle = math.ceil(
181
+ self.detector.findAngle(
182
+ p1=(avg_x2, avg_y2), p2=(avg_x1, avg_y1), p3=(avg_x2, avg_y1)
183
+ )[0]
184
+ )
185
+ else:
186
+ angle = math.ceil(
187
+ self.detector.findAngle(
188
+ p1=(avg_x2, avg_y2), p2=(avg_x1, avg_y1), p3=(avg_x2, avg_y1)
189
+ )[0]
190
+ )
191
+ angle = angle * -1
192
+
193
+ xdist = avg_x2 - avg_x1
194
+ origImgRatio = xdist / jewellery.shape[1]
195
+ ydist = jewellery.shape[0] * origImgRatio
196
+
197
+ logger.info(f"NECKLACE TRY ON :: adding offset based on the necklace shape :: {storename}")
198
+
199
+ image_gray = cv2.cvtColor(jewellery, cv2.COLOR_BGRA2GRAY)
200
+ for offset_orig in range(image_gray.shape[1]):
201
+ pixel_value = image_gray[0, :][offset_orig]
202
+ if (pixel_value != 255) & (pixel_value != 0):
203
+ break
204
+ else:
205
+ continue
206
+ offset = int(0.8 * xdist * (offset_orig / jewellery.shape[1]))
207
+ jewellery = cv2.resize(
208
+ jewellery, (int(xdist), int(ydist)), interpolation=cv2.INTER_AREA
209
+ )
210
+ jewellery = cvzone.rotateImage(jewellery, angle)
211
+ y_coordinate = avg_y1 - offset
212
+ available_space = copy_image.shape[0] - y_coordinate
213
+ extra = jewellery.shape[0] - available_space
214
+
215
+ logger.info(f"NECKLACE TRY ON :: generating necklace placement status :: {storename}")
216
+
217
+ if extra > 0:
218
+ headerText = "To see more of the necklace, please step back slightly."
219
+ else:
220
+ headerText = "success"
221
+
222
+ logger.info(f"NECKLACE TRY ON :: generating output :: {storename}")
223
+
224
+ result = cvzone.overlayPNG(copy_image, jewellery, (avg_x1, y_coordinate))
225
+ image = Image.fromarray(result.astype(np.uint8))
226
+ logo = Image.open(returnBytesData(url=self.necklaceTryOnConfig.logoURL.format({storename})))
227
+ result = addWatermark(background=image, logo=logo)
228
+
229
+ gc.collect()
230
+
231
+ return [result, headerText]
232
+
233
+ except Exception as e:
234
+ logger.error(f"{CustomException(e)}:: {storename}")
235
+ raise CustomException(e)
src/pipelines/__init__.py ADDED
File without changes
src/pipelines/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (151 Bytes). View file
 
src/pipelines/__pycache__/completePipeline.cpython-310.pyc ADDED
Binary file (925 Bytes). View file
 
src/pipelines/completePipeline.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.components.necklaceTryOn import NecklaceTryOn
2
+ from typing import Union
3
+ from PIL import Image
4
+
5
+
6
+ class Pipeline:
7
+ def __init__(self) -> None:
8
+ self.necklaceTryOnObj = NecklaceTryOn()
9
+
10
+ async def necklaceTryOn(self, image: Image.Image, jewellery: Image.Image, storename: str) -> list[
11
+ Union[Image.Image, str]]:
12
+ result, headerText = self.necklaceTryOnObj.necklaceTryOn(image=image, jewellery=jewellery, storename=storename)
13
+ return [result, headerText]
src/utils/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from supabase import create_client, Client
2
+ from PIL import Image
3
+ from io import BytesIO
4
+ import requests
5
+ import os
6
+
7
+
8
+ # function to add watermark to images
9
+ def addWatermark(background: Image.Image, logo: Image.Image) -> Image.Image:
10
+ background = background.convert("RGBA")
11
+ logo = logo.convert("RGBA").resize((int(0.08 * background.size[0]), int(0.08 * background.size[0])))
12
+ background.paste(logo, (10, background.size[1] - logo.size[1] - 10), logo)
13
+ return background
14
+
15
+
16
+ # function to download an image from url and return as bytes objects
17
+ def returnBytesData(url: str) -> BytesIO:
18
+ response = requests.get(url)
19
+ return BytesIO(response.content)
20
+
21
+
22
+ # function to get public URLs of paths
23
+ def supabaseGetPublicURL(path: str) -> str:
24
+ url_string = f"https://lvuhhlrkcuexzqtsbqyu.supabase.co/storage/v1/object/public/Stores/{path}"
25
+ return url_string.replace(" ", "%20")
26
+
27
+
28
+ # function to deduct credit
29
+ def deductAndTrackCredit(storename: str, endpoint: str) -> str:
30
+ url: str = os.environ["SUPABASE_URL"]
31
+ key: str = os.environ["SUPABASE_KEY"]
32
+ supabase: Client = create_client(url, key)
33
+ current, _ = supabase.table('ClientConfig').select('CreditBalance').eq("StoreName", f"{storename}").execute()
34
+ if current[1] == []:
35
+ return "Not Found"
36
+ else:
37
+ current = current[1][0]["CreditBalance"]
38
+ if current > 0:
39
+ data, _ = supabase.table('ClientConfig').update({'CreditBalance': current - 1}).eq("StoreName",
40
+ f"{storename}").execute()
41
+ data, _ = supabase.table('UsageHistory').insert(
42
+ {'StoreName': f"{storename}", 'APIEndpoint': f"{endpoint}"}).execute()
43
+ return "Success"
44
+ else:
45
+ return "No Credits Available"
src/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.7 kB). View file
 
src/utils/__pycache__/backgroundEnhancerArchitecture.cpython-310.pyc ADDED
Binary file (9.92 kB). View file
 
src/utils/__pycache__/exceptions.cpython-310.pyc ADDED
Binary file (1.03 kB). View file
 
src/utils/__pycache__/logger.cpython-310.pyc ADDED
Binary file (683 Bytes). View file
 
src/utils/backgroundEnhancerArchitecture.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+
6
+
7
+ class REBNCONV(nn.Module):
8
+ def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
9
+ super(REBNCONV, self).__init__()
10
+
11
+ self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride)
12
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
13
+ self.relu_s1 = nn.ReLU(inplace=True)
14
+
15
+ def forward(self, x):
16
+ hx = x
17
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
18
+
19
+ return xout
20
+
21
+
22
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
23
+ def _upsample_like(src, tar):
24
+ src = F.interpolate(src, size=tar.shape[2:], mode='bilinear')
25
+
26
+ return src
27
+
28
+
29
+ ### RSU-7 ###
30
+ class RSU7(nn.Module):
31
+
32
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
33
+ super(RSU7, self).__init__()
34
+
35
+ self.in_ch = in_ch
36
+ self.mid_ch = mid_ch
37
+ self.out_ch = out_ch
38
+
39
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
40
+
41
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
42
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
43
+
44
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
45
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
46
+
47
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
48
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
49
+
50
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
51
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
52
+
53
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
54
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
55
+
56
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
57
+
58
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
59
+
60
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
61
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
62
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
63
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
64
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
65
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
66
+
67
+ def forward(self, x):
68
+ b, c, h, w = x.shape
69
+
70
+ hx = x
71
+ hxin = self.rebnconvin(hx)
72
+
73
+ hx1 = self.rebnconv1(hxin)
74
+ hx = self.pool1(hx1)
75
+
76
+ hx2 = self.rebnconv2(hx)
77
+ hx = self.pool2(hx2)
78
+
79
+ hx3 = self.rebnconv3(hx)
80
+ hx = self.pool3(hx3)
81
+
82
+ hx4 = self.rebnconv4(hx)
83
+ hx = self.pool4(hx4)
84
+
85
+ hx5 = self.rebnconv5(hx)
86
+ hx = self.pool5(hx5)
87
+
88
+ hx6 = self.rebnconv6(hx)
89
+
90
+ hx7 = self.rebnconv7(hx6)
91
+
92
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
93
+ hx6dup = _upsample_like(hx6d, hx5)
94
+
95
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
96
+ hx5dup = _upsample_like(hx5d, hx4)
97
+
98
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
99
+ hx4dup = _upsample_like(hx4d, hx3)
100
+
101
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
102
+ hx3dup = _upsample_like(hx3d, hx2)
103
+
104
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
105
+ hx2dup = _upsample_like(hx2d, hx1)
106
+
107
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
108
+
109
+ return hx1d + hxin
110
+
111
+
112
+ ### RSU-6 ###
113
+ class RSU6(nn.Module):
114
+
115
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
116
+ super(RSU6, self).__init__()
117
+
118
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
119
+
120
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
121
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
122
+
123
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
124
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
125
+
126
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
127
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
128
+
129
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
130
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
131
+
132
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
133
+
134
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
135
+
136
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
137
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
138
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
139
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
140
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
141
+
142
+ def forward(self, x):
143
+ hx = x
144
+
145
+ hxin = self.rebnconvin(hx)
146
+
147
+ hx1 = self.rebnconv1(hxin)
148
+ hx = self.pool1(hx1)
149
+
150
+ hx2 = self.rebnconv2(hx)
151
+ hx = self.pool2(hx2)
152
+
153
+ hx3 = self.rebnconv3(hx)
154
+ hx = self.pool3(hx3)
155
+
156
+ hx4 = self.rebnconv4(hx)
157
+ hx = self.pool4(hx4)
158
+
159
+ hx5 = self.rebnconv5(hx)
160
+
161
+ hx6 = self.rebnconv6(hx5)
162
+
163
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
164
+ hx5dup = _upsample_like(hx5d, hx4)
165
+
166
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
167
+ hx4dup = _upsample_like(hx4d, hx3)
168
+
169
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
170
+ hx3dup = _upsample_like(hx3d, hx2)
171
+
172
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
173
+ hx2dup = _upsample_like(hx2d, hx1)
174
+
175
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
176
+
177
+ return hx1d + hxin
178
+
179
+
180
+ ### RSU-5 ###
181
+ class RSU5(nn.Module):
182
+
183
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
184
+ super(RSU5, self).__init__()
185
+
186
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
187
+
188
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
189
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
190
+
191
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
192
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
193
+
194
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
195
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
196
+
197
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
198
+
199
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
200
+
201
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
202
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
203
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
204
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
205
+
206
+ def forward(self, x):
207
+ hx = x
208
+
209
+ hxin = self.rebnconvin(hx)
210
+
211
+ hx1 = self.rebnconv1(hxin)
212
+ hx = self.pool1(hx1)
213
+
214
+ hx2 = self.rebnconv2(hx)
215
+ hx = self.pool2(hx2)
216
+
217
+ hx3 = self.rebnconv3(hx)
218
+ hx = self.pool3(hx3)
219
+
220
+ hx4 = self.rebnconv4(hx)
221
+
222
+ hx5 = self.rebnconv5(hx4)
223
+
224
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
225
+ hx4dup = _upsample_like(hx4d, hx3)
226
+
227
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
228
+ hx3dup = _upsample_like(hx3d, hx2)
229
+
230
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
231
+ hx2dup = _upsample_like(hx2d, hx1)
232
+
233
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
234
+
235
+ return hx1d + hxin
236
+
237
+
238
+ ### RSU-4 ###
239
+ class RSU4(nn.Module):
240
+
241
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
242
+ super(RSU4, self).__init__()
243
+
244
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
245
+
246
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
247
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
248
+
249
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
250
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
251
+
252
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
253
+
254
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
255
+
256
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
257
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
258
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
259
+
260
+ def forward(self, x):
261
+ hx = x
262
+
263
+ hxin = self.rebnconvin(hx)
264
+
265
+ hx1 = self.rebnconv1(hxin)
266
+ hx = self.pool1(hx1)
267
+
268
+ hx2 = self.rebnconv2(hx)
269
+ hx = self.pool2(hx2)
270
+
271
+ hx3 = self.rebnconv3(hx)
272
+
273
+ hx4 = self.rebnconv4(hx3)
274
+
275
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
276
+ hx3dup = _upsample_like(hx3d, hx2)
277
+
278
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
279
+ hx2dup = _upsample_like(hx2d, hx1)
280
+
281
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
282
+
283
+ return hx1d + hxin
284
+
285
+
286
+ ### RSU-4F ###
287
+ class RSU4F(nn.Module):
288
+
289
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
290
+ super(RSU4F, self).__init__()
291
+
292
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
293
+
294
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
295
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
296
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
297
+
298
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
299
+
300
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
301
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
302
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
303
+
304
+ def forward(self, x):
305
+ hx = x
306
+
307
+ hxin = self.rebnconvin(hx)
308
+
309
+ hx1 = self.rebnconv1(hxin)
310
+ hx2 = self.rebnconv2(hx1)
311
+ hx3 = self.rebnconv3(hx2)
312
+
313
+ hx4 = self.rebnconv4(hx3)
314
+
315
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
316
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
317
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
318
+
319
+ return hx1d + hxin
320
+
321
+
322
+ class myrebnconv(nn.Module):
323
+ def __init__(self, in_ch=3,
324
+ out_ch=1,
325
+ kernel_size=3,
326
+ stride=1,
327
+ padding=1,
328
+ dilation=1,
329
+ groups=1):
330
+ super(myrebnconv, self).__init__()
331
+
332
+ self.conv = nn.Conv2d(in_ch,
333
+ out_ch,
334
+ kernel_size=kernel_size,
335
+ stride=stride,
336
+ padding=padding,
337
+ dilation=dilation,
338
+ groups=groups)
339
+ self.bn = nn.BatchNorm2d(out_ch)
340
+ self.rl = nn.ReLU(inplace=True)
341
+
342
+ def forward(self, x):
343
+ return self.rl(self.bn(self.conv(x)))
344
+
345
+
346
+ class BackgroundEnhancerArchitecture(nn.Module, PyTorchModelHubMixin):
347
+
348
+ def __init__(self, config: dict = {"in_ch": 3, "out_ch": 1}):
349
+ super(BackgroundEnhancerArchitecture, self).__init__()
350
+ in_ch = config["in_ch"]
351
+ out_ch = config["out_ch"]
352
+ self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
353
+ self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
354
+
355
+ self.stage1 = RSU7(64, 32, 64)
356
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
357
+
358
+ self.stage2 = RSU6(64, 32, 128)
359
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
360
+
361
+ self.stage3 = RSU5(128, 64, 256)
362
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
363
+
364
+ self.stage4 = RSU4(256, 128, 512)
365
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
366
+
367
+ self.stage5 = RSU4F(512, 256, 512)
368
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
369
+
370
+ self.stage6 = RSU4F(512, 256, 512)
371
+
372
+ # decoder
373
+ self.stage5d = RSU4F(1024, 256, 512)
374
+ self.stage4d = RSU4(1024, 128, 256)
375
+ self.stage3d = RSU5(512, 64, 128)
376
+ self.stage2d = RSU6(256, 32, 64)
377
+ self.stage1d = RSU7(128, 16, 64)
378
+
379
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
380
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
381
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
382
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
383
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
384
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
385
+
386
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
387
+
388
+ def forward(self, x):
389
+ hx = x
390
+
391
+ hxin = self.conv_in(hx)
392
+ # hx = self.pool_in(hxin)
393
+
394
+ # stage 1
395
+ hx1 = self.stage1(hxin)
396
+ hx = self.pool12(hx1)
397
+
398
+ # stage 2
399
+ hx2 = self.stage2(hx)
400
+ hx = self.pool23(hx2)
401
+
402
+ # stage 3
403
+ hx3 = self.stage3(hx)
404
+ hx = self.pool34(hx3)
405
+
406
+ # stage 4
407
+ hx4 = self.stage4(hx)
408
+ hx = self.pool45(hx4)
409
+
410
+ # stage 5
411
+ hx5 = self.stage5(hx)
412
+ hx = self.pool56(hx5)
413
+
414
+ # stage 6
415
+ hx6 = self.stage6(hx)
416
+ hx6up = _upsample_like(hx6, hx5)
417
+
418
+ # -------------------- decoder --------------------
419
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
420
+ hx5dup = _upsample_like(hx5d, hx4)
421
+
422
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
423
+ hx4dup = _upsample_like(hx4d, hx3)
424
+
425
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
426
+ hx3dup = _upsample_like(hx3d, hx2)
427
+
428
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
429
+ hx2dup = _upsample_like(hx2d, hx1)
430
+
431
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
432
+
433
+ # side output
434
+ d1 = self.side1(hx1d)
435
+ d1 = _upsample_like(d1, x)
436
+
437
+ d2 = self.side2(hx2d)
438
+ d2 = _upsample_like(d2, x)
439
+
440
+ d3 = self.side3(hx3d)
441
+ d3 = _upsample_like(d3, x)
442
+
443
+ d4 = self.side4(hx4d)
444
+ d4 = _upsample_like(d4, x)
445
+
446
+ d5 = self.side5(hx5d)
447
+ d5 = _upsample_like(d5, x)
448
+
449
+ d6 = self.side6(hx6)
450
+ d6 = _upsample_like(d6, x)
451
+
452
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)], [hx1d, hx2d,
453
+ hx3d, hx4d,
454
+ hx5d, hx6]
src/utils/exceptions.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ def error_message_detail(error):
4
+ _, _, exc_info = sys.exc_info()
5
+ filename = exc_info.tb_frame.f_code.co_filename
6
+ lineno = exc_info.tb_lineno
7
+ error_message = "Error encountered in line no [{}], filename : [{}], saying [{}]".format(lineno, filename, error)
8
+ return error_message
9
+
10
+ class CustomException(Exception):
11
+ def __init__(self, error_message):
12
+ super().__init__(error_message)
13
+ self.error_message = error_message_detail(error_message)
14
+
15
+ def __str__(self) -> str:
16
+ return self.error_message
src/utils/logger.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ logger = logging.getLogger(__name__)
5
+ logger.setLevel(logging.INFO)
6
+
7
+ log_dir = os.path.join(os.getcwd(), "logs")
8
+ os.makedirs(log_dir, exist_ok = True)
9
+
10
+ LOG_FILE = os.path.join(log_dir, "running_logs.log")
11
+
12
+ logFormat = "[%(asctime)s: %(levelname)s: %(module)s: %(message)s]"
13
+ logFormatter = logging.Formatter(fmt = logFormat, style = "%")
14
+
15
+ streamHandler = logging.StreamHandler()
16
+ streamHandler.setFormatter(logFormatter)
17
+
18
+ fileHandler = logging.FileHandler(filename = LOG_FILE)
19
+ fileHandler.setFormatter(logFormatter)
20
+
21
+ logger.addHandler(streamHandler)
22
+ logger.addHandler(fileHandler)