File size: 3,943 Bytes
d09f0be
 
0d85cab
d09f0be
 
 
 
 
 
f3a7c09
d09f0be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
862bdb6
d09f0be
 
 
 
 
 
 
 
 
 
 
 
862bdb6
d09f0be
 
 
 
 
 
 
 
 
 
 
862bdb6
d09f0be
 
 
 
 
 
 
381f478
d09f0be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3a7c09
 
 
 
 
1db420b
f3a7c09
d09f0be
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from diffusers import DiffusionPipeline
import torch
import random
import numpy as np
import importlib.util
import sys
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import os
from torchvision.utils import save_image, make_grid
from PIL import Image
from safetensors.torch import load_file
from .vq_model import VQ_models
from .arpg import ARPG_models

# inheriting from DiffusionPipeline for HF
class ARPGModel(DiffusionPipeline):

    def __init__(self):
        super().__init__()

    @torch.no_grad()
    def __call__(self, *args, **kwargs):
        """
        This method downloads the model and VAE components,
        then executes the forward pass based on the user's input.
        """
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # init the mar model architecture
        model_type = kwargs.get("model_type", "ARPG-XXL")

        # download the pretrained model and set diffloss parameters
        if model_type == "ARPG-L":
            model_path = "arpg_300m.pt"
        elif model_type == "ARPG-XL":
            model_path = "arpg_700m.pt"
        elif model_type == "ARPG-XXL":
            model_path = "arpg_1b.pt"
        else:
            raise NotImplementedError

        # download and load the model weights (.safetensors or .pth)
        model_checkpoint_path = hf_hub_download(
            repo_id=kwargs.get("repo_id", "hp-l33/ARPG"),
            filename=kwargs.get("model_filename", model_path)
        )

        model_fn = ARPG_models[model_type]
        model = model_fn(
          num_classes=1000,
          vocab_size=16384
        ).cuda()

        state_dict = torch.load(model_checkpoint_path)['state_dict']
        model.load_state_dict(state_dict)
        model.eval()

        # download and load the vae
        vae_checkpoint_path = hf_hub_download(
            repo_id=kwargs.get("repo_id", "FoundationVision/LlamaGen"),
            filename=kwargs.get("vae_filename", "vq_ds16_c2i.pt")
        )

        vae = VQ_models['VQ-16']()

        vae_state_dict = torch.load(vae_checkpoint_path)['model']
        vae.load_state_dict(vae_state_dict)
        vae = vae.to(device).eval()

        # set up user-specified or default values for generation
        seed = kwargs.get("seed", 6)
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)

        num_steps = kwargs.get("num_steps", 64)
        cfg_scale = kwargs.get("cfg_scale", 4)
        cfg_schedule = kwargs.get("cfg_schedule", "constant")
        sample_schedule = kwargs.get("sample_schedule", "arccos")
        temperature = kwargs.get("temperature", 1.0)
        top_k = kwargs.get("top_k", 600)
        class_labels = kwargs.get("class_labels", [207, 360, 388, 113, 355, 980, 323, 979])

        # generate the tokens and images
        with torch.cuda.amp.autocast():
            sampled_tokens = model.generate(
               condition=torch.Tensor(class_labels).long().cuda(),
               num_iter=num_steps,
               guidance_scale=cfg_scale,
               cfg_schedule=cfg_schedule,
               sample_schedule=sample_schedule,
               temperature=temperature,
               top_k=top_k,
            )
            sampled_images = vae.decode_code(sampled_tokens, shape=(len(class_labels), 8, 16, 16))

        output_dir = kwargs.get("output_dir", "./")
        os.makedirs(output_dir, exist_ok=True)
    
        # save the images
        image_path = os.path.join(output_dir, "sampled_image.png")
        samples_per_row = kwargs.get("samples_per_row", 4)
    
        ndarr = make_grid(
          torch.clamp(127.5 * sampled_images + 128.0, 0, 255),
          nrow=int(samples_per_row)
        ).permute(1, 2, 0).to("cpu", dtype=torch.uint8).numpy()

        Image.fromarray(ndarr).save(image_path)

        # return as a pil image
        image = Image.open(image_path)
    
        return image