chat-gradio / app.py
artek0chumak's picture
another test
e3e6a48
raw
history blame
1.1 kB
import sys
sys.path.insert(0, './petals/')
import torch
import transformers
import gradio as gr
from src.client.remote_model import DistributedBloomForCausalLM
# MODEL_NAME = "bigscience/test-bloomd-6b3" # select model you like
# INITIAL_PEERS = ["/ip4/193.106.95.184/tcp/31000/p2p/QmSg7izCDtowVTACbUmWvEiQZNY4wgCQ9T9Doo66K59X6q"]
# tokenizer = transformers.BloomTokenizerFast.from_pretrained("bigscience/test-bloomd-6b3")
# model = DistributedBloomForCausalLM.from_pretrained("bigscience/test-bloomd-6b3", initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32)
def inference(text, seq_length=1):
# input_ids = tokenizer([text], return_tensors="pt").input_ids
# output = model.generate(input_ids, max_new_tokens=seq_length)
# return tokenizer.batch_decode(output)[0]
return text
iface = gr.Interface(
fn=inference,
inputs=[
gr.Textbox(lines=10, label="Input text"),
gr.inputs.Slider(
minimum=0,
maximum=1000,
step=1,
default=42,
label="Sequence length for generation"
)
],
outputs="text"
)
iface.launch()