AiPrompt / app.py
SpawnedShoyo's picture
Update app.py
1eb749b verified
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
def load_model(model_name):
try:
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
return pipeline("text-generation", model=model, tokenizer=tokenizer)
except Exception as e:
return str(e)
def refine_prompt(user_prompt, model_name):
# Load the specified model
text_generator = load_model(model_name)
if isinstance(text_generator, str): # If there's an error loading the model
return text_generator
# Define the guidelines
guidelines = (
"Refine the following prompt according to these guidelines:\n"
"1. Be concise\n"
"2. Be specific and well-defined\n"
"3. Ask one task at a time\n"
"4. Turn generative tasks into classification tasks\n"
"5. Improve response quality by including examples\n\n"
f"Original Prompt: {user_prompt}\n"
"Refined Prompt:"
)
# Generate the refined prompt
refined_prompt = text_generator(guidelines, max_length=100, num_return_sequences=1)[0]['generated_text']
# Extract the refined prompt from the generated text
refined_prompt = refined_prompt.split("Refined Prompt:")[-1].strip()
return refined_prompt
# Create a Gradio interface
iface = gr.Interface(
fn=refine_prompt,
inputs=[
gr.Textbox(label="User Prompt", placeholder="Enter your prompt here..."),
gr.Textbox(label="Model Name", placeholder="Enter Hugging Face model name (e.g., gpt2, distilgpt2)...")
],
outputs="text",
title="Prompt Refinement Tool",
description="Input a prompt and model name to get a refined version that follows specific guidelines."
)
# Launch the app
if __name__ == "__main__":
iface.launch()