File size: 4,008 Bytes
a970429
 
 
 
 
207e0eb
a970429
207e0eb
 
a970429
207e0eb
 
 
 
 
 
 
 
 
 
a405953
 
207e0eb
5d9d006
 
6863e73
760adf8
 
6863e73
207e0eb
 
5d9d006
207e0eb
a970429
 
 
a405953
a970429
 
 
a405953
a970429
 
 
 
a405953
a970429
 
 
 
 
a405953
a970429
a405953
a970429
 
207e0eb
 
 
760adf8
207e0eb
 
 
 
760adf8
207e0eb
 
 
 
a970429
207e0eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a970429
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import spaces
import gradio as gr  
import torch  
from transformers import AutoTokenizer, AutoModelForCausalLM

title = """# Minitron-8B-Base Story Generator"""
description = """
# Minitron

Minitron is a family of small language models (SLMs) obtained by pruning [NVIDIA's](https://huggingface.co/nvidia) Nemotron-4 15B model. We prune model embedding size, attention heads, and MLP intermediate dimension, following which, we perform continued training with distillation to arrive at the final models.

# Short Story Generator
Welcome to the Short Story Generator! This application helps you create unique short stories based on your inputs.

**Instructions:**
1. **Main Character:** Describe the main character of your story. For example, "a brave knight" or "a curious cat".
2. **Setting:** Describe the setting where your story takes place. For example, "in an enchanted forest" or "in a bustling city".
3. **Plot Twist:** Add an interesting plot twist to make the story exciting. For example, "discovers a hidden treasure" or "finds a secret portal to another world".

After filling in these details, click the "Submit" button, and a short story will be generated for you.
"""

inputs = [
    gr.Textbox(label="Main Character", placeholder="e.g. a brave knight"),
    gr.Textbox(label="Setting", placeholder="e.g. in an enchanted forest"),
    gr.Textbox(label="Plot Twist", placeholder="e.g. discovers a hidden treasure"),
    gr.Slider(minimum=1, maximum=2048, value=64, step=1, label="Max new tokens"),
    gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
    gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
]

outputs = gr.Textbox(label="Generated Story")

# Load the tokenizer and model
model_path = "nvidia/Minitron-8B-Base"
tokenizer = AutoTokenizer.from_pretrained(model_path)

device='cuda'
dtype=torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=dtype, device_map=device)

# Define the prompt format  
def create_prompt(instruction):  
    PROMPT = '''Below is an instruction that describes a task.\n\nWrite a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:'''  
    return PROMPT.format(instruction=instruction)  

@spaces.GPU  
def respond(message, history, system_message, max_tokens, temperature, top_p):  
    prompt = create_prompt(message)  
      
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)

    output_ids = model.generate(input_ids, max_length=50, num_return_sequences=1)

    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
      
    return output_text 

@spaces.GPU 
def generate_story(character, setting, plot_twist, max_tokens, temperature, top_p):
    """Define the function to generate the story."""
    prompt = f"Write a short story with the following details:\nMain character: {character}\nSetting: {setting}\nPlot twist: {plot_twist}\n\nStory:"
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)

    output_ids = model.generate(input_ids, max_length=max_tokens, num_return_sequences=1, temperature=temperature, top_p=top_p)

    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
      
    return output_text 
  
#demo = gr.ChatInterface(
#    title=gr.Markdown(title),
#    description=gr.Markdown(description),
#    fn=generate_story,  
#    additional_inputs=[
#        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),  
#        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),  
#        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")  
#    ],  
#)  

# Create the Gradio interface
demo = gr.Interface(
    fn=generate_story,
    inputs=inputs,
    outputs=outputs,
    title="Short Story Generator",
    description=description
)
  
if __name__ == "__main__":  
    demo.launch()