Update README.md
Browse files
README.md
CHANGED
@@ -26,6 +26,7 @@ Despite the powerful capabilities of this model, users should be aware of its li
|
|
26 |
How to use
|
27 |
Here is an example of how to use this model:
|
28 |
|
|
|
29 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
30 |
import time
|
31 |
import torch
|
@@ -33,9 +34,7 @@ import torch
|
|
33 |
class Chatbot:
|
34 |
def __init__(self, model_name):
|
35 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
|
36 |
-
self.model = AutoModelForCausalLM.from_pretrained(model_name,
|
37 |
-
load_in_4bit=True,
|
38 |
-
torch_dtype=torch.bfloat16)
|
39 |
if self.tokenizer.pad_token_id is None:
|
40 |
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
41 |
|
@@ -45,9 +44,9 @@ class Chatbot:
|
|
45 |
inputs = {name: tensor.to('cuda') for name, tensor in inputs.items()}
|
46 |
start_time = time.time()
|
47 |
tokens = self.model.generate(input_ids=inputs['input_ids'],
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
end_time = time.time()
|
52 |
output_tokens = tokens[0][inputs['input_ids'].shape[-1]:]
|
53 |
output = self.tokenizer.decode(output_tokens, skip_special_tokens=True)
|
|
|
26 |
How to use
|
27 |
Here is an example of how to use this model:
|
28 |
|
29 |
+
```python
|
30 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
31 |
import time
|
32 |
import torch
|
|
|
34 |
class Chatbot:
|
35 |
def __init__(self, model_name):
|
36 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
|
37 |
+
self.model = AutoModelForCausalLM.from_pretrained(model_name, load_in_4bit=True, torch_dtype=torch.bfloat16)
|
|
|
|
|
38 |
if self.tokenizer.pad_token_id is None:
|
39 |
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
40 |
|
|
|
44 |
inputs = {name: tensor.to('cuda') for name, tensor in inputs.items()}
|
45 |
start_time = time.time()
|
46 |
tokens = self.model.generate(input_ids=inputs['input_ids'],
|
47 |
+
attention_mask=inputs['attention_mask'],
|
48 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
49 |
+
max_new_tokens=400)
|
50 |
end_time = time.time()
|
51 |
output_tokens = tokens[0][inputs['input_ids'].shape[-1]:]
|
52 |
output = self.tokenizer.decode(output_tokens, skip_special_tokens=True)
|