|
import streamlit as st |
|
import torch |
|
import clip |
|
from PIL import Image |
|
import os |
|
import pandas as pd |
|
from datetime import datetime |
|
import torch.nn.functional as F |
|
from typing import List |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
model, preprocess = clip.load("ViT-B/32", device=device) |
|
model.eval() |
|
|
|
|
|
st.set_page_config(page_title="Few-Shot Fault Detection", layout="wide") |
|
st.title("🛠️ Few-Shot Fault Detection (Industrial Quality Control)") |
|
|
|
st.markdown(""" |
|
This demo uses the **smaller `ViT-B/32` encoder** from OpenAI's CLIP model to classify test images as **Nominal** or **Defective**, based on few-shot learning using user-provided reference images. |
|
|
|
⚠️ **Note**: This app is running on a **free CPU tier** and is meant for demonstration purposes. For more advanced use cases, including GPU acceleration, custom training, and larger models, please refer to: |
|
|
|
📄 [Megahed et al. (2025)](https://arxiv.org/abs/2501.12596): |
|
*Adapting OpenAI's CLIP Model for Few-Shot Image Inspection in Manufacturing Quality Control: An Expository Case Study with Multiple Application Examples* |
|
|
|
🔗 [GitHub & Colab links available in the paper](https://arxiv.org/abs/2501.12596) |
|
""") |
|
|
|
|
|
def few_shot_fault_classification( |
|
test_images: List[Image.Image], |
|
test_image_filenames: List[str], |
|
nominal_images: List[Image.Image], |
|
nominal_descriptions: List[str], |
|
defective_images: List[Image.Image], |
|
defective_descriptions: List[str], |
|
num_few_shot_nominal_imgs: int, |
|
file_path: str = '.', |
|
file_name: str = 'image_classification_results.csv', |
|
print_one_liner: bool = False |
|
): |
|
if not isinstance(test_images, list): test_images = [test_images] |
|
if not isinstance(test_image_filenames, list): test_image_filenames = [test_image_filenames] |
|
if not isinstance(nominal_images, list): nominal_images = [nominal_images] |
|
if not isinstance(nominal_descriptions, list): nominal_descriptions = [nominal_descriptions] |
|
if not isinstance(defective_images, list): defective_images = [defective_images] |
|
if not isinstance(defective_descriptions, list): defective_descriptions = [defective_descriptions] |
|
|
|
csv_file = os.path.join(file_path, file_name) |
|
results = [] |
|
|
|
with torch.no_grad(): |
|
nominal_features = torch.stack([model.encode_image(img.unsqueeze(0)).squeeze(0).to(device) for img in nominal_images]) |
|
nominal_features /= nominal_features.norm(dim=-1, keepdim=True) |
|
|
|
defective_features = torch.stack([model.encode_image(img.unsqueeze(0)).squeeze(0).to(device) for img in defective_images]) |
|
defective_features /= defective_features.norm(dim=-1, keepdim=True) |
|
|
|
csv_data = [] |
|
|
|
for idx, test_img in enumerate(test_images): |
|
test_features = model.encode_image(test_img.unsqueeze(0)).squeeze(0).to(device) |
|
test_features /= test_features.norm(dim=-1, keepdim=True) |
|
|
|
max_nom_sim, max_def_sim = -float('inf'), -float('inf') |
|
max_nom_idx, max_def_idx = -1, -1 |
|
|
|
for i in range(nominal_features.shape[0]): |
|
sim = (test_features @ nominal_features[i].T).item() |
|
if sim > max_nom_sim: |
|
max_nom_sim, max_nom_idx = sim, i |
|
|
|
for j in range(defective_features.shape[0]): |
|
sim = (test_features @ defective_features[j].T).item() |
|
if sim > max_def_sim: |
|
max_def_sim, max_def_idx = sim, j |
|
|
|
similarities = torch.tensor([max_nom_sim, max_def_sim]) |
|
probabilities = F.softmax(similarities, dim=0).tolist() |
|
prob_nom, prob_def = probabilities |
|
|
|
classification = "Defective" if prob_def > prob_nom else "Nominal" |
|
|
|
csv_data.append({ |
|
"datetime_of_operation": datetime.now().isoformat(), |
|
"num_few_shot_nominal_imgs": num_few_shot_nominal_imgs, |
|
"image_path": test_image_filenames[idx], |
|
"image_name": test_image_filenames[idx].split('/')[-1], |
|
"classification_result": classification, |
|
"non_defect_prob": round(prob_nom, 3), |
|
"defect_prob": round(prob_def, 3), |
|
"nominal_description": nominal_descriptions[max_nom_idx], |
|
"defective_description": defective_descriptions[max_def_idx] if defective_images else "N/A" |
|
}) |
|
|
|
if print_one_liner: |
|
print(f"{test_image_filenames[idx]} classified as {classification} " |
|
f"(Nominal Prob: {prob_nom:.3f}, Defective Prob: {prob_def:.3f})") |
|
|
|
file_exists = os.path.isfile(csv_file) |
|
with open(csv_file, mode='a' if file_exists else 'w', newline='') as file: |
|
import csv |
|
fieldnames = [ |
|
"datetime_of_operation", "num_few_shot_nominal_imgs", "image_path", "image_name", |
|
"classification_result", "non_defect_prob", "defect_prob", |
|
"nominal_description", "defective_description" |
|
] |
|
writer = csv.DictWriter(file, fieldnames=fieldnames) |
|
if not file_exists: |
|
writer.writeheader() |
|
for row in csv_data: |
|
writer.writerow(row) |
|
|
|
return "" |
|
|
|
|
|
if 'nominal_images' not in st.session_state: |
|
st.session_state.nominal_images = [] |
|
if 'defective_images' not in st.session_state: |
|
st.session_state.defective_images = [] |
|
if 'test_images' not in st.session_state: |
|
st.session_state.test_images = [] |
|
if 'results' not in st.session_state: |
|
st.session_state.results = [] |
|
|
|
|
|
tab1, tab2, tab3 = st.tabs(["📥 Upload Reference Images", "🔍 Test Classification", "📊 Results"]) |
|
|
|
|
|
with tab1: |
|
st.header("Upload Reference Images") |
|
nominal_files = st.file_uploader("Upload Nominal Images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg']) |
|
defective_files = st.file_uploader("Upload Defective Images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg']) |
|
|
|
if nominal_files: |
|
st.session_state.nominal_images = [preprocess(Image.open(file).convert("RGB")).to(device) for file in nominal_files] |
|
st.session_state.nominal_descriptions = [file.name for file in nominal_files] |
|
st.success(f"Uploaded {len(nominal_files)} nominal images.") |
|
|
|
if defective_files: |
|
st.session_state.defective_images = [preprocess(Image.open(file).convert("RGB")).to(device) for file in defective_files] |
|
st.session_state.defective_descriptions = [file.name for file in defective_files] |
|
st.success(f"Uploaded {len(defective_files)} defective images.") |
|
|
|
|
|
with tab2: |
|
st.header("Upload Test Image(s)") |
|
test_files = st.file_uploader("Upload Test Images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg']) |
|
|
|
if st.button("🔍 Run Classification") and test_files: |
|
test_images = [preprocess(Image.open(file).convert("RGB")).to(device) for file in test_files] |
|
test_filenames = [file.name for file in test_files] |
|
|
|
few_shot_fault_classification( |
|
test_images=test_images, |
|
test_image_filenames=test_filenames, |
|
nominal_images=st.session_state.nominal_images, |
|
nominal_descriptions=st.session_state.nominal_descriptions, |
|
defective_images=st.session_state.defective_images, |
|
defective_descriptions=st.session_state.defective_descriptions, |
|
num_few_shot_nominal_imgs=len(st.session_state.nominal_images), |
|
file_path=".", |
|
file_name="streamlit_results.csv", |
|
print_one_liner=False |
|
) |
|
|
|
st.success("Classification complete!") |
|
st.session_state.results = "streamlit_results.csv" |
|
|
|
|
|
with tab3: |
|
st.header("Classification Results") |
|
if os.path.exists("streamlit_results.csv"): |
|
df = pd.read_csv("streamlit_results.csv") |
|
st.dataframe(df) |
|
st.download_button("📥 Download Results", data=df.to_csv(index=False), file_name="classification_results.csv", mime="text/csv") |
|
else: |
|
st.info("No results yet. Please classify some test images.") |
|
|