Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
datasets:
|
4 |
+
- emozilla/booksum-summary-analysis_gptneox-8192
|
5 |
+
- kmfoda/booksum
|
6 |
+
---
|
7 |
+
|
8 |
+
# mpt-7b-storysummarizer
|
9 |
+
|
10 |
+
This is a fine-tuned version of [mosaicml/mpt-7b-storywriter](https://huggingface.co/mosaicml/mpt-7b-storywriter) on [emozilla/booksum-summary-analysis_gptneox-8192](emozilla/booksum-summary-analysis_gptneox-8192), which is adapted from [kmfoda/booksum](https://huggingface.co/datasets/kmfoda/booksum).
|
11 |
+
The training run was performed using [llm-foundry](https://github.com/mosaicml/llm-foundry) on an 8xA100 80 GB node at 8192 context length. The run can be viewed on [wandb](https://wandb.ai/emozilla/booksum/runs/457ym4r9).
|
12 |
+
|
13 |
+
## How to Use
|
14 |
+
|
15 |
+
This model is intended for summarization and literary analysis of fiction stories. It can be prompted in one of two ways:
|
16 |
+
|
17 |
+
```
|
18 |
+
SOME_FICTION
|
19 |
+
|
20 |
+
### SUMMARY:
|
21 |
+
```
|
22 |
+
|
23 |
+
or
|
24 |
+
|
25 |
+
```
|
26 |
+
SOME_FICTION
|
27 |
+
|
28 |
+
### ANALYSIS:
|
29 |
+
```
|
30 |
+
|
31 |
+
A `repetition_penalty` of ~1.04 seems to be best. For summary prompts, simple greedy search suffices while a temperature of 0.8 works well for analysis.
|
32 |
+
The model often prints `'#'` to delinate the end of a a summary or analyis. You can use `transformers.StopOnTokens` to end a generation.
|
33 |
+
|
34 |
+
```python
|
35 |
+
class StopOnTokens(StoppingCriteria):
|
36 |
+
def __init__(self, stop_ids):
|
37 |
+
self.stop_ids = stop_ids
|
38 |
+
|
39 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
40 |
+
for stop_id in self.stop_ids:
|
41 |
+
if input_ids[0][-1] == stop_id:
|
42 |
+
return True
|
43 |
+
return False
|
44 |
+
|
45 |
+
stop_ids = tokenizer("#").input_ids
|
46 |
+
stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)]),
|
47 |
+
```
|
48 |
+
|
49 |
+
Pass `stopping_criteria` as an argument to the model's `generate` function to stop on `#`.
|
50 |
+
|
51 |
+
The code for this model includes adaptions from [Birchlabs/mosaicml-mpt-7b-chat-qlora](https://huggingface.co/Birchlabs/mosaicml-mpt-7b-chat-qlora) which allow MPT models to be loaded with `device_map="auto"` and `load_in_8bit=True`.
|
52 |
+
For longer contexts, the following is recommended:
|
53 |
+
|
54 |
+
```python
|
55 |
+
tokenizer = AutoTokenizer.from_pretrained("emozilla/mpt-7b-storysummarizer")
|
56 |
+
model = AutoModelForCausalLM.from_pretrained(
|
57 |
+
"emozilla/mpt-7b-storysummarizer",
|
58 |
+
load_in_8bit=True,
|
59 |
+
trust_remote_code=True,
|
60 |
+
device_map="auto")
|
61 |
+
```
|