import os import wandb import streamlit as st from transformers import LayoutLMv3Processor, LayoutLMv3ForSequenceClassification from pdf2image import convert_from_bytes from PIL import Image wandb_api_key = os.getenv("WANDB_API_KEY") if not wandb_api_key: st.error( "Couldn't find WanDB API key. Please set it up as an environemnt variable", icon="🚨", ) else: wandb.login(key=wandb_api_key) labels = [ 'budget', 'email', 'form', 'handwritten', 'invoice', 'language', 'letter', 'memo', 'news article', 'questionnaire', 'resume', 'scientific publication', 'specification', ] id2label = {i: label for i, label in enumerate(labels)} label2id = {v: k for k, v in id2label.items()} if 'model' not in st.session_state: st.session_state.model = LayoutLMv3ForSequenceClassification.from_pretrained("model/layoutlmv3/") if 'processor' not in st.session_state: st.session_state.processor = LayoutLMv3Processor.from_pretrained("model/layoutlmv3/") model = st.session_state.model processor = st.session_state.processor st.title("Document Classification with LayoutLMv3") uploaded_file = st.file_uploader( "Upload Document", type=["pdf", "jpg", "png"], accept_multiple_files=False ) feedback_table = wandb.Table(columns=[ 'image', 'filetype', 'predicted_label', 'predicted_label_id', 'correct_label', 'correct_label_id' ]) if 'wandb_run' not in st.session_state: st.session_state.wandb_run = wandb.init(project='hydra-classifier', name='feedback-loop') @st.cache_data def classify_image(_image): print(f'Encoding image with index {i}') encoding = processor( image, return_tensors="pt", truncation=True, max_length=512, ) print(f'Predicting image with index {i}') outputs = model(**encoding) prediction = outputs.logits.argmax(-1)[0].item() return prediction if uploaded_file: if uploaded_file.type == "application/pdf": images = convert_from_bytes(uploaded_file.getvalue()) else: images = [Image.open(uploaded_file)] for i, image in enumerate(images): st.image(image, caption=f'Uploaded Image {i}', use_container_width=True) prediction = classify_image(image) st.write(f"Prediction: {id2label[prediction]}") feedback = st.radio( "Is the classification correct?", ("Yes", "No"), key=f'prediction-{i}' ) if feedback == "No": correct_label = st.selectbox( "Please select the correct label:", labels, key=f'selectbox-{i}' ) print(f'Correct label for image {i}: {correct_label}') # Add a button to confirm feedback and log it if st.button(f"Add feedback for Image {i}", key=f'add-{i}'): feedback_table.add_data( wandb.Image(image), uploaded_file.type, id2label[prediction], prediction, correct_label, label2id[correct_label], ) if st.button("Submit all feedback", key=f'submit'): run = st.session_state.wandb_run run.log({'feedback_table': feedback_table}) run.finish() st.success(f"Feedback submitted!")