Update README.md
Browse files
README.md
CHANGED
|
@@ -62,11 +62,7 @@ The following hyperparameters were used during training:
|
|
| 62 |
- rougeLsum: 33.414162%
|
| 63 |
|
| 64 |
### How to use
|
| 65 |
-
|
| 66 |
-
Even though the model checkpoint is small, a huge input would crash the memory. Batching the inputs is advised.
|
| 67 |
-
Rule of thumb is that T4 can handle at a time a list of upto 14-15 elements, with each elements having 4000 words.
|
| 68 |
-
|
| 69 |
-
Note 'max_new_tokens=60' is used in the example below to limit the summary size. BART model has max generation length = 142 (default) and min generation length = 56.
|
| 70 |
|
| 71 |
```python
|
| 72 |
import torch
|
|
@@ -85,9 +81,10 @@ tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
|
|
| 85 |
model = PeftModel.from_pretrained(model, peft_model_id, device_map='auto')
|
| 86 |
|
| 87 |
# Tokenize the text inputs
|
| 88 |
-
texts = "<e.g.
|
| 89 |
-
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True
|
| 90 |
|
|
|
|
| 91 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 92 |
with torch.no_grad():
|
| 93 |
output = self.model.generate(input_ids=inputs["input_ids"].to(device), max_new_tokens=60, do_sample=True, top_p=0.9)
|
|
|
|
| 62 |
- rougeLsum: 33.414162%
|
| 63 |
|
| 64 |
### How to use
|
| 65 |
+
Note 'max_new_tokens=60' is used in the example below to control the summary size. BART model has max generation length = 142 (default) and min generation length = 56.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
```python
|
| 68 |
import torch
|
|
|
|
| 81 |
model = PeftModel.from_pretrained(model, peft_model_id, device_map='auto')
|
| 82 |
|
| 83 |
# Tokenize the text inputs
|
| 84 |
+
texts = "<e.g. Transcript>"
|
| 85 |
+
inputs = tokenizer(texts, return_tensors="pt", padding=True, ) # truncation=True
|
| 86 |
|
| 87 |
+
# Make inferences
|
| 88 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 89 |
with torch.no_grad():
|
| 90 |
output = self.model.generate(input_ids=inputs["input_ids"].to(device), max_new_tokens=60, do_sample=True, top_p=0.9)
|