Spaces:
Sleeping
Sleeping
Initial commit
Browse files- .github/workflows/dockerhub.yaml +39 -0
- .gitignore +7 -0
- Dockerfile +21 -0
- app.py +224 -0
- requirements.txt +80 -0
- setup.py +17 -0
- src/components/__init__.py +0 -0
- src/components/__pycache__/__init__.cpython-310.pyc +0 -0
- src/components/__pycache__/necklaceTryOn.cpython-310.pyc +0 -0
- src/components/necklaceTryOn.py +235 -0
- src/pipelines/__init__.py +0 -0
- src/pipelines/__pycache__/__init__.cpython-310.pyc +0 -0
- src/pipelines/__pycache__/completePipeline.cpython-310.pyc +0 -0
- src/pipelines/completePipeline.py +13 -0
- src/utils/__init__.py +45 -0
- src/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- src/utils/__pycache__/backgroundEnhancerArchitecture.cpython-310.pyc +0 -0
- src/utils/__pycache__/exceptions.cpython-310.pyc +0 -0
- src/utils/__pycache__/logger.cpython-310.pyc +0 -0
- src/utils/backgroundEnhancerArchitecture.py +454 -0
- src/utils/exceptions.py +16 -0
- src/utils/logger.py +22 -0
.github/workflows/dockerhub.yaml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Publish Docker image
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches: [main]
|
6 |
+
|
7 |
+
jobs:
|
8 |
+
push_to_registry:
|
9 |
+
name: Push Docker image to Docker Hub
|
10 |
+
runs-on: ubuntu-latest
|
11 |
+
permissions:
|
12 |
+
packages: write
|
13 |
+
contents: read
|
14 |
+
attestations: write
|
15 |
+
steps:
|
16 |
+
- name: Check out the repo
|
17 |
+
uses: actions/checkout@v4
|
18 |
+
|
19 |
+
- name: Log in to Docker Hub
|
20 |
+
uses: docker/login-action@f4ef78c080cd8ba55a85445d5b36e214a81df20a
|
21 |
+
with:
|
22 |
+
username: ${{ secrets.DOCKER_USERNAME }}
|
23 |
+
password: ${{ secrets.DOCKER_PASSWORD }}
|
24 |
+
|
25 |
+
- name: Extract metadata (tags, labels) for Docker
|
26 |
+
id: meta
|
27 |
+
uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7
|
28 |
+
with:
|
29 |
+
images: ishworrsubedii/web_plugin_api
|
30 |
+
|
31 |
+
- name: Build and push Docker image
|
32 |
+
id: push
|
33 |
+
uses: docker/build-push-action@3b5e8027fcad23fda98b2e3ac259d8d67585f671
|
34 |
+
with:
|
35 |
+
context: .
|
36 |
+
file: ./Dockerfile
|
37 |
+
push: true
|
38 |
+
tags: ${{ steps.meta.outputs.tags }}
|
39 |
+
labels: ${{ steps.meta.outputs.labels }}
|
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
.idea/
|
6 |
+
|
7 |
+
*.log
|
Dockerfile
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10-slim
|
2 |
+
|
3 |
+
WORKDIR /api
|
4 |
+
|
5 |
+
COPY . /api
|
6 |
+
|
7 |
+
RUN apt-get update && apt-get install -y
|
8 |
+
|
9 |
+
RUN apt install libgl1-mesa-glx -y
|
10 |
+
|
11 |
+
RUN apt-get install 'ffmpeg'\
|
12 |
+
'libsm6'\
|
13 |
+
'libxext6' -y
|
14 |
+
|
15 |
+
RUN pip install -r requirements.txt
|
16 |
+
|
17 |
+
RUN ulimit -s 2000
|
18 |
+
|
19 |
+
EXPOSE 8000
|
20 |
+
|
21 |
+
CMD ["uvicorn", "app:app", "--host","0.0.0.0","--port","8000"]
|
app.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi.encoders import jsonable_encoder
|
2 |
+
from src.utils import supabaseGetPublicURL, deductAndTrackCredit, returnBytesData
|
3 |
+
from fastapi import FastAPI, File, UploadFile, Header, HTTPException, Form, Depends
|
4 |
+
from src.pipelines.completePipeline import Pipeline
|
5 |
+
from fastapi.middleware.cors import CORSMiddleware
|
6 |
+
from fastapi.responses import JSONResponse
|
7 |
+
from supabase import create_client, Client
|
8 |
+
from typing import Dict, Union, List
|
9 |
+
from io import BytesIO
|
10 |
+
from PIL import Image
|
11 |
+
import pandas as pd
|
12 |
+
import base64
|
13 |
+
import os
|
14 |
+
from pydantic import BaseModel
|
15 |
+
|
16 |
+
pipeline = Pipeline()
|
17 |
+
app = FastAPI(title="Magical Mirror Web Plugin")
|
18 |
+
|
19 |
+
app.add_middleware(
|
20 |
+
CORSMiddleware,
|
21 |
+
allow_origins=["*"],
|
22 |
+
allow_credentials=True,
|
23 |
+
allow_methods=["*"],
|
24 |
+
allow_headers=["*"],
|
25 |
+
)
|
26 |
+
url: str = os.environ["SUPABASE_URL"]
|
27 |
+
key: str = os.environ["SUPABASE_KEY"]
|
28 |
+
supabase: Client = create_client(url, key)
|
29 |
+
|
30 |
+
|
31 |
+
@app.post("/productData/{storeId}")
|
32 |
+
async def product_data(
|
33 |
+
storeId: str,
|
34 |
+
filterattributes: List[Dict[str, Union[str, int, float]]],
|
35 |
+
storename: str = Header(default="default")
|
36 |
+
):
|
37 |
+
"""Filters product data based on the provided attributes and store ID."""
|
38 |
+
|
39 |
+
try:
|
40 |
+
response = supabase.table('MagicMirror').select("*").execute()
|
41 |
+
df = pd.DataFrame(response.dict()["data"])
|
42 |
+
|
43 |
+
df = df[df["StoreName"] == storeId]
|
44 |
+
|
45 |
+
# Preprocess filterattributes to handle multiple or duplicated attributes
|
46 |
+
attribute_dict = {}
|
47 |
+
for attr in filterattributes:
|
48 |
+
key, value = list(attr.items())[
|
49 |
+
0] # This will convert the dictionary into a list and get the key and value.
|
50 |
+
if key in attribute_dict: # This will check if the key is already present in the dictionary.
|
51 |
+
if isinstance(attribute_dict[key],
|
52 |
+
list): # This will create a list if there are multiple values for the same key and we are doing or operation.
|
53 |
+
attribute_dict[key].append(value) # This will append the value to the list.
|
54 |
+
else:
|
55 |
+
attribute_dict[key] = [attribute_dict[key], value]
|
56 |
+
else:
|
57 |
+
attribute_dict[key] = [value] # This will create a list if there is only one value for the key.
|
58 |
+
|
59 |
+
priceFrom = None
|
60 |
+
priceTo = None
|
61 |
+
weightFrom = None
|
62 |
+
weightTo = None
|
63 |
+
weightAscending = None
|
64 |
+
priceAscending = None
|
65 |
+
idAscending = None
|
66 |
+
dateAscending = None
|
67 |
+
|
68 |
+
for key, value in attribute_dict.items():
|
69 |
+
if key == 'priceFrom':
|
70 |
+
priceFrom = value[0]
|
71 |
+
|
72 |
+
elif key == "priceTo":
|
73 |
+
priceTo = value[0]
|
74 |
+
|
75 |
+
elif key == "priceAscending":
|
76 |
+
priceAscending = value[0]
|
77 |
+
|
78 |
+
elif key == "weightFrom":
|
79 |
+
weightFrom = value[0]
|
80 |
+
|
81 |
+
elif key == "weightTo":
|
82 |
+
weightTo = value[0]
|
83 |
+
|
84 |
+
elif key == "weightAscending":
|
85 |
+
weightAscending = value[0]
|
86 |
+
|
87 |
+
elif key == "idAscending":
|
88 |
+
idAscending = value[0]
|
89 |
+
|
90 |
+
elif key == "dateAscending":
|
91 |
+
dateAscending = value[0]
|
92 |
+
|
93 |
+
df["image_url"] = df.apply(
|
94 |
+
lambda row: supabaseGetPublicURL(f"{row['StoreName']}/{row['Category']}/image/{row['Id']}.png"),
|
95 |
+
axis=1)
|
96 |
+
df["thumbnail_url"] = df.apply(
|
97 |
+
lambda row: supabaseGetPublicURL(f"{row['StoreName']}/{row['Category']}/thumbnail/{row['Id']}.png"),
|
98 |
+
axis=1)
|
99 |
+
|
100 |
+
df.reset_index(drop=True, inplace=True)
|
101 |
+
for key, values in attribute_dict.items():
|
102 |
+
try:
|
103 |
+
df = df[df[key].isin(values)]
|
104 |
+
|
105 |
+
except:
|
106 |
+
pass
|
107 |
+
|
108 |
+
# applying filter for price and weight
|
109 |
+
if priceFrom is not None:
|
110 |
+
df = df[df["Price"] >= priceFrom]
|
111 |
+
if priceTo is not None:
|
112 |
+
df = df[df["Price"] <= priceTo]
|
113 |
+
if weightFrom is not None:
|
114 |
+
df = df[df["Weight"] >= weightFrom]
|
115 |
+
if weightTo is not None:
|
116 |
+
df = df[df["Weight"] <= weightTo]
|
117 |
+
|
118 |
+
if priceAscending is not None:
|
119 |
+
if priceAscending == 1:
|
120 |
+
value = True
|
121 |
+
|
122 |
+
else:
|
123 |
+
value = False
|
124 |
+
df = df.sort_values(by="Price", ascending=value)
|
125 |
+
if weightAscending is not None:
|
126 |
+
if weightAscending == 1:
|
127 |
+
value = True
|
128 |
+
|
129 |
+
else:
|
130 |
+
value = False
|
131 |
+
df = df.sort_values(by="Weight", ascending=value)
|
132 |
+
|
133 |
+
if idAscending is not None:
|
134 |
+
if idAscending == 1:
|
135 |
+
value = True
|
136 |
+
else:
|
137 |
+
value = False
|
138 |
+
df = df.sort_values(by="Id", ascending=value)
|
139 |
+
|
140 |
+
if dateAscending is not None:
|
141 |
+
if dateAscending == 1:
|
142 |
+
value = True
|
143 |
+
else:
|
144 |
+
value = False
|
145 |
+
df = df.sort_values(by="UpdatedAt", ascending=value)
|
146 |
+
|
147 |
+
df = df.drop(["CreatedAt", "EstimatedPrice"], axis=1)
|
148 |
+
|
149 |
+
result = {}
|
150 |
+
for _, row in df.iterrows():
|
151 |
+
category = row["Category"]
|
152 |
+
if category not in result: # this is for checking duplicate category
|
153 |
+
result[category] = []
|
154 |
+
result[category].append(row.to_dict())
|
155 |
+
|
156 |
+
return JSONResponse(content=jsonable_encoder(result)) # this will convert the result into json format.
|
157 |
+
|
158 |
+
except Exception as e:
|
159 |
+
raise HTTPException(status_code=500, detail=f"Failed to fetch or process data: {e}")
|
160 |
+
|
161 |
+
|
162 |
+
class NecklaceTryOnIDEntity(BaseModel):
|
163 |
+
necklaceImageId: str
|
164 |
+
necklaceCategory: str
|
165 |
+
storename: str
|
166 |
+
api_token: str
|
167 |
+
|
168 |
+
|
169 |
+
async def parse_necklace_try_on_id(necklaceImageId: str = Form(...),
|
170 |
+
necklaceCategory: str = Form(...),
|
171 |
+
storename: str = Form(...),
|
172 |
+
api_token: str = Form(...)) -> NecklaceTryOnIDEntity:
|
173 |
+
return NecklaceTryOnIDEntity(
|
174 |
+
necklaceImageId=necklaceImageId,
|
175 |
+
necklaceCategory=necklaceCategory,
|
176 |
+
storename=storename,
|
177 |
+
api_token=api_token
|
178 |
+
)
|
179 |
+
|
180 |
+
|
181 |
+
@app.post("/necklaceTryOnID")
|
182 |
+
async def necklace_try_on_id(necklace_try_on_id: NecklaceTryOnIDEntity = Depends(parse_necklace_try_on_id),
|
183 |
+
image: UploadFile = File(...)):
|
184 |
+
data, _ = supabase.table("APIKeyList").select("*").filter("API_KEY", "eq",
|
185 |
+
necklace_try_on_id.api_token).execute()
|
186 |
+
|
187 |
+
api_key_actual = data[1][0]['API_KEY']
|
188 |
+
if api_key_actual != necklace_try_on_id.api_token:
|
189 |
+
return JSONResponse(content={"error": "Invalid API Key"}, status_code=401)
|
190 |
+
|
191 |
+
else:
|
192 |
+
imageBytes = await image.read()
|
193 |
+
|
194 |
+
jewellery_url = f"https://lvuhhlrkcuexzqtsbqyu.supabase.co/storage/v1/object/public/Stores/{necklace_try_on_id.storename}/{necklace_try_on_id.necklaceCategory}/image/{necklace_try_on_id.necklaceImageId}.png"
|
195 |
+
|
196 |
+
try:
|
197 |
+
image, jewellery = Image.open(BytesIO(imageBytes)), Image.open(returnBytesData(url=jewellery_url))
|
198 |
+
|
199 |
+
except:
|
200 |
+
error_message = {
|
201 |
+
"error": "The requested resource (Image, necklace category, or store) is not available. Please verify the availability and try again."
|
202 |
+
}
|
203 |
+
|
204 |
+
return JSONResponse(content=error_message, status_code=404)
|
205 |
+
|
206 |
+
result, headetText = await pipeline.necklaceTryOn(image=image, jewellery=jewellery,
|
207 |
+
storename=necklace_try_on_id.storename)
|
208 |
+
|
209 |
+
inMemFile = BytesIO()
|
210 |
+
result.save(inMemFile, format="WEBP", quality=85)
|
211 |
+
outputBytes = inMemFile.getvalue()
|
212 |
+
response = {
|
213 |
+
"output": f"data:image/WEBP;base64,{base64.b64encode(outputBytes).decode('utf-8')}"
|
214 |
+
}
|
215 |
+
creditResponse = deductAndTrackCredit(storename=necklace_try_on_id.storename, endpoint="/necklaceTryOnID")
|
216 |
+
if creditResponse == "No Credits Available":
|
217 |
+
response = {
|
218 |
+
"error": "No Credits Remaining"
|
219 |
+
}
|
220 |
+
|
221 |
+
return JSONResponse(content=response)
|
222 |
+
|
223 |
+
else:
|
224 |
+
return JSONResponse(content=response)
|
requirements.txt
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.1.0
|
2 |
+
annotated-types==0.7.0
|
3 |
+
anyio==4.4.0
|
4 |
+
attrs==23.2.0
|
5 |
+
certifi==2024.7.4
|
6 |
+
cffi==1.16.0
|
7 |
+
charset-normalizer==3.3.2
|
8 |
+
click==8.1.7
|
9 |
+
contourpy==1.2.1
|
10 |
+
cvzone==1.6.1
|
11 |
+
cycler==0.12.1
|
12 |
+
deprecation==2.1.0
|
13 |
+
dnspython==2.6.1
|
14 |
+
email_validator==2.2.0
|
15 |
+
exceptiongroup==1.2.1
|
16 |
+
fastapi==0.111.0
|
17 |
+
fastapi-cli==0.0.4
|
18 |
+
flatbuffers==24.3.25
|
19 |
+
fonttools==4.53.1
|
20 |
+
gotrue==2.5.4
|
21 |
+
h11==0.14.0
|
22 |
+
httpcore==1.0.5
|
23 |
+
httptools==0.6.1
|
24 |
+
httpx==0.27.0
|
25 |
+
idna==3.7
|
26 |
+
jax==0.4.30
|
27 |
+
jaxlib==0.4.30
|
28 |
+
Jinja2==3.1.4
|
29 |
+
kiwisolver==1.4.5
|
30 |
+
markdown-it-py==3.0.0
|
31 |
+
MarkupSafe==2.1.5
|
32 |
+
matplotlib==3.9.1
|
33 |
+
mdurl==0.1.2
|
34 |
+
mediapipe==0.10.14
|
35 |
+
ml-dtypes==0.4.0
|
36 |
+
numpy==2.0.0
|
37 |
+
opencv-contrib-python==4.10.0.84
|
38 |
+
opencv-python==4.10.0.84
|
39 |
+
opt-einsum==3.3.0
|
40 |
+
orjson==3.10.6
|
41 |
+
packaging==24.1
|
42 |
+
pandas==2.2.2
|
43 |
+
pillow==10.4.0
|
44 |
+
postgrest==0.16.8
|
45 |
+
protobuf==4.25.3
|
46 |
+
pycparser==2.22
|
47 |
+
pydantic==2.8.2
|
48 |
+
pydantic_core==2.20.1
|
49 |
+
Pygments==2.18.0
|
50 |
+
pyparsing==3.1.2
|
51 |
+
python-dateutil==2.9.0.post0
|
52 |
+
python-dotenv==1.0.1
|
53 |
+
python-multipart==0.0.9
|
54 |
+
pytz==2024.1
|
55 |
+
PyYAML==6.0.1
|
56 |
+
realtime==1.0.6
|
57 |
+
requests==2.32.3
|
58 |
+
rich==13.7.1
|
59 |
+
scikit-build==0.18.0
|
60 |
+
scipy==1.14.0
|
61 |
+
shellingham==1.5.4
|
62 |
+
six==1.16.0
|
63 |
+
sniffio==1.3.1
|
64 |
+
sounddevice==0.4.7
|
65 |
+
starlette==0.37.2
|
66 |
+
storage3==0.7.6
|
67 |
+
StrEnum==0.4.15
|
68 |
+
supabase==2.5.1
|
69 |
+
supafunc==0.4.6
|
70 |
+
tomli==2.0.1
|
71 |
+
typer==0.12.3
|
72 |
+
typing_extensions==4.12.2
|
73 |
+
tzdata==2024.1
|
74 |
+
ujson==5.10.0
|
75 |
+
urllib3==2.2.2
|
76 |
+
uvicorn==0.30.1
|
77 |
+
uvloop==0.19.0
|
78 |
+
watchfiles==0.22.0
|
79 |
+
websockets==12.0
|
80 |
+
-e .
|
setup.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
|
3 |
+
HYPER_E_DOT = "-e ."
|
4 |
+
def getRequirements(requirementsPath: str) -> list[str]:
|
5 |
+
with open(requirementsPath) as file:
|
6 |
+
requirements = file.read().split("\n")
|
7 |
+
requirements.remove(HYPER_E_DOT)
|
8 |
+
return requirements
|
9 |
+
|
10 |
+
setup(
|
11 |
+
name = "Magical-Mirror",
|
12 |
+
author = "Subramani Sivakumar",
|
13 |
+
author_email = "[email protected]",
|
14 |
+
version = "0.1",
|
15 |
+
packages = find_packages(),
|
16 |
+
install_requires = getRequirements(requirementsPath = "./requirements.txt")
|
17 |
+
)
|
src/components/__init__.py
ADDED
File without changes
|
src/components/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (152 Bytes). View file
|
|
src/components/__pycache__/necklaceTryOn.cpython-310.pyc
ADDED
Binary file (6.35 kB). View file
|
|
src/components/necklaceTryOn.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from cvzone.FaceMeshModule import FaceMeshDetector
|
2 |
+
from src.utils import addWatermark, returnBytesData
|
3 |
+
from src.utils.exceptions import CustomException
|
4 |
+
from cvzone.PoseModule import PoseDetector
|
5 |
+
from src.utils.logger import logger
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from typing import Union
|
8 |
+
from PIL import Image
|
9 |
+
import numpy as np
|
10 |
+
import cvzone
|
11 |
+
import math
|
12 |
+
import cv2
|
13 |
+
import gc
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class NecklaceTryOnConfig:
|
18 |
+
logoURL: str = "https://lvuhhlrkcuexzqtsbqyu.supabase.co/storage/v1/object/public/MagicMirror/FullImages/{}.png"
|
19 |
+
|
20 |
+
|
21 |
+
class NecklaceTryOn:
|
22 |
+
def __init__(self) -> None:
|
23 |
+
self.detector = PoseDetector()
|
24 |
+
self.necklaceTryOnConfig = NecklaceTryOnConfig()
|
25 |
+
self.meshDetector = FaceMeshDetector(staticMode=True, maxFaces=1)
|
26 |
+
|
27 |
+
def necklaceTryOn(self, image: Image.Image, jewellery: Image.Image, storename: str) -> list[
|
28 |
+
Union[Image.Image, str]]:
|
29 |
+
try:
|
30 |
+
logger.info(f">>> NECKLACE TRY ON STARTED :: {storename} <<<")
|
31 |
+
|
32 |
+
# reading the images
|
33 |
+
image, jewellery = image.convert("RGB").resize((3000, 3000)), jewellery.convert("RGBA")
|
34 |
+
image = np.array(image)
|
35 |
+
copy_image = image.copy()
|
36 |
+
jewellery = np.array(jewellery)
|
37 |
+
|
38 |
+
logger.info(f"NECKLACE TRY ON :: detecting pose and landmarks :: {storename}")
|
39 |
+
|
40 |
+
image = self.detector.findPose(image)
|
41 |
+
lmList, _ = self.detector.findPosition(image, bboxWithHands=False, draw=False)
|
42 |
+
|
43 |
+
img, faces = self.meshDetector.findFaceMesh(image, draw=False)
|
44 |
+
leftLandmarkIndex = 172
|
45 |
+
rightLandmarkIndex = 397
|
46 |
+
|
47 |
+
leftLandmark, rightLandmark = faces[0][leftLandmarkIndex], faces[0][rightLandmarkIndex]
|
48 |
+
landmarksDistance = int(
|
49 |
+
((leftLandmark[0] - rightLandmark[0]) ** 2 + (leftLandmark[1] - rightLandmark[1]) ** 2) ** 0.5)
|
50 |
+
|
51 |
+
logger.info(f"NECKLACE TRY ON :: estimating neck points :: {storename}")
|
52 |
+
|
53 |
+
# avg_x1 = int(leftLandmark[0] - landmarksDistance * 0.15) -> V2.1
|
54 |
+
avg_x1 = int(leftLandmark[0] - landmarksDistance * 0.12)
|
55 |
+
avg_x2 = int(rightLandmark[0] + landmarksDistance * 0.12)
|
56 |
+
|
57 |
+
avg_y1 = int(leftLandmark[1] + landmarksDistance * 0.5)
|
58 |
+
avg_y2 = int(rightLandmark[1] + landmarksDistance * 0.5)
|
59 |
+
|
60 |
+
logger.info(f"NECKLACE TRY ON :: scaling the necklace image :: {storename}")
|
61 |
+
|
62 |
+
if avg_y2 < avg_y1:
|
63 |
+
angle = math.ceil(
|
64 |
+
self.detector.findAngle(
|
65 |
+
p1=(avg_x2, avg_y2), p2=(avg_x1, avg_y1), p3=(avg_x2, avg_y1)
|
66 |
+
)[0]
|
67 |
+
)
|
68 |
+
else:
|
69 |
+
angle = math.ceil(
|
70 |
+
self.detector.findAngle(
|
71 |
+
p1=(avg_x2, avg_y2), p2=(avg_x1, avg_y1), p3=(avg_x2, avg_y1)
|
72 |
+
)[0]
|
73 |
+
)
|
74 |
+
angle = angle * -1
|
75 |
+
|
76 |
+
xdist = avg_x2 - avg_x1
|
77 |
+
origImgRatio = xdist / jewellery.shape[1]
|
78 |
+
ydist = jewellery.shape[0] * origImgRatio
|
79 |
+
|
80 |
+
logger.info(f"NECKLACE TRY ON :: adding offset based on the necklace shape :: {storename}")
|
81 |
+
|
82 |
+
image_gray = cv2.cvtColor(jewellery, cv2.COLOR_BGRA2GRAY)
|
83 |
+
for offset_orig in range(image_gray.shape[1]):
|
84 |
+
pixel_value = image_gray[0, :][offset_orig]
|
85 |
+
if (pixel_value != 255) & (pixel_value != 0):
|
86 |
+
break
|
87 |
+
else:
|
88 |
+
continue
|
89 |
+
offset = int(0.8 * xdist * (offset_orig / jewellery.shape[1]))
|
90 |
+
jewellery = cv2.resize(
|
91 |
+
jewellery, (int(xdist), int(ydist)), interpolation=cv2.INTER_AREA
|
92 |
+
)
|
93 |
+
jewellery = cvzone.rotateImage(jewellery, angle)
|
94 |
+
y_coordinate = avg_y1 - offset
|
95 |
+
available_space = copy_image.shape[0] - y_coordinate
|
96 |
+
extra = jewellery.shape[0] - available_space
|
97 |
+
|
98 |
+
logger.info(f"NECKLACE TRY ON :: generating necklace placement status :: {storename}")
|
99 |
+
|
100 |
+
if extra > 0:
|
101 |
+
headerText = "To see more of the necklace, please step back slightly."
|
102 |
+
else:
|
103 |
+
headerText = "success"
|
104 |
+
|
105 |
+
logger.info(f"NECKLACE TRY ON :: generating output :: {storename}")
|
106 |
+
|
107 |
+
result = cvzone.overlayPNG(copy_image, jewellery, (avg_x1, y_coordinate))
|
108 |
+
image = Image.fromarray(result.astype(np.uint8))
|
109 |
+
logo = Image.open(returnBytesData(url=self.necklaceTryOnConfig.logoURL.format(storename)))
|
110 |
+
result = addWatermark(background=image, logo=logo)
|
111 |
+
|
112 |
+
gc.collect()
|
113 |
+
|
114 |
+
return [result, headerText]
|
115 |
+
|
116 |
+
except Exception as e:
|
117 |
+
logger.error(f"{CustomException(e)}:: {storename}")
|
118 |
+
raise CustomException(e)
|
119 |
+
|
120 |
+
def necklaceTryOnV3(self, image: Image.Image, jewellery: Image.Image, storename: str) -> list[
|
121 |
+
Union[Image.Image, str]]:
|
122 |
+
try:
|
123 |
+
logger.info(f">>> NECKLACE TRY ON STARTED :: {storename} <<<")
|
124 |
+
|
125 |
+
# reading the images
|
126 |
+
image, jewellery = image.convert("RGB"), jewellery.convert("RGBA")
|
127 |
+
image = np.array(image.resize((4000, 4000)))
|
128 |
+
copy_image = image.copy()
|
129 |
+
jewellery = np.array(jewellery)
|
130 |
+
|
131 |
+
logger.info(f"NECKLACE TRY ON :: detecting pose and landmarks :: {storename}")
|
132 |
+
|
133 |
+
image = self.detector.findPose(image)
|
134 |
+
lmList, _ = self.detector.findPosition(image, bboxWithHands=False, draw=False)
|
135 |
+
meshDetector = FaceMeshDetector(staticMode=True, maxFaces=1)
|
136 |
+
img, faces = meshDetector.findFaceMesh(image, draw=False)
|
137 |
+
left_lip_point = faces[0][61]
|
138 |
+
right_lip_point = faces[0][291]
|
139 |
+
pt12, pt11, pt10, pt9 = (
|
140 |
+
lmList[12][:2],
|
141 |
+
lmList[11][:2],
|
142 |
+
lmList[10][:2],
|
143 |
+
lmList[9][:2],
|
144 |
+
)
|
145 |
+
|
146 |
+
mid_lips = (
|
147 |
+
int((left_lip_point[0] + right_lip_point[0]) / 2), int((left_lip_point[1] + right_lip_point[1]) / 2))
|
148 |
+
|
149 |
+
mid_lips_x1 = int(pt12[0] + (mid_lips[0] - pt12[0]) / 2)
|
150 |
+
mid_lips_y1 = int(pt12[1] + (mid_lips[1] - pt12[1]) / 2)
|
151 |
+
|
152 |
+
mid_lips_x2 = int(pt11[0] + (mid_lips[0] - pt11[0]) / 2)
|
153 |
+
mid_lips_y2 = int(pt11[1] + (mid_lips[1] - pt11[1]) / 2)
|
154 |
+
|
155 |
+
# left right lip
|
156 |
+
left_right_lip_org_x11 = int(pt12[0] + (right_lip_point[0] - pt12[0]) / 2)
|
157 |
+
left_right_lip_org_y11 = int(pt12[1] + (right_lip_point[1] - pt12[1]) / 2)
|
158 |
+
|
159 |
+
left_right_lip_org_x12 = int(pt11[0] + (left_lip_point[0] - pt11[0]) / 2)
|
160 |
+
left_right_lip_org_y12 = int(pt11[1] + (left_lip_point[1] - pt11[1]) / 2)
|
161 |
+
|
162 |
+
# left right lip 2
|
163 |
+
left_right_lip_org_x21 = int(pt12[0] + (left_lip_point[0] - pt12[0]) / 2)
|
164 |
+
left_right_lip_org_y21 = int(pt12[1] + (left_lip_point[1] - pt12[1]) / 2)
|
165 |
+
|
166 |
+
left_right_lip_org_x22 = int(pt11[0] + (right_lip_point[0] - pt11[0]) / 2)
|
167 |
+
left_right_lip_org_y22 = int(pt11[1] + (right_lip_point[1] - pt11[1]) / 2)
|
168 |
+
|
169 |
+
logger.info(f"NECKLACE TRY ON :: estimating neck points :: {storename}")
|
170 |
+
|
171 |
+
avg_x1 = int((mid_lips_x1 + left_right_lip_org_x11 + left_right_lip_org_x21) / 3)
|
172 |
+
avg_y1 = int((mid_lips_y1 + left_right_lip_org_y11 + left_right_lip_org_y21) / 3)
|
173 |
+
|
174 |
+
avg_x2 = int((mid_lips_x2 + left_right_lip_org_x12 + left_right_lip_org_x22) / 3)
|
175 |
+
avg_y2 = int((mid_lips_y2 + left_right_lip_org_y12 + left_right_lip_org_y22) / 3)
|
176 |
+
|
177 |
+
logger.info(f"NECKLACE TRY ON :: scaling the necklace image :: {storename}")
|
178 |
+
|
179 |
+
if avg_y2 < avg_y1:
|
180 |
+
angle = math.ceil(
|
181 |
+
self.detector.findAngle(
|
182 |
+
p1=(avg_x2, avg_y2), p2=(avg_x1, avg_y1), p3=(avg_x2, avg_y1)
|
183 |
+
)[0]
|
184 |
+
)
|
185 |
+
else:
|
186 |
+
angle = math.ceil(
|
187 |
+
self.detector.findAngle(
|
188 |
+
p1=(avg_x2, avg_y2), p2=(avg_x1, avg_y1), p3=(avg_x2, avg_y1)
|
189 |
+
)[0]
|
190 |
+
)
|
191 |
+
angle = angle * -1
|
192 |
+
|
193 |
+
xdist = avg_x2 - avg_x1
|
194 |
+
origImgRatio = xdist / jewellery.shape[1]
|
195 |
+
ydist = jewellery.shape[0] * origImgRatio
|
196 |
+
|
197 |
+
logger.info(f"NECKLACE TRY ON :: adding offset based on the necklace shape :: {storename}")
|
198 |
+
|
199 |
+
image_gray = cv2.cvtColor(jewellery, cv2.COLOR_BGRA2GRAY)
|
200 |
+
for offset_orig in range(image_gray.shape[1]):
|
201 |
+
pixel_value = image_gray[0, :][offset_orig]
|
202 |
+
if (pixel_value != 255) & (pixel_value != 0):
|
203 |
+
break
|
204 |
+
else:
|
205 |
+
continue
|
206 |
+
offset = int(0.8 * xdist * (offset_orig / jewellery.shape[1]))
|
207 |
+
jewellery = cv2.resize(
|
208 |
+
jewellery, (int(xdist), int(ydist)), interpolation=cv2.INTER_AREA
|
209 |
+
)
|
210 |
+
jewellery = cvzone.rotateImage(jewellery, angle)
|
211 |
+
y_coordinate = avg_y1 - offset
|
212 |
+
available_space = copy_image.shape[0] - y_coordinate
|
213 |
+
extra = jewellery.shape[0] - available_space
|
214 |
+
|
215 |
+
logger.info(f"NECKLACE TRY ON :: generating necklace placement status :: {storename}")
|
216 |
+
|
217 |
+
if extra > 0:
|
218 |
+
headerText = "To see more of the necklace, please step back slightly."
|
219 |
+
else:
|
220 |
+
headerText = "success"
|
221 |
+
|
222 |
+
logger.info(f"NECKLACE TRY ON :: generating output :: {storename}")
|
223 |
+
|
224 |
+
result = cvzone.overlayPNG(copy_image, jewellery, (avg_x1, y_coordinate))
|
225 |
+
image = Image.fromarray(result.astype(np.uint8))
|
226 |
+
logo = Image.open(returnBytesData(url=self.necklaceTryOnConfig.logoURL.format({storename})))
|
227 |
+
result = addWatermark(background=image, logo=logo)
|
228 |
+
|
229 |
+
gc.collect()
|
230 |
+
|
231 |
+
return [result, headerText]
|
232 |
+
|
233 |
+
except Exception as e:
|
234 |
+
logger.error(f"{CustomException(e)}:: {storename}")
|
235 |
+
raise CustomException(e)
|
src/pipelines/__init__.py
ADDED
File without changes
|
src/pipelines/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (151 Bytes). View file
|
|
src/pipelines/__pycache__/completePipeline.cpython-310.pyc
ADDED
Binary file (925 Bytes). View file
|
|
src/pipelines/completePipeline.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.components.necklaceTryOn import NecklaceTryOn
|
2 |
+
from typing import Union
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
class Pipeline:
|
7 |
+
def __init__(self) -> None:
|
8 |
+
self.necklaceTryOnObj = NecklaceTryOn()
|
9 |
+
|
10 |
+
async def necklaceTryOn(self, image: Image.Image, jewellery: Image.Image, storename: str) -> list[
|
11 |
+
Union[Image.Image, str]]:
|
12 |
+
result, headerText = self.necklaceTryOnObj.necklaceTryOn(image=image, jewellery=jewellery, storename=storename)
|
13 |
+
return [result, headerText]
|
src/utils/__init__.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from supabase import create_client, Client
|
2 |
+
from PIL import Image
|
3 |
+
from io import BytesIO
|
4 |
+
import requests
|
5 |
+
import os
|
6 |
+
|
7 |
+
|
8 |
+
# function to add watermark to images
|
9 |
+
def addWatermark(background: Image.Image, logo: Image.Image) -> Image.Image:
|
10 |
+
background = background.convert("RGBA")
|
11 |
+
logo = logo.convert("RGBA").resize((int(0.08 * background.size[0]), int(0.08 * background.size[0])))
|
12 |
+
background.paste(logo, (10, background.size[1] - logo.size[1] - 10), logo)
|
13 |
+
return background
|
14 |
+
|
15 |
+
|
16 |
+
# function to download an image from url and return as bytes objects
|
17 |
+
def returnBytesData(url: str) -> BytesIO:
|
18 |
+
response = requests.get(url)
|
19 |
+
return BytesIO(response.content)
|
20 |
+
|
21 |
+
|
22 |
+
# function to get public URLs of paths
|
23 |
+
def supabaseGetPublicURL(path: str) -> str:
|
24 |
+
url_string = f"https://lvuhhlrkcuexzqtsbqyu.supabase.co/storage/v1/object/public/Stores/{path}"
|
25 |
+
return url_string.replace(" ", "%20")
|
26 |
+
|
27 |
+
|
28 |
+
# function to deduct credit
|
29 |
+
def deductAndTrackCredit(storename: str, endpoint: str) -> str:
|
30 |
+
url: str = os.environ["SUPABASE_URL"]
|
31 |
+
key: str = os.environ["SUPABASE_KEY"]
|
32 |
+
supabase: Client = create_client(url, key)
|
33 |
+
current, _ = supabase.table('ClientConfig').select('CreditBalance').eq("StoreName", f"{storename}").execute()
|
34 |
+
if current[1] == []:
|
35 |
+
return "Not Found"
|
36 |
+
else:
|
37 |
+
current = current[1][0]["CreditBalance"]
|
38 |
+
if current > 0:
|
39 |
+
data, _ = supabase.table('ClientConfig').update({'CreditBalance': current - 1}).eq("StoreName",
|
40 |
+
f"{storename}").execute()
|
41 |
+
data, _ = supabase.table('UsageHistory').insert(
|
42 |
+
{'StoreName': f"{storename}", 'APIEndpoint': f"{endpoint}"}).execute()
|
43 |
+
return "Success"
|
44 |
+
else:
|
45 |
+
return "No Credits Available"
|
src/utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (1.7 kB). View file
|
|
src/utils/__pycache__/backgroundEnhancerArchitecture.cpython-310.pyc
ADDED
Binary file (9.92 kB). View file
|
|
src/utils/__pycache__/exceptions.cpython-310.pyc
ADDED
Binary file (1.03 kB). View file
|
|
src/utils/__pycache__/logger.cpython-310.pyc
ADDED
Binary file (683 Bytes). View file
|
|
src/utils/backgroundEnhancerArchitecture.py
ADDED
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from huggingface_hub import PyTorchModelHubMixin
|
5 |
+
|
6 |
+
|
7 |
+
class REBNCONV(nn.Module):
|
8 |
+
def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
|
9 |
+
super(REBNCONV, self).__init__()
|
10 |
+
|
11 |
+
self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride)
|
12 |
+
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
13 |
+
self.relu_s1 = nn.ReLU(inplace=True)
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
hx = x
|
17 |
+
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
|
18 |
+
|
19 |
+
return xout
|
20 |
+
|
21 |
+
|
22 |
+
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
|
23 |
+
def _upsample_like(src, tar):
|
24 |
+
src = F.interpolate(src, size=tar.shape[2:], mode='bilinear')
|
25 |
+
|
26 |
+
return src
|
27 |
+
|
28 |
+
|
29 |
+
### RSU-7 ###
|
30 |
+
class RSU7(nn.Module):
|
31 |
+
|
32 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
|
33 |
+
super(RSU7, self).__init__()
|
34 |
+
|
35 |
+
self.in_ch = in_ch
|
36 |
+
self.mid_ch = mid_ch
|
37 |
+
self.out_ch = out_ch
|
38 |
+
|
39 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
|
40 |
+
|
41 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
42 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
43 |
+
|
44 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
45 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
46 |
+
|
47 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
48 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
49 |
+
|
50 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
51 |
+
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
52 |
+
|
53 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
54 |
+
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
55 |
+
|
56 |
+
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
57 |
+
|
58 |
+
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
59 |
+
|
60 |
+
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
61 |
+
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
62 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
63 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
64 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
65 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
b, c, h, w = x.shape
|
69 |
+
|
70 |
+
hx = x
|
71 |
+
hxin = self.rebnconvin(hx)
|
72 |
+
|
73 |
+
hx1 = self.rebnconv1(hxin)
|
74 |
+
hx = self.pool1(hx1)
|
75 |
+
|
76 |
+
hx2 = self.rebnconv2(hx)
|
77 |
+
hx = self.pool2(hx2)
|
78 |
+
|
79 |
+
hx3 = self.rebnconv3(hx)
|
80 |
+
hx = self.pool3(hx3)
|
81 |
+
|
82 |
+
hx4 = self.rebnconv4(hx)
|
83 |
+
hx = self.pool4(hx4)
|
84 |
+
|
85 |
+
hx5 = self.rebnconv5(hx)
|
86 |
+
hx = self.pool5(hx5)
|
87 |
+
|
88 |
+
hx6 = self.rebnconv6(hx)
|
89 |
+
|
90 |
+
hx7 = self.rebnconv7(hx6)
|
91 |
+
|
92 |
+
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
|
93 |
+
hx6dup = _upsample_like(hx6d, hx5)
|
94 |
+
|
95 |
+
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
|
96 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
97 |
+
|
98 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
99 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
100 |
+
|
101 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
102 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
103 |
+
|
104 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
105 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
106 |
+
|
107 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
108 |
+
|
109 |
+
return hx1d + hxin
|
110 |
+
|
111 |
+
|
112 |
+
### RSU-6 ###
|
113 |
+
class RSU6(nn.Module):
|
114 |
+
|
115 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
116 |
+
super(RSU6, self).__init__()
|
117 |
+
|
118 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
119 |
+
|
120 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
121 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
122 |
+
|
123 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
124 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
125 |
+
|
126 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
127 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
128 |
+
|
129 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
130 |
+
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
131 |
+
|
132 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
133 |
+
|
134 |
+
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
135 |
+
|
136 |
+
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
137 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
138 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
139 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
140 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
141 |
+
|
142 |
+
def forward(self, x):
|
143 |
+
hx = x
|
144 |
+
|
145 |
+
hxin = self.rebnconvin(hx)
|
146 |
+
|
147 |
+
hx1 = self.rebnconv1(hxin)
|
148 |
+
hx = self.pool1(hx1)
|
149 |
+
|
150 |
+
hx2 = self.rebnconv2(hx)
|
151 |
+
hx = self.pool2(hx2)
|
152 |
+
|
153 |
+
hx3 = self.rebnconv3(hx)
|
154 |
+
hx = self.pool3(hx3)
|
155 |
+
|
156 |
+
hx4 = self.rebnconv4(hx)
|
157 |
+
hx = self.pool4(hx4)
|
158 |
+
|
159 |
+
hx5 = self.rebnconv5(hx)
|
160 |
+
|
161 |
+
hx6 = self.rebnconv6(hx5)
|
162 |
+
|
163 |
+
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
|
164 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
165 |
+
|
166 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
167 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
168 |
+
|
169 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
170 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
171 |
+
|
172 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
173 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
174 |
+
|
175 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
176 |
+
|
177 |
+
return hx1d + hxin
|
178 |
+
|
179 |
+
|
180 |
+
### RSU-5 ###
|
181 |
+
class RSU5(nn.Module):
|
182 |
+
|
183 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
184 |
+
super(RSU5, self).__init__()
|
185 |
+
|
186 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
187 |
+
|
188 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
189 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
190 |
+
|
191 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
192 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
193 |
+
|
194 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
195 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
196 |
+
|
197 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
198 |
+
|
199 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
200 |
+
|
201 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
202 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
203 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
204 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
205 |
+
|
206 |
+
def forward(self, x):
|
207 |
+
hx = x
|
208 |
+
|
209 |
+
hxin = self.rebnconvin(hx)
|
210 |
+
|
211 |
+
hx1 = self.rebnconv1(hxin)
|
212 |
+
hx = self.pool1(hx1)
|
213 |
+
|
214 |
+
hx2 = self.rebnconv2(hx)
|
215 |
+
hx = self.pool2(hx2)
|
216 |
+
|
217 |
+
hx3 = self.rebnconv3(hx)
|
218 |
+
hx = self.pool3(hx3)
|
219 |
+
|
220 |
+
hx4 = self.rebnconv4(hx)
|
221 |
+
|
222 |
+
hx5 = self.rebnconv5(hx4)
|
223 |
+
|
224 |
+
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
|
225 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
226 |
+
|
227 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
228 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
229 |
+
|
230 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
231 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
232 |
+
|
233 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
234 |
+
|
235 |
+
return hx1d + hxin
|
236 |
+
|
237 |
+
|
238 |
+
### RSU-4 ###
|
239 |
+
class RSU4(nn.Module):
|
240 |
+
|
241 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
242 |
+
super(RSU4, self).__init__()
|
243 |
+
|
244 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
245 |
+
|
246 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
247 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
248 |
+
|
249 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
250 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
251 |
+
|
252 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
253 |
+
|
254 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
255 |
+
|
256 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
257 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
258 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
259 |
+
|
260 |
+
def forward(self, x):
|
261 |
+
hx = x
|
262 |
+
|
263 |
+
hxin = self.rebnconvin(hx)
|
264 |
+
|
265 |
+
hx1 = self.rebnconv1(hxin)
|
266 |
+
hx = self.pool1(hx1)
|
267 |
+
|
268 |
+
hx2 = self.rebnconv2(hx)
|
269 |
+
hx = self.pool2(hx2)
|
270 |
+
|
271 |
+
hx3 = self.rebnconv3(hx)
|
272 |
+
|
273 |
+
hx4 = self.rebnconv4(hx3)
|
274 |
+
|
275 |
+
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
276 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
277 |
+
|
278 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
279 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
280 |
+
|
281 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
282 |
+
|
283 |
+
return hx1d + hxin
|
284 |
+
|
285 |
+
|
286 |
+
### RSU-4F ###
|
287 |
+
class RSU4F(nn.Module):
|
288 |
+
|
289 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
290 |
+
super(RSU4F, self).__init__()
|
291 |
+
|
292 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
293 |
+
|
294 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
295 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
296 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
|
297 |
+
|
298 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
|
299 |
+
|
300 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
|
301 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
|
302 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
303 |
+
|
304 |
+
def forward(self, x):
|
305 |
+
hx = x
|
306 |
+
|
307 |
+
hxin = self.rebnconvin(hx)
|
308 |
+
|
309 |
+
hx1 = self.rebnconv1(hxin)
|
310 |
+
hx2 = self.rebnconv2(hx1)
|
311 |
+
hx3 = self.rebnconv3(hx2)
|
312 |
+
|
313 |
+
hx4 = self.rebnconv4(hx3)
|
314 |
+
|
315 |
+
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
316 |
+
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
|
317 |
+
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
|
318 |
+
|
319 |
+
return hx1d + hxin
|
320 |
+
|
321 |
+
|
322 |
+
class myrebnconv(nn.Module):
|
323 |
+
def __init__(self, in_ch=3,
|
324 |
+
out_ch=1,
|
325 |
+
kernel_size=3,
|
326 |
+
stride=1,
|
327 |
+
padding=1,
|
328 |
+
dilation=1,
|
329 |
+
groups=1):
|
330 |
+
super(myrebnconv, self).__init__()
|
331 |
+
|
332 |
+
self.conv = nn.Conv2d(in_ch,
|
333 |
+
out_ch,
|
334 |
+
kernel_size=kernel_size,
|
335 |
+
stride=stride,
|
336 |
+
padding=padding,
|
337 |
+
dilation=dilation,
|
338 |
+
groups=groups)
|
339 |
+
self.bn = nn.BatchNorm2d(out_ch)
|
340 |
+
self.rl = nn.ReLU(inplace=True)
|
341 |
+
|
342 |
+
def forward(self, x):
|
343 |
+
return self.rl(self.bn(self.conv(x)))
|
344 |
+
|
345 |
+
|
346 |
+
class BackgroundEnhancerArchitecture(nn.Module, PyTorchModelHubMixin):
|
347 |
+
|
348 |
+
def __init__(self, config: dict = {"in_ch": 3, "out_ch": 1}):
|
349 |
+
super(BackgroundEnhancerArchitecture, self).__init__()
|
350 |
+
in_ch = config["in_ch"]
|
351 |
+
out_ch = config["out_ch"]
|
352 |
+
self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
|
353 |
+
self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
354 |
+
|
355 |
+
self.stage1 = RSU7(64, 32, 64)
|
356 |
+
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
357 |
+
|
358 |
+
self.stage2 = RSU6(64, 32, 128)
|
359 |
+
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
360 |
+
|
361 |
+
self.stage3 = RSU5(128, 64, 256)
|
362 |
+
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
363 |
+
|
364 |
+
self.stage4 = RSU4(256, 128, 512)
|
365 |
+
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
366 |
+
|
367 |
+
self.stage5 = RSU4F(512, 256, 512)
|
368 |
+
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
369 |
+
|
370 |
+
self.stage6 = RSU4F(512, 256, 512)
|
371 |
+
|
372 |
+
# decoder
|
373 |
+
self.stage5d = RSU4F(1024, 256, 512)
|
374 |
+
self.stage4d = RSU4(1024, 128, 256)
|
375 |
+
self.stage3d = RSU5(512, 64, 128)
|
376 |
+
self.stage2d = RSU6(256, 32, 64)
|
377 |
+
self.stage1d = RSU7(128, 16, 64)
|
378 |
+
|
379 |
+
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
380 |
+
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
381 |
+
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
|
382 |
+
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
|
383 |
+
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
|
384 |
+
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
|
385 |
+
|
386 |
+
# self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
|
387 |
+
|
388 |
+
def forward(self, x):
|
389 |
+
hx = x
|
390 |
+
|
391 |
+
hxin = self.conv_in(hx)
|
392 |
+
# hx = self.pool_in(hxin)
|
393 |
+
|
394 |
+
# stage 1
|
395 |
+
hx1 = self.stage1(hxin)
|
396 |
+
hx = self.pool12(hx1)
|
397 |
+
|
398 |
+
# stage 2
|
399 |
+
hx2 = self.stage2(hx)
|
400 |
+
hx = self.pool23(hx2)
|
401 |
+
|
402 |
+
# stage 3
|
403 |
+
hx3 = self.stage3(hx)
|
404 |
+
hx = self.pool34(hx3)
|
405 |
+
|
406 |
+
# stage 4
|
407 |
+
hx4 = self.stage4(hx)
|
408 |
+
hx = self.pool45(hx4)
|
409 |
+
|
410 |
+
# stage 5
|
411 |
+
hx5 = self.stage5(hx)
|
412 |
+
hx = self.pool56(hx5)
|
413 |
+
|
414 |
+
# stage 6
|
415 |
+
hx6 = self.stage6(hx)
|
416 |
+
hx6up = _upsample_like(hx6, hx5)
|
417 |
+
|
418 |
+
# -------------------- decoder --------------------
|
419 |
+
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
420 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
421 |
+
|
422 |
+
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
423 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
424 |
+
|
425 |
+
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
426 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
427 |
+
|
428 |
+
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
429 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
430 |
+
|
431 |
+
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
432 |
+
|
433 |
+
# side output
|
434 |
+
d1 = self.side1(hx1d)
|
435 |
+
d1 = _upsample_like(d1, x)
|
436 |
+
|
437 |
+
d2 = self.side2(hx2d)
|
438 |
+
d2 = _upsample_like(d2, x)
|
439 |
+
|
440 |
+
d3 = self.side3(hx3d)
|
441 |
+
d3 = _upsample_like(d3, x)
|
442 |
+
|
443 |
+
d4 = self.side4(hx4d)
|
444 |
+
d4 = _upsample_like(d4, x)
|
445 |
+
|
446 |
+
d5 = self.side5(hx5d)
|
447 |
+
d5 = _upsample_like(d5, x)
|
448 |
+
|
449 |
+
d6 = self.side6(hx6)
|
450 |
+
d6 = _upsample_like(d6, x)
|
451 |
+
|
452 |
+
return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)], [hx1d, hx2d,
|
453 |
+
hx3d, hx4d,
|
454 |
+
hx5d, hx6]
|
src/utils/exceptions.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
def error_message_detail(error):
|
4 |
+
_, _, exc_info = sys.exc_info()
|
5 |
+
filename = exc_info.tb_frame.f_code.co_filename
|
6 |
+
lineno = exc_info.tb_lineno
|
7 |
+
error_message = "Error encountered in line no [{}], filename : [{}], saying [{}]".format(lineno, filename, error)
|
8 |
+
return error_message
|
9 |
+
|
10 |
+
class CustomException(Exception):
|
11 |
+
def __init__(self, error_message):
|
12 |
+
super().__init__(error_message)
|
13 |
+
self.error_message = error_message_detail(error_message)
|
14 |
+
|
15 |
+
def __str__(self) -> str:
|
16 |
+
return self.error_message
|
src/utils/logger.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
|
4 |
+
logger = logging.getLogger(__name__)
|
5 |
+
logger.setLevel(logging.INFO)
|
6 |
+
|
7 |
+
log_dir = os.path.join(os.getcwd(), "logs")
|
8 |
+
os.makedirs(log_dir, exist_ok = True)
|
9 |
+
|
10 |
+
LOG_FILE = os.path.join(log_dir, "running_logs.log")
|
11 |
+
|
12 |
+
logFormat = "[%(asctime)s: %(levelname)s: %(module)s: %(message)s]"
|
13 |
+
logFormatter = logging.Formatter(fmt = logFormat, style = "%")
|
14 |
+
|
15 |
+
streamHandler = logging.StreamHandler()
|
16 |
+
streamHandler.setFormatter(logFormatter)
|
17 |
+
|
18 |
+
fileHandler = logging.FileHandler(filename = LOG_FILE)
|
19 |
+
fileHandler.setFormatter(logFormatter)
|
20 |
+
|
21 |
+
logger.addHandler(streamHandler)
|
22 |
+
logger.addHandler(fileHandler)
|