Update README.md
Browse files
README.md
CHANGED
@@ -16,13 +16,12 @@ pipeline_tag: text2text-generation
|
|
16 |
import torch
|
17 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
18 |
|
19 |
-
device='cuda'
|
20 |
tokenizer = AutoTokenizer.from_pretrained('TeraSpace/dialofred')
|
21 |
-
model = AutoModelForSeq2SeqLM.from_pretrained('TeraSpace/dialofred').to(device)
|
22 |
while True:
|
23 |
text_inp = input("=>")
|
24 |
lm_text=f'<SC1>- {text_inp}\n- <extra_id_0>'
|
25 |
-
input_ids=torch.tensor([tokenizer.encode(lm_text)]).to(device)
|
26 |
# outputs=model.generate(input_ids=input_ids,
|
27 |
# max_length=200,
|
28 |
# eos_token_id=tokenizer.eos_token_id,
|
|
|
16 |
import torch
|
17 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
18 |
|
|
|
19 |
tokenizer = AutoTokenizer.from_pretrained('TeraSpace/dialofred')
|
20 |
+
model = AutoModelForSeq2SeqLM.from_pretrained('TeraSpace/dialofred', device_map={'':0}).to(device)# Add torch_dtype=torch.bfloat16 to use less memory
|
21 |
while True:
|
22 |
text_inp = input("=>")
|
23 |
lm_text=f'<SC1>- {text_inp}\n- <extra_id_0>'
|
24 |
+
input_ids=torch.tensor([tokenizer.encode(lm_text)]).to(model.device)
|
25 |
# outputs=model.generate(input_ids=input_ids,
|
26 |
# max_length=200,
|
27 |
# eos_token_id=tokenizer.eos_token_id,
|