Spaces:
Runtime error
Runtime error
import gradio as gr | |
#gr.load("models/mrm8488/bertin-gpt-j-6B-ES-8bit").launch() | |
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, GPTJForCausalLM | |
from Utils import GPTJBlock # Assuming Utils.py is in the same directory | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Monkey-patch GPT-J | |
transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock | |
ckpt = "mrm8488/bertin-gpt-j-6B-ES-8bit" | |
tokenizer = AutoTokenizer.from_pretrained(ckpt) | |
model = GPTJForCausalLM.from_pretrained(ckpt, pad_token_id=tokenizer.eos_token_id, low_cpu_mem_usage=True).to(device) | |
def generate_text(prompt): | |
prompt = tokenizer(prompt, return_tensors='pt') | |
prompt = {key: value.to(device) for key, value in prompt.items()} | |
out = model.generate(**prompt, max_length=64, do_sample=True) | |
return tokenizer.decode(out[0]) | |
iface = gr.Interface( | |
fn=generate_text, | |
inputs="text", | |
outputs="text", | |
live=True | |
) | |
iface.launch() | |