from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification import torch from PIL import Image import numpy as np import gradio as gr # Load the model and tokenizer tokenizer = AutoTokenizer.from_pretrained("neulab/UIX-Qwen2") model = AutoModel.from_pretrained("neulab/UIX-Qwen2") # Function to preprocess the image (for simplicity, assume basic resizing) def preprocess_image(image): # Resize the image to the expected input size (placeholder, adjust for actual size needed by the model) image = image.resize((224, 224)) # Example size image = np.array(image).astype(np.float32) / 255.0 # Normalize to [0, 1] image = torch.tensor(image).permute(2, 0, 1).unsqueeze(0) # Convert to tensor, add batch dim return image # Function to predict coordinates based on screenshot and prompt def predict_coordinates(screenshot, prompt): # Preprocess the image (screenshot) image_tensor = preprocess_image(screenshot) # Tokenize the prompt (text input) inputs = tokenizer(prompt, return_tensors="pt") # Assuming model accepts both image and text as input (adjust according to model's actual input requirement) outputs = model(**inputs, pixel_values=image_tensor) # The output could be logits or raw coordinates; we assume coordinates here (adjust based on model output) coordinates = outputs.logits # Placeholder: adapt to actual model's coordinate prediction output # Convert logits to coordinates (this is an example, adjust based on model's actual output format) x, y = torch.argmax(coordinates, dim=-1).tolist() # Example conversion to (x, y) return {"x": x, "y": y} # Gradio Interface with gr.Blocks() as demo: gr.Markdown("# UIX-Qwen2: Predict Coordinates for UI Interactions") with gr.Row(): with gr.Column(): screenshot = gr.Image(type="pil", label="Upload Screenshot") prompt = gr.Textbox(label="Prompt (e.g., 'Click on Submit button')") with gr.Column(): output = gr.JSON(label="Predicted Coordinates (x, y)") submit_button = gr.Button("Get Coordinates") submit_button.click(predict_coordinates, inputs=[screenshot, prompt], outputs=output) # Launch the Gradio app demo.launch()