Update README.md
Browse files
README.md
CHANGED
@@ -160,6 +160,43 @@ prompt = "Schreibe eine Stellenanzeige für Data Scientist bei AXA!"
|
|
160 |
final_prompt = prompt_template.format(prompt=prompt)
|
161 |
```
|
162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
### German benchmarks
|
164 |
|
165 |
| **German tasks:** | **MMLU-DE** | **Hellaswag-DE** | **ARC-DE** |**Average** |
|
|
|
160 |
final_prompt = prompt_template.format(prompt=prompt)
|
161 |
```
|
162 |
|
163 |
+
#### Limit the model to output reply-only:
|
164 |
+
To solve this, you need to implement a custom stopping criteria:
|
165 |
+
|
166 |
+
```python
|
167 |
+
from transformers import StoppingCriteria
|
168 |
+
class GermeoStoppingCriteria(StoppingCriteria):
|
169 |
+
def __init__(self, target_sequence, prompt):
|
170 |
+
self.target_sequence = target_sequence
|
171 |
+
self.prompt=prompt
|
172 |
+
|
173 |
+
def __call__(self, input_ids, scores, **kwargs):
|
174 |
+
# Get the generated text as a string
|
175 |
+
generated_text = tokenizer.decode(input_ids[0])
|
176 |
+
generated_text = generated_text.replace(self.prompt,'')
|
177 |
+
# Check if the target sequence appears in the generated text
|
178 |
+
if self.target_sequence in generated_text:
|
179 |
+
return True # Stop generation
|
180 |
+
|
181 |
+
return False # Continue generation
|
182 |
+
|
183 |
+
def __len__(self):
|
184 |
+
return 1
|
185 |
+
|
186 |
+
def __iter__(self):
|
187 |
+
yield self
|
188 |
+
```
|
189 |
+
This then expects your input prompt (formatted as given into the model), and a stopping criteria, in this case the im_end token. Simply add it to the generation:
|
190 |
+
|
191 |
+
```python
|
192 |
+
generation_output = model.generate(
|
193 |
+
tokens,
|
194 |
+
streamer=streamer,
|
195 |
+
max_new_tokens=1012,
|
196 |
+
stopping_criteria=GermeoStoppingCriteria("<|im_end|>", prompt_template.format(prompt=prompt))
|
197 |
+
)
|
198 |
+
```
|
199 |
+
|
200 |
### German benchmarks
|
201 |
|
202 |
| **German tasks:** | **MMLU-DE** | **Hellaswag-DE** | **ARC-DE** |**Average** |
|