import torch from PIL import Image from transformers import AutoModelForCausalLM, AutoProcessor import gradio as gr # Define the model and processor DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") API_TOKEN = "jPXZV69OTMUOmNTVOhX0B4770c3EjpnH" # Replace with your Hugging Face API token PROCESSOR = AutoProcessor.from_pretrained( "HuggingFaceM4/VLM_WebSight_finetuned", token=API_TOKEN, ) MODEL = AutoModelForCausalLM.from_pretrained( "HuggingFaceM4/VLM_WebSight_finetuned", token=API_TOKEN, trust_remote_code=True, ).to(DEVICE) image_seq_len = MODEL.config.perceiver_config.resampler_n_latents BOS_TOKEN = PROCESSOR.tokenizer.bos_token BAD_WORDS_IDS = PROCESSOR.tokenizer(["", ""], add_special_tokens=False).input_ids # Image preprocessing def convert_to_rgb(image): if image.mode == "RGB": return image image_rgba = image.convert("RGBA") background = Image.new("RGBA", image_rgba.size, (255, 255, 255)) alpha_composite = Image.alpha_composite(background, image_rgba) return alpha_composite.convert("RGB") def custom_transform(x): x = convert_to_rgb(x) x = x.resize((960, 960), Image.BILINEAR) x = torch.tensor(x).permute(2, 0, 1) / 255.0 x = (x - PROCESSOR.image_processor.image_mean[:, None, None]) / PROCESSOR.image_processor.image_std[:, None, None] return x.unsqueeze(0) # Function to generate HTML/CSS code def generate_code(image): inputs = PROCESSOR.tokenizer( f"{BOS_TOKEN}{'' * image_seq_len}", return_tensors="pt", add_special_tokens=False, ) inputs["pixel_values"] = custom_transform(image).to(DEVICE) inputs = {k: v.to(DEVICE) for k, v in inputs.items()} generated_ids = MODEL.generate(**inputs, bad_words_ids=BAD_WORDS_IDS, max_length=4096) generated_text = PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0] return generated_text # Gradio Interface iface = gr.Interface( fn=generate_code, inputs=gr.inputs.Image(type="pil"), outputs="text", title="WebInsight - Generate HTML/CSS from Mockup", description="Upload a website component image to generate corresponding HTML/CSS code." ) iface.launch()