Spaces:
Configuration error
Configuration error
import streamlit as st | |
from PIL import Image | |
import numpy as np | |
from joblib import load | |
from skimage.transform import resize | |
import torch | |
import os | |
import sys | |
# Ensure to run these commands in your terminal first: | |
# pip install git+https://github.com/FacePerceiver/facer.git@main | |
# pip install timm | |
# git clone https://github.com/FacePerceiver/facer.git | |
# Set the path for the 'facer' module | |
sys.path.append('facer') | |
import facer | |
# Load face parsing model | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
face_detector = facer.face_detector('retinaface/mobilenet', device=device) | |
face_parser = facer.face_parser('farl/lapa/448', device=device) | |
# Define the monk scale colors | |
monk_scale = { | |
'Class2': (243, 231, 219), # f3e7db | |
'Class3': (247, 234, 208), # f7ead0 | |
'Class4': (234, 218, 186), # eadaba | |
'Class5': (215, 189, 150), # d7bd96 | |
'Class6': (160, 126, 86), # a07e56 | |
'Class7': (130, 92, 67), # 825c43 | |
'Class8': (96, 65, 52), # 604134 | |
'Class9': (58, 49, 42), # 3a312a | |
'Class10': (41, 36, 32), # 292420 | |
} | |
# Function to convert RGB tuple to hex color code | |
def rgb_to_hex(rgb): | |
return '#{:02x}{:02x}{:02x}'.format(*rgb) | |
# Mapping of Monk classes to colors using monk_scale | |
monk_colors = { | |
'1': [rgb_to_hex(monk_scale['Class2']), rgb_to_hex(monk_scale['Class3']), rgb_to_hex(monk_scale['Class4'])], | |
'2': [rgb_to_hex(monk_scale['Class5']), rgb_to_hex(monk_scale['Class6'])], | |
'3': [rgb_to_hex(monk_scale['Class7']), rgb_to_hex(monk_scale['Class8'])], | |
'4': [rgb_to_hex(monk_scale['Class9']), rgb_to_hex(monk_scale['Class10'])], | |
'default': '#808080' # Default color for unexpected classes | |
} | |
# Mapping of model's output classes to monk classes | |
class_mapping = { | |
0: '1', # Map model class 0 to monk class 1 | |
1: '2', # Map model class 1 to monk class 2 | |
2: '3', # Map model class 2 to monk class 3 | |
3: '4', # Map model class 3 to monk class 4 | |
# Add more mappings if needed | |
} | |
# Function to load the model | |
def load_model(): | |
model_path = r"C:\Users\ramam\svm_model3.joblib" # Adjust the path to your model | |
model = load(model_path) | |
return model | |
# Function to parse face and extract skin region | |
def parse_face(image): | |
# Ensure the image has 3 channels (RGB) | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
image_data = np.array(image) | |
# Check if the image has 3 channels | |
if image_data.shape[2] != 3: | |
raise ValueError("Image does not have 3 channels (RGB).") | |
image_tensor = torch.from_numpy(image_data.astype('float32')).permute(2, 0, 1).unsqueeze(0).to(device) | |
faces = face_detector(image_tensor) | |
if faces: | |
parsed_faces = face_parser(image_tensor, faces) | |
if 'seg' in parsed_faces: | |
seg_logits = parsed_faces['seg']['logits'] | |
seg_probs = torch.sigmoid(seg_logits) | |
binary_mask = seg_probs[0, 1, :, :] > 0.5 | |
binary_mask = binary_mask.cpu().numpy() | |
binary_mask_3d = np.repeat(binary_mask[:, :, np.newaxis], 3, axis=2) | |
skin_region = image_data * binary_mask_3d | |
return skin_region.astype(np.uint8) | |
return None | |
# Function to make predictions | |
def classify_image(image, model): | |
parsed_image = parse_face(image) | |
if parsed_image is not None: | |
image_resized = resize(parsed_image, (128, 128), anti_aliasing=True) # Resize to 128x128 | |
image_reshaped = image_resized.reshape(1, -1) # Reshape to match the model input | |
if image_reshaped.shape[1] == 49152: # Check if resizing is correct | |
image_padded = np.pad(image_reshaped, ((0, 0), (0, 65536 - 49152)), 'constant') | |
else: | |
raise ValueError("Unexpected number of features after reshaping.") | |
prediction = model.predict(image_padded) | |
return prediction[0], parsed_image | |
else: | |
raise ValueError("Face parsing failed.") | |
# Load the model | |
model = load_model() | |
# Function to display the Monk class color | |
def display_monk_class_color(prediction): | |
st.write(f"Prediction: {prediction}") # Debugging | |
monk_class = class_mapping.get(prediction, 'default') | |
colors = monk_colors.get(monk_class, monk_colors['default']) # Default to gray if class not found | |
st.write(f"Monk Class: {monk_class}") | |
for color in colors: | |
st.markdown(f"<div style='width:100px; height:50px; background-color:{color};'></div>", unsafe_allow_html=True) | |
# Streamlit app | |
st.title('Skin Tone Classification') | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file) | |
st.image(image, caption='Uploaded Image.', use_column_width=True) | |
if st.button('Classify'): | |
try: | |
prediction, parsed_image = classify_image(image, model) | |
display_monk_class_color(prediction) | |
st.image(parsed_image, caption='Parsed Image.', use_column_width=True) | |
except ValueError as e: | |
st.error(f"Error: {e}") | |