CreitinGameplays's picture
Update app.py
4950e4d verified
raw
history blame
2.53 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Define the BLOOM model name
model_name = "CreitinGameplays/bloom-3b-conversational"
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
def generate_text(user_prompt):
"""Generates text using the BLOOM model from Hugging Face Transformers and removes the user prompt."""
# Construct the full prompt with system introduction, user prompt, and assistant role
prompt = f"<|system|> You are a helpful AI assistant. </s> <|prompter|> {user_prompt} </s> <|assistant|>"
encoded_prompt = tokenizer(prompt, return_tensors="pt").input_ids
# Initialize variables for real-time generation
generated_text = ""
current_output = torch.tensor([tokenizer.encode("<|assistant|>", return_tensors="pt").input_ids[0]])
for char in user_prompt:
# Encode character and concatenate with previous output
encoded_char = torch.tensor([tokenizer.encode(char, return_tensors="pt").input_ids[0]])
input_ids = torch.cat((current_output, encoded_char), dim=-1)
# Generate text with the current prompt and encoded character
output = model.generate(
input_ids=input_ids,
max_length=256,
num_beams=1,
num_return_sequences=1,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.2,
repetition_penalty=1.155
)
# Decode the generated token sequence back to text
decoded_text = tokenizer.decode(output[0], skip_special_tokens=True)
# Extract and update generated text, removing special tokens
generated_text += decoded_text.split("<|assistant|>")[-1].strip()
current_output = input_ids
# Remove prompt and user input from final response
assistant_response = generated_text.replace(f"{user_prompt}", "").strip()
assistant_response = assistant_response.replace("You are a helpful AI assistant.", "").strip()
return assistant_response
# Define the Gradio interface with streaming enabled
interface = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(label="Text Prompt", value="", type="verbatim"), # Set type to "verbatim" for character-by-character input
],
outputs="text",
description="Interact with BLOOM-3b-conversational (Loaded with Hugging Face Transformers)",
**{"allow_user_code": True}, # Enable user code execution for real-time updates
)
# Launch the Gradio interface
interface.launch()