sygma-damage-annotation / functions.py
ychafiqui's picture
fixed filtering bug by removing count from part name
b7e9fbb
raw
history blame
2.6 kB
import boto3
from PIL import Image
import pandas as pd
import streamlit as st
import random
import io
s3_client = boto3.client('s3',
aws_access_key_id=st.secrets["aws_access_key_id"],
aws_secret_access_key=st.secrets["aws_secret_access_key"],
region_name='eu-west-3')
bucket_name = "sygma-global-data-storage"
folder = "car-damage-detection/scrappedImages/"
csv_folder = "car-damage-detection/CSVs/"
s3_df_path = csv_folder + "70k_old_annotations_fixed.csv"
response = s3_client.get_object(Bucket=bucket_name, Key=s3_df_path)
with io.BytesIO(response['Body'].read()) as bio:
df = pd.read_csv(bio, low_memory=False)
df = df[df['s3_available'] == True]
def get_car_parts_count():
car_parts = df.columns[6:]
# create a dictionary with the count of each part, +1 for a part if value > 0
car_parts_count = {part: len(df[df[part] > 0]) for part in car_parts}
return [f"{part} ({count})" for part, count in car_parts_count.items()]
def get_random_image(parts_filter=False):
not_validated_imgs = df[df["validated"] == False]["img_name"].tolist()
if parts_filter:
# remove the count from the part name
parts_filter = [part.split(" (")[0] for part in parts_filter]
# get rows where all selected parts are damaged (> 0)
filtered_imgs = df[(df[parts_filter] > 0).all(axis=1)]["img_name"].tolist()
not_validated_imgs = list(set(not_validated_imgs) & set(filtered_imgs))
if len(not_validated_imgs) == 0:
return None, None
image_name = random.choice(not_validated_imgs)
s3_image_path = folder + image_name
try:
response = s3_client.get_object(Bucket=bucket_name, Key=s3_image_path)
image = Image.open(io.BytesIO(response['Body'].read())).resize((1000, 800))
return image, image_name
except:
return get_random_image()
def get_img_damages(img_name):
img_row = df.loc[df["img_name"] == img_name]
damages = img_row.iloc[0, 6:].to_dict()
return damages
def process_image(img_name, annotator_name, is_car, skip, rotation, damaged_parts):
df.loc[df["img_name"] == img_name, "annotator_name"] = annotator_name
df.loc[df["img_name"] == img_name, "is_car"] = is_car
df.loc[df["img_name"] == img_name, "rotation"] = rotation
if not skip:
df.loc[df["img_name"] == img_name, damaged_parts.keys()] = damaged_parts.values()
df.loc[df["img_name"] == img_name, "validated"] = not skip
# df.to_csv("CSVs/70k_old_annotations_fixed.csv", index=False)
s3_client.put_object(Bucket=bucket_name, Key=s3_df_path, Body=df.to_csv(index=False))