Spaces:
Runtime error
Runtime error
File size: 967 Bytes
8589c7a 54c64de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import gradio as gr
#gr.load("models/mrm8488/bertin-gpt-j-6B-ES-8bit").launch()
import gradio as gr
import torch
from transformers import AutoTokenizer, GPTJForCausalLM
from Utils import GPTJBlock # Assuming Utils.py is in the same directory
device = "cuda" if torch.cuda.is_available() else "cpu"
# Monkey-patch GPT-J
transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock
ckpt = "mrm8488/bertin-gpt-j-6B-ES-8bit"
tokenizer = AutoTokenizer.from_pretrained(ckpt)
model = GPTJForCausalLM.from_pretrained(ckpt, pad_token_id=tokenizer.eos_token_id, low_cpu_mem_usage=True).to(device)
def generate_text(prompt):
prompt = tokenizer(prompt, return_tensors='pt')
prompt = {key: value.to(device) for key, value in prompt.items()}
out = model.generate(**prompt, max_length=64, do_sample=True)
return tokenizer.decode(out[0])
iface = gr.Interface(
fn=generate_text,
inputs="text",
outputs="text",
live=True
)
iface.launch()
|