File size: 4,689 Bytes
59696ca
e128669
eb0afed
 
3b81b26
f8b495d
e128669
eb0afed
 
8615850
 
 
 
 
 
 
3b81b26
 
eb0afed
59696ca
f8b495d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb0afed
 
 
 
f8b495d
 
 
eb0afed
 
 
 
 
 
f8b495d
eb0afed
 
f8b495d
eb0afed
 
 
 
 
 
f8b495d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8615850
 
 
f8b495d
 
 
 
 
 
 
 
 
8615850
f8b495d
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
#import peft
import transformers
import os
import re

device = "cpu"
is_peft = False
model_id = os.environ.get("MODEL_ID") or "treadon/prompt-fungineer-355M"
auth_token = os.environ.get("HUB_TOKEN") or True

print(f"Using model {model_id}.")

if auth_token != True:
    print("Using auth token.")

model = transformers.AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True,use_auth_token=auth_token)
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")


def format_prompt(prompt, enhancers=True, inspiration=False, negative_prompt=False):
    try:
        pattern = r"(BRF:|POS:|ENH:|INS:|NEG:) (.*?)(?= (BRF:|POS:|ENH:|INS:|NEG:)|$)"
        matches = re.findall(pattern, prompt)
        vals = {key: value.strip() for key, value,ex in matches}
        result = vals["POS:"]
        if enhancers:
            result += " " + vals["ENH:"]
        if inspiration:
            result += " " + vals["INS:"]
        if negative_prompt:
            result += "\n\n--no " + vals["NEG:"]

        return result
    except Exception as e:
        return "Failed to generate prompt."

    
def generate_text(prompt, extra=False, top_k=100, top_p=0.95, temperature=0.85, enhancers = True, inpspiration = False , negative_prompt = False):
    
    if not prompt.startswith("BRF:"):
        prompt = "BRF: " + prompt

    if not extra:
        prompt = prompt + " POS:"

    model.eval()
    # SOFT SAMPLE
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    samples = []
    try:
        for i in range(1):
            outputs = model.generate(**inputs, max_length=256, do_sample=True, top_k=top_k, top_p=top_p, temperature=temperature, num_return_sequences=4, pad_token_id=tokenizer.eos_token_id)
            for output in outputs:
                sample = tokenizer.decode(output, skip_special_tokens=True)
                sample = format_prompt(sample, enhancers, inpspiration, negative_prompt)
                samples.append(sample)
    except Exception as e:
        print(e)

    return samples


with gr.Blocks() as fungineer:
    with gr.Row():
        gr.Markdown("""# Midjourney / Dalle 2 / Stable Diffusion Prompt Generator

This is the 355M parameter model.  There is also a 7B parameter model that is much better but far slower (access coming soon).

Just enter a basic prompt and the fungineering model will use its wildest imagination to expand the prompt in detail.""")
    with gr.Row():
        with gr.Column():
            base_prompt = gr.Textbox(lines=5, label="Base Prompt", placeholder="An astronaut in space", info="Enter a very simple prompt that will be fungineered into something exciting!")
            extra = gr.Checkbox(value=True, label="Extra Fungineer Imagination", info="If checked, the model will be allowed to go wild with its imagination.")
            with gr.Accordion("Advanced Generation Settings", open=False):
                top_k = gr.Slider( minimum=10, maximum=1000, value=100, label="Top K", info="Top K sampling")
                top_p = gr.Slider( minimum=0.1, maximum=1, value=0.95, step=0.01, label="Top P", info="Top P sampling")
                temperature = gr.Slider( minimum=0.1, maximum=1.2, value=0.85, step=0.01, label="Temperature", info="Temperature sampling.  Higher values will make the model more creative")

            with gr.Accordion("Advanced Output Settings", open=False):
                enh = gr.Checkbox(value=True, label="Enhancers", info="Add image meta information such as lens type, shuffter speed, camera model, etc.")
                insp = gr.Checkbox(value=False, label="Inpsiration", info="Include inspirational photographers that are known for this type of photography.  Sometimes random people will appear here, needs more training.")
                neg = gr.Checkbox(value=False, label="Negative Prompt", info="Include a negative prompt, more often used in Stable Diffusion.  If you're a Stable Diffusion user, chances are you already have a better negative prompt you like to use.")

        with gr.Column():
            outputs = [
                gr.Textbox(lines=5, label="Fungineered Text 1"),
                gr.Textbox(lines=5, label="Fungineered Text 2"),
                gr.Textbox(lines=5, label="Fungineered Text 3"),
                gr.Textbox(lines=5, label="Fungineered Text 4"),
            ]

    inputs = [base_prompt, extra, top_k, top_p, temperature, enh, insp, neg]


    submit = gr.Button(label="Fungineer",variant="primary")
    submit.click(generate_text, inputs=inputs, outputs=outputs)

fungineer.launch(enable_queue=True)