SoSa123456 commited on
Commit
54c64de
1 Parent(s): 8589c7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -1
app.py CHANGED
@@ -1,3 +1,35 @@
1
  import gradio as gr
2
 
3
- gr.load("models/mrm8488/bertin-gpt-j-6B-ES-8bit").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
+ #gr.load("models/mrm8488/bertin-gpt-j-6B-ES-8bit").launch()
4
+
5
+ import gradio as gr
6
+ import torch
7
+ from transformers import AutoTokenizer, GPTJForCausalLM
8
+
9
+ from Utils import GPTJBlock # Assuming Utils.py is in the same directory
10
+
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ # Monkey-patch GPT-J
14
+ transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock
15
+
16
+ ckpt = "mrm8488/bertin-gpt-j-6B-ES-8bit"
17
+
18
+ tokenizer = AutoTokenizer.from_pretrained(ckpt)
19
+ model = GPTJForCausalLM.from_pretrained(ckpt, pad_token_id=tokenizer.eos_token_id, low_cpu_mem_usage=True).to(device)
20
+
21
+ def generate_text(prompt):
22
+ prompt = tokenizer(prompt, return_tensors='pt')
23
+ prompt = {key: value.to(device) for key, value in prompt.items()}
24
+ out = model.generate(**prompt, max_length=64, do_sample=True)
25
+ return tokenizer.decode(out[0])
26
+
27
+ iface = gr.Interface(
28
+ fn=generate_text,
29
+ inputs="text",
30
+ outputs="text",
31
+ live=True
32
+ )
33
+
34
+ iface.launch()
35
+