djrana's picture
Create app.py
5b5b708 verified
raw
history blame
798 Bytes
import gradio as gr
from transformers import pipeline
pipe = pipeline('text-generation', model_id='Ar4ikov/gpt2-650k-stable-diffusion-prompt-generator')
def get_valid_prompt(text: str) -> str:
dot_split = text.split('.')[0]
n_split = text.split('\n')[0]
return {
len(dot_split) < len(n_split): dot_split,
len(n_split) > len(dot_split): n_split,
len(n_split) == len(dot_split): dot_split
}[True]
def generate_prompt(prompt):
valid_prompt = get_valid_prompt(pipe(prompt, max_length=77)[0]['generated_text'])
return valid_prompt
iface = gr.Interface(
fn=generate_prompt,
inputs="text",
outputs="text",
title="Prompt Generator",
description="Enter a prompt and get the valid prompt generated by the script."
)
iface.launch()