File size: 2,352 Bytes
6998869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
import gradio as gr

# Define the model and processor
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
API_TOKEN = "jPXZV69OTMUOmNTVOhX0B4770c3EjpnH"  # Replace with your Hugging Face API token

PROCESSOR = AutoProcessor.from_pretrained(
    "HuggingFaceM4/VLM_WebSight_finetuned",
    token=API_TOKEN,
)
MODEL = AutoModelForCausalLM.from_pretrained(
    "HuggingFaceM4/VLM_WebSight_finetuned",
    token=API_TOKEN,
    trust_remote_code=True,
).to(DEVICE)

image_seq_len = MODEL.config.perceiver_config.resampler_n_latents
BOS_TOKEN = PROCESSOR.tokenizer.bos_token
BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids

# Image preprocessing
def convert_to_rgb(image):
    if image.mode == "RGB":
        return image
    image_rgba = image.convert("RGBA")
    background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
    alpha_composite = Image.alpha_composite(background, image_rgba)
    return alpha_composite.convert("RGB")

def custom_transform(x):
    x = convert_to_rgb(x)
    x = x.resize((960, 960), Image.BILINEAR)
    x = torch.tensor(x).permute(2, 0, 1) / 255.0
    x = (x - PROCESSOR.image_processor.image_mean[:, None, None]) / PROCESSOR.image_processor.image_std[:, None, None]
    return x.unsqueeze(0)

# Function to generate HTML/CSS code
def generate_code(image):
    inputs = PROCESSOR.tokenizer(
        f"{BOS_TOKEN}<fake_token_around_image>{'<image>' * image_seq_len}<fake_token_around_image>",
        return_tensors="pt",
        add_special_tokens=False,
    )
    inputs["pixel_values"] = custom_transform(image).to(DEVICE)
    inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
    generated_ids = MODEL.generate(**inputs, bad_words_ids=BAD_WORDS_IDS, max_length=4096)
    generated_text = PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_text

# Gradio Interface
iface = gr.Interface(
    fn=generate_code,
    inputs=gr.inputs.Image(type="pil"),
    outputs="text",
    title="WebInsight - Generate HTML/CSS from Mockup",
    description="Upload a website component image to generate corresponding HTML/CSS code."
)

iface.launch()