File size: 3,560 Bytes
e3475d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import logging
import pathlib
import gradio as gr
import pandas as pd
from gt4sd.algorithms.generation.hugging_face import (
    HuggingFaceCTRLGenerator,
    HuggingFaceGenerationAlgorithm,
    HuggingFaceGPT2Generator,
    HuggingFaceTransfoXLGenerator,
    HuggingFaceOpenAIGPTGenerator,
    HuggingFaceXLMGenerator,
    HuggingFaceXLNetGenerator,
)
from gt4sd.algorithms.registry import ApplicationsRegistry


logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

MODEL_FN = {
    "HuggingFaceCTRLGenerator": HuggingFaceCTRLGenerator,
    "HuggingFaceGPT2Generator": HuggingFaceGPT2Generator,
    "HuggingFaceTransfoXLGenerator": HuggingFaceTransfoXLGenerator,
    "HuggingFaceOpenAIGPTGenerator": HuggingFaceOpenAIGPTGenerator,
    "HuggingFaceXLMGenerator": HuggingFaceXLMGenerator,
    "HuggingFaceXLNetGenerator": HuggingFaceXLNetGenerator,
}


def run_inference(
    model_type: str,
    prompt: str,
    length: float,
    temperature: float,
    prefix: str,
    k: float,
    p: float,
    repetition_penalty: float,
):
    model = model_type.split("_")[0]
    version = model_type.split("_")[1]

    if model not in MODEL_FN.keys():
        raise ValueError(f"Model type {model} not supported")
    config = MODEL_FN[model](
        algorithm_version=version,
        prompt=prompt,
        length=length,
        temperature=temperature,
        repetition_penalty=repetition_penalty,
        k=k,
        p=p,
        prefix=prefix,
    )

    model = HuggingFaceGenerationAlgorithm(config)
    text = list(model.sample(1))[0]

    return text


if __name__ == "__main__":

    # Preparation (retrieve all available algorithms)
    all_algos = ApplicationsRegistry.list_available()
    algos = [
        x["algorithm_application"] + "_" + x["algorithm_version"]
        for x in list(filter(lambda x: "HuggingFace" in x["algorithm_name"], all_algos))
    ]

    # Load metadata
    metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")

    examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna(
        ""
    )
    print("Examples: ", examples.values.tolist())

    with open(metadata_root.joinpath("article.md"), "r") as f:
        article = f.read()
    with open(metadata_root.joinpath("description.md"), "r") as f:
        description = f.read()

    demo = gr.Interface(
        fn=run_inference,
        title="HuggingFace language models",
        inputs=[
            gr.Dropdown(
                algos,
                label="Language model",
                value="HuggingFaceGPT2Generator_gpt2",
            ),
            gr.Textbox(
                label="Text prompt",
                placeholder="I'm a stochastic parrot.",
                lines=1,
            ),
            gr.Slider(minimum=5, maximum=100, value=20, label="Maximal length", step=1),
            gr.Slider(
                minimum=0.6, maximum=1.5, value=1.1, label="Decoding temperature"
            ),
            gr.Textbox(
                label="Prefix", placeholder="Some prefix (before the prompt)", lines=1
            ),
            gr.Slider(minimum=2, maximum=500, value=50, label="Top-k", step=1),
            gr.Slider(minimum=0.5, maximum=1, value=1.0, label="Decoding-p", step=1),
            gr.Slider(minimum=0.5, maximum=5, value=1.0, label="Repetition penalty"),
        ],
        outputs=gr.Textbox(label="Output"),
        article=article,
        description=description,
        examples=examples.values.tolist(),
    )
    demo.launch(debug=True, show_error=True)