SoSa123456's picture
Update app.py
54c64de
raw
history blame contribute delete
967 Bytes
import gradio as gr
#gr.load("models/mrm8488/bertin-gpt-j-6B-ES-8bit").launch()
import gradio as gr
import torch
from transformers import AutoTokenizer, GPTJForCausalLM
from Utils import GPTJBlock # Assuming Utils.py is in the same directory
device = "cuda" if torch.cuda.is_available() else "cpu"
# Monkey-patch GPT-J
transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock
ckpt = "mrm8488/bertin-gpt-j-6B-ES-8bit"
tokenizer = AutoTokenizer.from_pretrained(ckpt)
model = GPTJForCausalLM.from_pretrained(ckpt, pad_token_id=tokenizer.eos_token_id, low_cpu_mem_usage=True).to(device)
def generate_text(prompt):
prompt = tokenizer(prompt, return_tensors='pt')
prompt = {key: value.to(device) for key, value in prompt.items()}
out = model.generate(**prompt, max_length=64, do_sample=True)
return tokenizer.decode(out[0])
iface = gr.Interface(
fn=generate_text,
inputs="text",
outputs="text",
live=True
)
iface.launch()