Spaces:
Configuration error
Configuration error
Upload AuraBloom.py
Browse files- AuraBloom.py +134 -0
AuraBloom.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
from joblib import load
|
5 |
+
from skimage.transform import resize
|
6 |
+
import torch
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
|
10 |
+
# Ensure to run these commands in your terminal first:
|
11 |
+
# pip install git+https://github.com/FacePerceiver/facer.git@main
|
12 |
+
# pip install timm
|
13 |
+
# git clone https://github.com/FacePerceiver/facer.git
|
14 |
+
|
15 |
+
# Set the path for the 'facer' module
|
16 |
+
sys.path.append('facer')
|
17 |
+
|
18 |
+
import facer
|
19 |
+
|
20 |
+
# Load face parsing model
|
21 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
22 |
+
face_detector = facer.face_detector('retinaface/mobilenet', device=device)
|
23 |
+
face_parser = facer.face_parser('farl/lapa/448', device=device)
|
24 |
+
|
25 |
+
# Define the monk scale colors
|
26 |
+
monk_scale = {
|
27 |
+
'Class2': (243, 231, 219), # f3e7db
|
28 |
+
'Class3': (247, 234, 208), # f7ead0
|
29 |
+
'Class4': (234, 218, 186), # eadaba
|
30 |
+
'Class5': (215, 189, 150), # d7bd96
|
31 |
+
'Class6': (160, 126, 86), # a07e56
|
32 |
+
'Class7': (130, 92, 67), # 825c43
|
33 |
+
'Class8': (96, 65, 52), # 604134
|
34 |
+
'Class9': (58, 49, 42), # 3a312a
|
35 |
+
'Class10': (41, 36, 32), # 292420
|
36 |
+
}
|
37 |
+
|
38 |
+
# Function to convert RGB tuple to hex color code
|
39 |
+
def rgb_to_hex(rgb):
|
40 |
+
return '#{:02x}{:02x}{:02x}'.format(*rgb)
|
41 |
+
|
42 |
+
# Mapping of Monk classes to colors using monk_scale
|
43 |
+
monk_colors = {
|
44 |
+
'1': [rgb_to_hex(monk_scale['Class2']), rgb_to_hex(monk_scale['Class3']), rgb_to_hex(monk_scale['Class4'])],
|
45 |
+
'2': [rgb_to_hex(monk_scale['Class5']), rgb_to_hex(monk_scale['Class6'])],
|
46 |
+
'3': [rgb_to_hex(monk_scale['Class7']), rgb_to_hex(monk_scale['Class8'])],
|
47 |
+
'4': [rgb_to_hex(monk_scale['Class9']), rgb_to_hex(monk_scale['Class10'])],
|
48 |
+
'default': '#808080' # Default color for unexpected classes
|
49 |
+
}
|
50 |
+
|
51 |
+
# Mapping of model's output classes to monk classes
|
52 |
+
class_mapping = {
|
53 |
+
0: '1', # Map model class 0 to monk class 1
|
54 |
+
1: '2', # Map model class 1 to monk class 2
|
55 |
+
2: '3', # Map model class 2 to monk class 3
|
56 |
+
3: '4', # Map model class 3 to monk class 4
|
57 |
+
# Add more mappings if needed
|
58 |
+
}
|
59 |
+
|
60 |
+
# Function to load the model
|
61 |
+
def load_model():
|
62 |
+
model_path = r"C:\Users\ramam\svm_model3.joblib" # Adjust the path to your model
|
63 |
+
model = load(model_path)
|
64 |
+
return model
|
65 |
+
|
66 |
+
# Function to parse face and extract skin region
|
67 |
+
def parse_face(image):
|
68 |
+
# Ensure the image has 3 channels (RGB)
|
69 |
+
if image.mode != 'RGB':
|
70 |
+
image = image.convert('RGB')
|
71 |
+
|
72 |
+
image_data = np.array(image)
|
73 |
+
|
74 |
+
# Check if the image has 3 channels
|
75 |
+
if image_data.shape[2] != 3:
|
76 |
+
raise ValueError("Image does not have 3 channels (RGB).")
|
77 |
+
|
78 |
+
image_tensor = torch.from_numpy(image_data.astype('float32')).permute(2, 0, 1).unsqueeze(0).to(device)
|
79 |
+
faces = face_detector(image_tensor)
|
80 |
+
|
81 |
+
if faces:
|
82 |
+
parsed_faces = face_parser(image_tensor, faces)
|
83 |
+
if 'seg' in parsed_faces:
|
84 |
+
seg_logits = parsed_faces['seg']['logits']
|
85 |
+
seg_probs = torch.sigmoid(seg_logits)
|
86 |
+
binary_mask = seg_probs[0, 1, :, :] > 0.5
|
87 |
+
binary_mask = binary_mask.cpu().numpy()
|
88 |
+
binary_mask_3d = np.repeat(binary_mask[:, :, np.newaxis], 3, axis=2)
|
89 |
+
skin_region = image_data * binary_mask_3d
|
90 |
+
return skin_region.astype(np.uint8)
|
91 |
+
return None
|
92 |
+
|
93 |
+
# Function to make predictions
|
94 |
+
def classify_image(image, model):
|
95 |
+
parsed_image = parse_face(image)
|
96 |
+
if parsed_image is not None:
|
97 |
+
image_resized = resize(parsed_image, (128, 128), anti_aliasing=True) # Resize to 128x128
|
98 |
+
image_reshaped = image_resized.reshape(1, -1) # Reshape to match the model input
|
99 |
+
if image_reshaped.shape[1] == 49152: # Check if resizing is correct
|
100 |
+
image_padded = np.pad(image_reshaped, ((0, 0), (0, 65536 - 49152)), 'constant')
|
101 |
+
else:
|
102 |
+
raise ValueError("Unexpected number of features after reshaping.")
|
103 |
+
prediction = model.predict(image_padded)
|
104 |
+
return prediction[0], parsed_image
|
105 |
+
else:
|
106 |
+
raise ValueError("Face parsing failed.")
|
107 |
+
|
108 |
+
# Load the model
|
109 |
+
model = load_model()
|
110 |
+
|
111 |
+
# Function to display the Monk class color
|
112 |
+
def display_monk_class_color(prediction):
|
113 |
+
st.write(f"Prediction: {prediction}") # Debugging
|
114 |
+
monk_class = class_mapping.get(prediction, 'default')
|
115 |
+
colors = monk_colors.get(monk_class, monk_colors['default']) # Default to gray if class not found
|
116 |
+
st.write(f"Monk Class: {monk_class}")
|
117 |
+
for color in colors:
|
118 |
+
st.markdown(f"<div style='width:100px; height:50px; background-color:{color};'></div>", unsafe_allow_html=True)
|
119 |
+
|
120 |
+
# Streamlit app
|
121 |
+
st.title('Skin Tone Classification')
|
122 |
+
|
123 |
+
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
124 |
+
if uploaded_file is not None:
|
125 |
+
image = Image.open(uploaded_file)
|
126 |
+
st.image(image, caption='Uploaded Image.', use_column_width=True)
|
127 |
+
|
128 |
+
if st.button('Classify'):
|
129 |
+
try:
|
130 |
+
prediction, parsed_image = classify_image(image, model)
|
131 |
+
display_monk_class_color(prediction)
|
132 |
+
st.image(parsed_image, caption='Parsed Image.', use_column_width=True)
|
133 |
+
except ValueError as e:
|
134 |
+
st.error(f"Error: {e}")
|