|
--- |
|
license: apache-2.0 |
|
datasets: |
|
- wikitext |
|
- ptb_text_only |
|
language: |
|
- en |
|
metrics: |
|
- perplexity |
|
pipeline_tag: text-generation |
|
model-index: |
|
- name: distilgpt2 |
|
results: |
|
- task: |
|
type: text-generation |
|
dataset: |
|
name: penn_treebank |
|
type: ptb_text_only |
|
metrics: |
|
- name: perlexity@BASELINE |
|
type: dmx-perlexity |
|
value: 63.45857238769531 |
|
- name: perlexity@FALLBACK |
|
type: dmx-perlexity |
|
value: 64.36720275878906 |
|
- task: |
|
type: text-generation |
|
dataset: |
|
name: wikitext2 |
|
type: wikitext-2-raw-v1 |
|
metrics: |
|
- name: perlexity@BASELINE |
|
type: dmx-perlexity |
|
value: 46.05925369262695 |
|
- name: perlexity@FALLBACK |
|
type: dmx-perlexity |
|
value: 46.570838928222656 |
|
--- |
|
This is a quantized version of [DistilGPT2](https://huggingface.co/distilbert/distilgpt2). We provide the following two quantization configurations: |
|
|
|
BASELINE: Everything in original format, equivalent to original model. |
|
|
|
FALLBACK: Quantized Linear and Conv1D layers to BFP16. Added approximation functions for Layer Norm, GELU and Softmax. |
|
|
|
### Usage Example |
|
|
|
Prerequisites: |
|
- Install dmx-mltools: "pip install dmx-mltools" |
|
- clone this repo. "cd" to the cloned repo. |
|
```python |
|
>>> import os |
|
>>> from mltools import dmx |
|
>>> from transformers import pipeline |
|
>>> import evaluate |
|
>>> from datasets import load_dataset |
|
|
|
>>> my_hf_token = os.environ.get("HUGGING_FACE_HUB_TOKEN") |
|
|
|
>>> pipe = pipeline( |
|
>>> "text-generation", |
|
>>> model="d-matrix/distilgpt2", |
|
>>> use_auth_token=my_hf_token, |
|
>>> trust_remote_code=True, |
|
>>> # device_map="auto", # use this line for enabling pipeline parallel |
|
>>> ) |
|
>>> pipe.model = dmx.Model( |
|
>>> pipe.model, monkey_patched=False, hf=True, input_names=["input_ids", "labels"] |
|
>>> ) |
|
|
|
>>> pipe.model.transform("FALLBACK.yaml") |
|
|
|
>>> perplexity = evaluate.load("d-matrix/dmx_perplexity", module_type="metric") |
|
>>> input_texts = load_dataset("ptb_text_only", "penn_treebank", split="test")["sentence"] |
|
>>> results = perplexity.compute(model=pipe.model.body, references=input_texts) |
|
>>> print(results) |
|
{'loss': 4.164604187011719, 'perplexity': 64.36720275878906} |
|
``` |