File size: 4,008 Bytes
a970429 207e0eb a970429 207e0eb a970429 207e0eb a405953 207e0eb 5d9d006 6863e73 760adf8 6863e73 207e0eb 5d9d006 207e0eb a970429 a405953 a970429 a405953 a970429 a405953 a970429 a405953 a970429 a405953 a970429 207e0eb 760adf8 207e0eb 760adf8 207e0eb a970429 207e0eb a970429 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import spaces
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
title = """# Minitron-8B-Base Story Generator"""
description = """
# Minitron
Minitron is a family of small language models (SLMs) obtained by pruning [NVIDIA's](https://huggingface.co/nvidia) Nemotron-4 15B model. We prune model embedding size, attention heads, and MLP intermediate dimension, following which, we perform continued training with distillation to arrive at the final models.
# Short Story Generator
Welcome to the Short Story Generator! This application helps you create unique short stories based on your inputs.
**Instructions:**
1. **Main Character:** Describe the main character of your story. For example, "a brave knight" or "a curious cat".
2. **Setting:** Describe the setting where your story takes place. For example, "in an enchanted forest" or "in a bustling city".
3. **Plot Twist:** Add an interesting plot twist to make the story exciting. For example, "discovers a hidden treasure" or "finds a secret portal to another world".
After filling in these details, click the "Submit" button, and a short story will be generated for you.
"""
inputs = [
gr.Textbox(label="Main Character", placeholder="e.g. a brave knight"),
gr.Textbox(label="Setting", placeholder="e.g. in an enchanted forest"),
gr.Textbox(label="Plot Twist", placeholder="e.g. discovers a hidden treasure"),
gr.Slider(minimum=1, maximum=2048, value=64, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
]
outputs = gr.Textbox(label="Generated Story")
# Load the tokenizer and model
model_path = "nvidia/Minitron-8B-Base"
tokenizer = AutoTokenizer.from_pretrained(model_path)
device='cuda'
dtype=torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=dtype, device_map=device)
# Define the prompt format
def create_prompt(instruction):
PROMPT = '''Below is an instruction that describes a task.\n\nWrite a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:'''
return PROMPT.format(instruction=instruction)
@spaces.GPU
def respond(message, history, system_message, max_tokens, temperature, top_p):
prompt = create_prompt(message)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
output_ids = model.generate(input_ids, max_length=50, num_return_sequences=1)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return output_text
@spaces.GPU
def generate_story(character, setting, plot_twist, max_tokens, temperature, top_p):
"""Define the function to generate the story."""
prompt = f"Write a short story with the following details:\nMain character: {character}\nSetting: {setting}\nPlot twist: {plot_twist}\n\nStory:"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
output_ids = model.generate(input_ids, max_length=max_tokens, num_return_sequences=1, temperature=temperature, top_p=top_p)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return output_text
#demo = gr.ChatInterface(
# title=gr.Markdown(title),
# description=gr.Markdown(description),
# fn=generate_story,
# additional_inputs=[
# gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
# gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
# gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
# ],
#)
# Create the Gradio interface
demo = gr.Interface(
fn=generate_story,
inputs=inputs,
outputs=outputs,
title="Short Story Generator",
description=description
)
if __name__ == "__main__":
demo.launch() |