|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
|
|
def load_model(model_name="gpt2"): |
|
"""Load a GPT-2 model and tokenizer from Hugging Face.""" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
return pipeline("text-generation", model=model, tokenizer=tokenizer) |
|
|
|
|
|
generator = load_model() |
|
|
|
def generate_text(prompt, max_length=100, temperature=1.0, top_p=0.9): |
|
""" |
|
Generates text based on the prompt using a GPT-2 model. |
|
Args: |
|
prompt (str): Input text from the user. |
|
max_length (int): Max tokens in the prompt + generation. |
|
temperature (float): Controls randomness. |
|
top_p (float): Nucleus sampling hyperparameter. |
|
Returns: |
|
str: Generated text from GPT-2. |
|
""" |
|
results = generator( |
|
prompt, |
|
max_length=max_length, |
|
temperature=temperature, |
|
top_p=top_p, |
|
num_return_sequences=1, |
|
|
|
pad_token_id=generator.tokenizer.eos_token_id |
|
) |
|
return results[0]["generated_text"] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
# Educational GPT-2 Demo |
|
This demo demonstrates how a smaller Large Language Model (GPT-2) predicts text. |
|
Change the parameters below to see how the model's output is affected: |
|
- **Max Length** controls the total number of tokens in the output. |
|
- **Temperature** controls randomness (higher means more creative/chaotic). |
|
- **Top-p** controls the diversity of tokens (lower means more conservative choices). |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt = gr.Textbox( |
|
lines=4, |
|
label="Prompt", |
|
placeholder="Type a prompt here", |
|
value="Once upon a time," |
|
) |
|
max_len = gr.Slider( |
|
minimum=20, |
|
maximum=200, |
|
value=100, |
|
step=1, |
|
label="Max Length" |
|
) |
|
temp = gr.Slider( |
|
minimum=0.1, |
|
maximum=2.0, |
|
value=1.0, |
|
step=0.1, |
|
label="Temperature" |
|
) |
|
top_p = gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.9, |
|
step=0.05, |
|
label="Top-p" |
|
) |
|
generate_button = gr.Button("Generate") |
|
|
|
with gr.Column(): |
|
output_box = gr.Textbox( |
|
label="Generated Text", |
|
lines=10 |
|
) |
|
|
|
generate_button.click( |
|
fn=generate_text, |
|
inputs=[prompt, max_len, temp, top_p], |
|
outputs=[output_box] |
|
) |
|
|
|
demo.launch() |
|
|