mole_analyzer / app.py
Jean-Baptiste's picture
fixing gradio start issue
0a3d2f1
raw
history blame
5.03 kB
import gradio as gr
import os
from useful_functions import *
import useful_functions
from dotenv import load_dotenv
load_dotenv()
f_load_cancer_classifier()
f_load_cnn_model()
HF_TOKEN = os.getenv('HF_TOKEN')
hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "mole-dataset", private=True)
def image_classifier(file_path, age, sex, localization):
if age == 0:
age = 40
if sex == "":
sex = "unknown"
if localization == "":
localization = "unknown"
# file_path = file_path if file_path is not None else file_path_webcam
preds = f_predict_cnn_with_tta(file_path)
label = f_predict_cancer(preds, age, sex, localization)
return (dict(zip(useful_functions.lesion_model.dls.vocab, preds)),
label)
input_img = gr.Image(tool="editor", type="filepath", source="upload")
# input_webcam = gr.Image(tool="editor", type="filepath", source="webcam")
input_age = gr.Number(label="age (optionnel)")
input_sex = gr.Dropdown(label="sex (optionnel)", choices=["male", "female"])
input_localization = gr.Dropdown(label="localization (optionnel)", choices=["abdomen", "back", "chest", "ear",
"face", "foot", "genital", "hand",
"lower extremity", "neck", "scalp", "trunk", "upper extremity"])
output_lesion = gr.Label(label="Lesion detected")
output_malign = gr.Label(label="Classification")
list_files_examples = os.listdir("examples")
# examples = [[os.path.join("examples", file), 0, "", ""] for file in list_files_examples if file.endswith("jpg")]
examples = []
examples.append([os.path.join("examples", "PXL_20221103_153018529.jpg"), 40, "female", "back"])
examples.append([os.path.join("examples", "PXL_20221103_153129579.jpg"), 40, "male", "neck"])
examples.append([os.path.join("examples", "PXL_20221103_153137616.jpg"), 40, "male", "neck"])
examples.append([os.path.join("examples", "PXL_20221103_153217034.jpg"), 40, "male", "back"])
examples.append([os.path.join("examples", "PXL_20221103_153256612.jpg"), 40, "male", "upper extremity"])
examples.append([os.path.join("examples", "ISIC_0025402.jpg"), 70, "male", "lower extremity"])
demo = gr.Interface(title="Skin mole analyzer",
description=r"""This is a side project I have been working on to practice working with images.
The purpose is to classify skin lesions (Based on kaggle dataset Skin Cancer MNIST: HAM10000).
The framework used is FastAI/pytorch and the model used is a pre-trained cnn (resnet152).
I added an extra layer to use age, sex, localization and output of resnet152 to classify the lesion
as suspicious or not (randomForest model).
The lesions detected are the following:
<ul>
<li>Actinic keratoses and intraepithelial carcinoma / Bowen's disease (akiec)</li>
<li>basal cell carcinoma (bcc)</li>
<li>benign keratosis-like lesions (bkl)</li>
<li>dermatofibroma (df)</li>
<li>melanoma (mel)</li>
<li>melanocytic nevi (nv)</li>
<li>vascular lesions (vasc)</li>
</ul>
<b> This is in no case intended as a medical advice, just a pedagogical exercise.</b> <br />
<i>*Pictures should be relatively well centered on the mole to obtain the best results (cf examples).
You can use the tools available in the right corner to crop optimally.</i>
""",
article="""[1] Noel Codella, Veronica Rotemberg, Philipp Tschandl, M. Emre Celebi, Stephen Dusza, David Gutman,
Brian Helba, Aadi Kalloo, Konstantinos Liopyris, Michael Marchetti, Harald Kittler, Allan Halpern:
"Skin Lesion Analysis Toward Melanoma Detection 2018: A Challenge Hosted by the International Skin Imaging Collaboration (ISIC)",
2018;"<a href="https://arxiv.org/abs/1902.03368">"https://arxiv.org/abs/1902.03368"</a><br />
[2] Tschandl, P., Rosendahl, C. & Kittler, H. The HAM10000 dataset, a large collection of multi-source dermatoscopic
images of common pigmented skin lesions. Sci. Data 5, 180161 doi:10.1038/sdata.2018.161 (2018).""",
fn=image_classifier,
inputs= [input_img,
input_age,
input_sex,
input_localization],
outputs=[output_lesion,
output_malign],
examples=examples,
allow_flagging="auto",
flagging_options=list(useful_functions.lesion_model.dls.vocab) + ["other"],
flagging_callback=hf_writer
)
demo.launch()