import gradio as gr import base64 import json import os import shutil import uuid import glob from huggingface_hub import CommitScheduler, HfApi, snapshot_download from pathlib import Path import git from datasets import Dataset, Features, Value, Sequence, Image as ImageFeature import threading import time from utils import process_and_push_dataset from datasets import load_dataset api = HfApi(token=os.environ["HF_TOKEN"]) VALID_DATASET = load_dataset("taesiri/IERv2-Subset-Validation-150", split="train") VALID_DATASET_POST_IDS = ( load_dataset( "taesiri/IERv2-Subset-Validation-150", split="train", columns=["post_id"] ) .to_pandas()["post_id"] .tolist() ) POST_ID_TO_ID_MAP = {post_id: idx for idx, post_id in enumerate(VALID_DATASET_POST_IDS)} DATASET_REPO = "taesiri/AIImageEditingResults_Intemediate2" FINAL_DATASET_REPO = "taesiri/AIImageEditingResults" # Download existing data from hub def sync_with_hub(): """ Synchronize local data with the hub by cloning the dataset repo """ print("Starting sync with hub...") data_dir = Path("./data") if data_dir.exists(): # Backup existing data backup_dir = Path("./data_backup") if backup_dir.exists(): shutil.rmtree(backup_dir) shutil.copytree(data_dir, backup_dir) # Clone/pull latest data from hub # Use token in the URL for authentication following HF's new format token = os.environ["HF_TOKEN"] username = "taesiri" # Extract from DATASET_REPO repo_url = f"https://{username}:{token}@huggingface.co/datasets/{DATASET_REPO}" hub_data_dir = Path("hub_data") if hub_data_dir.exists(): # If repo exists, do a git pull print("Pulling latest changes...") repo = git.Repo(hub_data_dir) origin = repo.remotes.origin # Set the new URL with token if "https://" in origin.url: origin.set_url(repo_url) origin.pull() else: # Clone the repo with token print("Cloning repository...") git.Repo.clone_from(repo_url, hub_data_dir) # Merge hub data with local data hub_data_source = hub_data_dir / "data" if hub_data_source.exists(): # Create data dir if it doesn't exist data_dir.mkdir(exist_ok=True) # Copy files from hub for item in hub_data_source.glob("*"): if item.is_dir(): dest = data_dir / item.name if not dest.exists(): # Only copy if doesn't exist locally shutil.copytree(item, dest) # Clean up cloned repo if hub_data_dir.exists(): shutil.rmtree(hub_data_dir) print("Finished syncing with hub!") scheduler = CommitScheduler( repo_id=DATASET_REPO, repo_type="dataset", folder_path="./data", path_in_repo="data", every=1, ) def load_question_data(question_id): """ Load a specific question's data Returns a tuple of all form fields """ if not question_id: return [None] * 11 # Reduced number of fields # Extract the ID part before the colon from the dropdown selection question_id = ( question_id.split(":")[0].strip() if ":" in question_id else question_id ) json_path = os.path.join("./data", question_id, "question.json") if not os.path.exists(json_path): print(f"Question file not found: {json_path}") return [None] * 11 try: with open(json_path, "r", encoding="utf-8") as f: data = json.loads(f.read().strip()) # Load images def load_image(image_path): if not image_path: return None full_path = os.path.join( "./data", question_id, os.path.basename(image_path) ) return full_path if os.path.exists(full_path) else None question_images = data.get("question_images", []) rationale_images = data.get("rationale_images", []) return [ ( ",".join(data["question_categories"]) if isinstance(data["question_categories"], list) else data["question_categories"] ), data["question"], data["final_answer"], data.get("rationale_text", ""), load_image(question_images[0] if question_images else None), load_image(question_images[1] if len(question_images) > 1 else None), load_image(question_images[2] if len(question_images) > 2 else None), load_image(question_images[3] if len(question_images) > 3 else None), load_image(rationale_images[0] if rationale_images else None), load_image(rationale_images[1] if len(rationale_images) > 1 else None), question_id, ] except Exception as e: print(f"Error loading question {question_id}: {str(e)}") return [None] * 11 def load_post_image(post_id): if not post_id: return [ None ] * 33 # source image + instruction + simplified_instruction + 10 triplets idx = POST_ID_TO_ID_MAP[post_id] source_image = VALID_DATASET[idx]["image"] instruction = VALID_DATASET[idx]["instruction"] simplified_instruction = VALID_DATASET[idx]["simplified_instruction"] # Load existing responses if any post_folder = os.path.join("./data", str(post_id)) metadata_path = os.path.join(post_folder, "metadata.json") if os.path.exists(metadata_path): with open(metadata_path, "r") as f: metadata = json.load(f) # Initialize response data responses = [(None, "", "")] * 10 # Initialize with empty notes # Fill in existing responses for response in metadata["responses"]: idx = response["response_id"] if idx < 10: # Ensure we don't exceed our UI limit image_path = os.path.join(post_folder, response["image_path"]) responses[idx] = ( image_path, response["answer_text"], response.get("notes", ""), ) # Flatten responses for output flat_responses = [item for triplet in responses for item in triplet] return [source_image, instruction, simplified_instruction] + flat_responses # If no existing responses, return source image, instructions and empty responses return [source_image, instruction, simplified_instruction] + [None] * 30 def generate_json_files(source_image, responses, post_id): """ Save the source image and multiple responses to the data directory Args: source_image: Path to the source image responses: List of (image, answer, notes) tuples post_id: The post ID from the dataset """ # Create parent data folder if it doesn't exist parent_data_folder = "./data" os.makedirs(parent_data_folder, exist_ok=True) # Create/clear post_id folder post_folder = os.path.join(parent_data_folder, str(post_id)) if os.path.exists(post_folder): shutil.rmtree(post_folder) os.makedirs(post_folder) # Save source image source_image_path = os.path.join(post_folder, "source_image.png") if isinstance(source_image, str): shutil.copy2(source_image, source_image_path) else: gr.processing_utils.save_image(source_image, source_image_path) # Create responses data responses_data = [] for idx, (response_image, answer_text, notes) in enumerate(responses): if response_image and answer_text: # Only process if both image and text exist response_folder = os.path.join(post_folder, f"response_{idx}") os.makedirs(response_folder) # Save response image response_image_path = os.path.join(response_folder, "response_image.png") if isinstance(response_image, str): shutil.copy2(response_image, response_image_path) else: gr.processing_utils.save_image(response_image, response_image_path) # Add to responses data responses_data.append( { "response_id": idx, "answer_text": answer_text, "notes": notes, "image_path": f"response_{idx}/response_image.png", } ) # Create metadata JSON metadata = { "post_id": post_id, "source_image": "source_image.png", "responses": responses_data, } # Save metadata with open(os.path.join(post_folder, "metadata.json"), "w", encoding="utf-8") as f: json.dump(metadata, f, ensure_ascii=False, indent=2) return post_folder def get_statistics(): """ Scan the data folder and return statistics about the responses """ data_dir = Path("./data") if not data_dir.exists(): return "No data directory found" total_expected_posts = len(VALID_DATASET_POST_IDS) processed_post_ids = set() posts_with_responses = 0 total_responses = 0 responses_per_post = [] # List to track number of responses for each post for metadata_file in data_dir.glob("*/metadata.json"): post_id = metadata_file.parent.name if post_id in VALID_DATASET_POST_IDS: # Only count valid posts processed_post_ids.add(post_id) try: with open(metadata_file, "r") as f: metadata = json.load(f) num_responses = len(metadata.get("responses", [])) responses_per_post.append(num_responses) if num_responses > 0: posts_with_responses += 1 total_responses += num_responses except: continue missing_posts = set(map(str, VALID_DATASET_POST_IDS)) - processed_post_ids total_processed = len(processed_post_ids) # Calculate additional statistics if responses_per_post: responses_per_post.sort() median_responses = responses_per_post[len(responses_per_post) // 2] max_responses = max(responses_per_post) avg_responses = ( total_responses / posts_with_responses if posts_with_responses > 0 else 0 ) else: median_responses = max_responses = avg_responses = 0 stats = f""" 📊 Collection Statistics: Dataset Coverage: - Total Expected Posts: {total_expected_posts} - Posts Processed: {total_processed} - Missing Posts: {len(missing_posts)} ({', '.join(list(missing_posts)[:5])}{'...' if len(missing_posts) > 5 else ''}) - Coverage Rate: {(total_processed/total_expected_posts*100):.2f}% Response Statistics: - Posts with Responses: {posts_with_responses} - Posts without Responses: {total_processed - posts_with_responses} - Total Individual Responses: {total_responses} Response Distribution: - Median Responses per Post: {median_responses} - Average Responses per Post: {avg_responses:.2f} - Maximum Responses for a Post: {max_responses} """ return stats # Build the Gradio app with gr.Blocks() as demo: gr.Markdown("# Image Response Collector") # Source image selection at the top with gr.Row(): with gr.Column(): post_id_dropdown = gr.Dropdown( label="Select Post ID to Load Image", choices=VALID_DATASET_POST_IDS, type="value", allow_custom_value=False, ) instruction_text = gr.Textbox(label="Instruction", interactive=False) simplified_instruction_text = gr.Textbox( label="Simplified Instruction", interactive=False ) source_image = gr.Image(label="Source Image", type="filepath", height=300) # Responses in tabs with gr.Tabs() as response_tabs: responses = [] for i in range(10): with gr.Tab(f"Response {i+1}"): img = gr.Image( label=f"Response Image {i+1}", type="filepath", height=300 ) txt = gr.Textbox(label=f"Model Name {i+1}", lines=2) notes = gr.Textbox(label=f"Miscellaneous Notes {i+1}", lines=3) responses.append((img, txt, notes)) with gr.Row(): submit_btn = gr.Button("Submit All Responses") clear_btn = gr.Button("Clear Form") # Add statistics accordion with gr.Accordion("Collection Statistics", open=False): stats_text = gr.Markdown("Loading statistics...") refresh_stats_btn = gr.Button("Refresh Statistics") def update_stats(): return get_statistics() refresh_stats_btn.click(fn=update_stats, outputs=[stats_text]) # Move the load event inside the Blocks context demo.load( fn=get_statistics, outputs=[stats_text], ) def submit_responses( source_img, post_id, instruction, simplified_instruction, *response_data ): if not source_img: gr.Warning("Please select a source image first!") return if not post_id: gr.Warning("Please select a post ID first!") return # Convert flat response_data into triplets of (image, text, notes) response_triplets = list( zip(response_data[::3], response_data[1::3], response_data[2::3]) ) # Check for responses with images but no model names incomplete_responses = [ i + 1 for i, (img, txt, _) in enumerate(response_triplets) if img is not None and not txt.strip() ] if incomplete_responses: gr.Warning( f"Please provide model names for responses: {', '.join(map(str, incomplete_responses))}!" ) return # Filter out empty responses (where both image and model name are empty) valid_responses = [ (img, txt, notes) for img, txt, notes in response_triplets if img is not None and txt.strip() ] if not valid_responses: gr.Warning("Please provide at least one response (image + model name)!") return # Generate JSON files with the valid responses generate_json_files(source_img, valid_responses, post_id) gr.Info("Responses saved successfully! 🎉") def clear_form(): outputs = [None] * ( 1 + 2 + 30 ) # source image + 2 instruction fields + 10 triplets return outputs # Connect components post_id_dropdown.change( fn=load_post_image, inputs=[post_id_dropdown], outputs=[source_image, instruction_text, simplified_instruction_text] + [comp for triplet in responses for comp in triplet], ) submit_inputs = [ source_image, post_id_dropdown, instruction_text, simplified_instruction_text, ] + [comp for triplet in responses for comp in triplet] submit_btn.click(fn=submit_responses, inputs=submit_inputs) clear_outputs = [source_image, instruction_text, simplified_instruction_text] + [ comp for triplet in responses for comp in triplet ] clear_btn.click(fn=clear_form, outputs=clear_outputs) def process_thread(): while True: try: pass # process_and_push_dataset( # "./data", # FINAL_DATASET_REPO, # token=os.environ["HF_TOKEN"], # private=True, # ) except Exception as e: print(f"Error in process thread: {e}") time.sleep(120) # Sleep for 2 minutes if __name__ == "__main__": print("Initializing app...") sync_with_hub() # Sync before launching the app print("Starting Gradio interface...") # Start the processing thread when the app starts processing_thread = threading.Thread(target=process_thread, daemon=True) processing_thread.start() demo.launch()