koziev ilya
commited on
Commit
·
e3f5def
1
Parent(s):
a31c007
немного причесал код, убрал лишние манипуляции с выдачей gpt
Browse files
README.md
CHANGED
|
@@ -44,6 +44,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
| 44 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 45 |
|
| 46 |
tokenizer = AutoTokenizer.from_pretrained("inkoziev/rugpt_interpreter")
|
|
|
|
| 47 |
model = AutoModelForCausalLM.from_pretrained("inkoziev/rugpt_interpreter")
|
| 48 |
model.to(device)
|
| 49 |
|
|
@@ -51,8 +52,10 @@ model.to(device)
|
|
| 51 |
# В конце добавляем символ "#"
|
| 52 |
input_text = """<s>- Как тебя зовут?
|
| 53 |
- Джульетта Мао #"""
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
| 56 |
|
| 57 |
output_sequences = model.generate(
|
| 58 |
input_ids=encoded_prompt,
|
|
@@ -63,12 +66,10 @@ output_sequences = model.generate(
|
|
| 63 |
repetition_penalty=1.2,
|
| 64 |
do_sample=True,
|
| 65 |
num_return_sequences=1,
|
| 66 |
-
pad_token_id=
|
| 67 |
)
|
| 68 |
|
| 69 |
-
|
| 70 |
-
text = tokenizer.decode(output_sequences[0].tolist(), clean_up_tokenization_spaces=True)
|
| 71 |
text = text[: text.find('</s>')]
|
| 72 |
-
text = text[text.find('#')+1:].strip() # Результат генерации содержит входную строку, поэтому отрезаем ее до символа "#".
|
| 73 |
print(text)
|
| 74 |
```
|
|
|
|
| 44 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 45 |
|
| 46 |
tokenizer = AutoTokenizer.from_pretrained("inkoziev/rugpt_interpreter")
|
| 47 |
+
tokenizer.add_special_tokens({'bos_token': '<s>', 'eos_token': '</s>', 'pad_token': '<pad>'})
|
| 48 |
model = AutoModelForCausalLM.from_pretrained("inkoziev/rugpt_interpreter")
|
| 49 |
model.to(device)
|
| 50 |
|
|
|
|
| 52 |
# В конце добавляем символ "#"
|
| 53 |
input_text = """<s>- Как тебя зовут?
|
| 54 |
- Джульетта Мао #"""
|
| 55 |
+
#input_text = """<s>- Что Предтечи забрали у Предшественников?
|
| 56 |
+
#- Они узурпировали у них Мантию — защиту всего живого в галактике #"""
|
| 57 |
+
|
| 58 |
+
encoded_prompt = tokenizer.encode(input_text, add_special_tokens=False, return_tensors="pt").to(device)
|
| 59 |
|
| 60 |
output_sequences = model.generate(
|
| 61 |
input_ids=encoded_prompt,
|
|
|
|
| 66 |
repetition_penalty=1.2,
|
| 67 |
do_sample=True,
|
| 68 |
num_return_sequences=1,
|
| 69 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 70 |
)
|
| 71 |
|
| 72 |
+
text = tokenizer.decode(output_sequences[0].tolist(), clean_up_tokenization_spaces=True)[len(input_text)+1:]
|
|
|
|
| 73 |
text = text[: text.find('</s>')]
|
|
|
|
| 74 |
print(text)
|
| 75 |
```
|