import streamlit as st import torch import clip from PIL import Image import os import pandas as pd from datetime import datetime import torch.nn.functional as F from typing import List # Device setup device = "cuda" if torch.cuda.is_available() else "cpu" # Load CLIP model and preprocessor (ViT-B/32 = small model, CPU-friendly) model, preprocess = clip.load("ViT-B/32", device=device) model.eval() # Display app title and information st.set_page_config(page_title="Few-Shot Fault Detection", layout="wide") st.title("🛠️ Few-Shot Fault Detection (Industrial Quality Control)") st.markdown(""" This demo uses the **smaller `ViT-B/32` encoder** from OpenAI's CLIP model to classify test images as **Nominal** or **Defective**, based on few-shot learning using user-provided reference images. ⚠️ **Note**: This app is running on a **free CPU tier** and is meant for demonstration purposes. For more advanced use cases, including GPU acceleration, custom training, and larger models, please refer to: 📄 [Megahed et al. (2025)](https://arxiv.org/abs/2501.12596): *Adapting OpenAI's CLIP Model for Few-Shot Image Inspection in Manufacturing Quality Control: An Expository Case Study with Multiple Application Examples* 🔗 [GitHub & Colab links available in the paper](https://arxiv.org/abs/2501.12596) """) # --- Few-shot classification logic --- def few_shot_fault_classification( test_images: List[Image.Image], test_image_filenames: List[str], nominal_images: List[Image.Image], nominal_descriptions: List[str], defective_images: List[Image.Image], defective_descriptions: List[str], num_few_shot_nominal_imgs: int, file_path: str = '.', file_name: str = 'image_classification_results.csv', print_one_liner: bool = False ): if not isinstance(test_images, list): test_images = [test_images] if not isinstance(test_image_filenames, list): test_image_filenames = [test_image_filenames] if not isinstance(nominal_images, list): nominal_images = [nominal_images] if not isinstance(nominal_descriptions, list): nominal_descriptions = [nominal_descriptions] if not isinstance(defective_images, list): defective_images = [defective_images] if not isinstance(defective_descriptions, list): defective_descriptions = [defective_descriptions] csv_file = os.path.join(file_path, file_name) results = [] with torch.no_grad(): nominal_features = torch.stack([model.encode_image(img.unsqueeze(0)).squeeze(0).to(device) for img in nominal_images]) nominal_features /= nominal_features.norm(dim=-1, keepdim=True) defective_features = torch.stack([model.encode_image(img.unsqueeze(0)).squeeze(0).to(device) for img in defective_images]) defective_features /= defective_features.norm(dim=-1, keepdim=True) csv_data = [] for idx, test_img in enumerate(test_images): test_features = model.encode_image(test_img.unsqueeze(0)).squeeze(0).to(device) test_features /= test_features.norm(dim=-1, keepdim=True) max_nom_sim, max_def_sim = -float('inf'), -float('inf') max_nom_idx, max_def_idx = -1, -1 for i in range(nominal_features.shape[0]): sim = (test_features @ nominal_features[i].T).item() if sim > max_nom_sim: max_nom_sim, max_nom_idx = sim, i for j in range(defective_features.shape[0]): sim = (test_features @ defective_features[j].T).item() if sim > max_def_sim: max_def_sim, max_def_idx = sim, j similarities = torch.tensor([max_nom_sim, max_def_sim]) probabilities = F.softmax(similarities, dim=0).tolist() prob_nom, prob_def = probabilities classification = "Defective" if prob_def > prob_nom else "Nominal" csv_data.append({ "datetime_of_operation": datetime.now().isoformat(), "num_few_shot_nominal_imgs": num_few_shot_nominal_imgs, "image_path": test_image_filenames[idx], "image_name": test_image_filenames[idx].split('/')[-1], "classification_result": classification, "non_defect_prob": round(prob_nom, 3), "defect_prob": round(prob_def, 3), "nominal_description": nominal_descriptions[max_nom_idx], "defective_description": defective_descriptions[max_def_idx] if defective_images else "N/A" }) if print_one_liner: print(f"{test_image_filenames[idx]} classified as {classification} " f"(Nominal Prob: {prob_nom:.3f}, Defective Prob: {prob_def:.3f})") file_exists = os.path.isfile(csv_file) with open(csv_file, mode='a' if file_exists else 'w', newline='') as file: import csv fieldnames = [ "datetime_of_operation", "num_few_shot_nominal_imgs", "image_path", "image_name", "classification_result", "non_defect_prob", "defect_prob", "nominal_description", "defective_description" ] writer = csv.DictWriter(file, fieldnames=fieldnames) if not file_exists: writer.writeheader() for row in csv_data: writer.writerow(row) return "" # --- App state --- if 'nominal_images' not in st.session_state: st.session_state.nominal_images = [] if 'defective_images' not in st.session_state: st.session_state.defective_images = [] if 'test_images' not in st.session_state: st.session_state.test_images = [] if 'results' not in st.session_state: st.session_state.results = [] # --- Tabs --- tab1, tab2, tab3 = st.tabs(["📥 Upload Reference Images", "🔍 Test Classification", "📊 Results"]) # Tab 1: Upload Reference Images with tab1: st.header("Upload Reference Images") nominal_files = st.file_uploader("Upload Nominal Images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg']) defective_files = st.file_uploader("Upload Defective Images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg']) if nominal_files: st.session_state.nominal_images = [preprocess(Image.open(file).convert("RGB")).to(device) for file in nominal_files] st.session_state.nominal_descriptions = [file.name for file in nominal_files] st.success(f"Uploaded {len(nominal_files)} nominal images.") if defective_files: st.session_state.defective_images = [preprocess(Image.open(file).convert("RGB")).to(device) for file in defective_files] st.session_state.defective_descriptions = [file.name for file in defective_files] st.success(f"Uploaded {len(defective_files)} defective images.") # Tab 2: Test Classification with tab2: st.header("Upload Test Image(s)") test_files = st.file_uploader("Upload Test Images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg']) if st.button("🔍 Run Classification") and test_files: test_images = [preprocess(Image.open(file).convert("RGB")).to(device) for file in test_files] test_filenames = [file.name for file in test_files] few_shot_fault_classification( test_images=test_images, test_image_filenames=test_filenames, nominal_images=st.session_state.nominal_images, nominal_descriptions=st.session_state.nominal_descriptions, defective_images=st.session_state.defective_images, defective_descriptions=st.session_state.defective_descriptions, num_few_shot_nominal_imgs=len(st.session_state.nominal_images), file_path=".", file_name="streamlit_results.csv", print_one_liner=False ) st.success("Classification complete!") st.session_state.results = "streamlit_results.csv" # Tab 3: View/Download Results with tab3: st.header("Classification Results") if os.path.exists("streamlit_results.csv"): df = pd.read_csv("streamlit_results.csv") st.dataframe(df) st.download_button("📥 Download Results", data=df.to_csv(index=False), file_name="classification_results.csv", mime="text/csv") else: st.info("No results yet. Please classify some test images.")