update readme.me with cpu/gpu runtime guide
Browse files
README.md
CHANGED
@@ -30,9 +30,16 @@ pip install pip --upgrade && pip install transformers --upgrade
|
|
30 |
``` Python
|
31 |
#Load model
|
32 |
import transformers, torch
|
|
|
|
|
|
|
33 |
compute_dtype = torch.float16
|
|
|
|
|
|
|
|
|
|
|
34 |
cache_path = ''
|
35 |
-
device = 'cuda'
|
36 |
model_id = "mobiuslabsgmbh/aanaphi2-v0.1"
|
37 |
model = transformers.AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=compute_dtype,
|
38 |
cache_dir=cache_path,
|
@@ -50,7 +57,7 @@ model.eval();
|
|
50 |
@torch.no_grad()
|
51 |
def generate(prompt, max_length=1024):
|
52 |
prompt_chat = prompt_format(prompt)
|
53 |
-
inputs = tokenizer(prompt_chat, return_tensors="pt", return_attention_mask=True).to(
|
54 |
outputs = model.generate(**inputs, max_length=max_length, eos_token_id= tokenizer.eos_token_id)
|
55 |
text = tokenizer.batch_decode(outputs[:,:-1])[0]
|
56 |
return text
|
|
|
30 |
``` Python
|
31 |
#Load model
|
32 |
import transformers, torch
|
33 |
+
|
34 |
+
#GPU runtime
|
35 |
+
device = 'cuda'
|
36 |
compute_dtype = torch.float16
|
37 |
+
|
38 |
+
##CPU runtime
|
39 |
+
#device = 'cpu'
|
40 |
+
#compute_dtype = torch.float32
|
41 |
+
|
42 |
cache_path = ''
|
|
|
43 |
model_id = "mobiuslabsgmbh/aanaphi2-v0.1"
|
44 |
model = transformers.AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=compute_dtype,
|
45 |
cache_dir=cache_path,
|
|
|
57 |
@torch.no_grad()
|
58 |
def generate(prompt, max_length=1024):
|
59 |
prompt_chat = prompt_format(prompt)
|
60 |
+
inputs = tokenizer(prompt_chat, return_tensors="pt", return_attention_mask=True).to(device)
|
61 |
outputs = model.generate(**inputs, max_length=max_length, eos_token_id= tokenizer.eos_token_id)
|
62 |
text = tokenizer.batch_decode(outputs[:,:-1])[0]
|
63 |
return text
|