TeraSpace commited on
Commit
6e0fe75
1 Parent(s): 4d20203

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +2 -3
README.md CHANGED
@@ -16,13 +16,12 @@ pipeline_tag: text2text-generation
16
  import torch
17
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
18
 
19
- device='cuda'
20
  tokenizer = AutoTokenizer.from_pretrained('TeraSpace/dialofred')
21
- model = AutoModelForSeq2SeqLM.from_pretrained('TeraSpace/dialofred').to(device)
22
  while True:
23
  text_inp = input("=>")
24
  lm_text=f'<SC1>- {text_inp}\n- <extra_id_0>'
25
- input_ids=torch.tensor([tokenizer.encode(lm_text)]).to(device)
26
  # outputs=model.generate(input_ids=input_ids,
27
  # max_length=200,
28
  # eos_token_id=tokenizer.eos_token_id,
 
16
  import torch
17
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
18
 
 
19
  tokenizer = AutoTokenizer.from_pretrained('TeraSpace/dialofred')
20
+ model = AutoModelForSeq2SeqLM.from_pretrained('TeraSpace/dialofred', device_map={'':0}).to(device)# Add torch_dtype=torch.bfloat16 to use less memory
21
  while True:
22
  text_inp = input("=>")
23
  lm_text=f'<SC1>- {text_inp}\n- <extra_id_0>'
24
+ input_ids=torch.tensor([tokenizer.encode(lm_text)]).to(model.device)
25
  # outputs=model.generate(input_ids=input_ids,
26
  # max_length=200,
27
  # eos_token_id=tokenizer.eos_token_id,