Spaces:
Runtime error
Runtime error
import sys | |
sys.path.insert(0, './petals/') | |
import torch | |
import transformers | |
import gradio as gr | |
from src.client.remote_model import DistributedBloomForCausalLM | |
MODEL_NAME = "bigscience/bloom-petals" | |
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME) | |
model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME) | |
def inference(text, seq_length=1): | |
input_ids = tokenizer([text], return_tensors="pt").input_ids | |
output = model.generate(input_ids, max_new_tokens=seq_length) | |
return tokenizer.batch_decode(output)[0] | |
iface = gr.Interface( | |
fn=inference, | |
inputs=[ | |
gr.Textbox(lines=10, label="Input text"), | |
gr.inputs.Slider( | |
minimum=0, | |
maximum=1000, | |
step=1, | |
default=42, | |
label="Sequence length for generation" | |
) | |
], | |
outputs="text" | |
) | |
iface.launch() |