File size: 4,910 Bytes
36e4899
f741a32
39e064f
 
f741a32
 
36e4899
f741a32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a128f1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f741a32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
---
language: zh
tags:
- VAE
- Generation
inference: False
---

# Randeng-DELLA-226M-Chinese

- Github: [Fengshenbang-LM](https://github.com/IDEA-CCNL/Fengshenbang-LM)
- Docs: [Fengshenbang-Docs](https://fengshenbang-doc.readthedocs.io/)

## 简介 Brief Introduction

在悟道数据集上进行通用预训练的Deep VAE模型。其中编码器和解码器都是GPT-2架构。可以用于下游的句子重写,语义转换,性质控制等任务。

A deep VAE model pretrained on Wudao dataset. Both encoder and decoder are based on GPT-2 architecture. Such model is particularly suitable for paraphrasing, semantic updating and fine-grained attributes control.

## 模型分类 Model Taxonomy

|  需求 Demand  | 任务 Task       | 系列 Series      | 模型 Model    | 参数 Parameter | 额外 Extra |
|  :----:  | :----:  | :----:  | :----:  | :----:  | :----:  |
| 通用 General | 自然语言生成 NLG | 燃灯 Randeng | DELLA |     226M      |    变分自编码器-中文 VAE-Chinese    |


## 模型信息 Model Information

参考论文:[Fuse It More Deeply! A Variational Transformer with Layer-Wise Latent Variable Inference for Text Generation](https://arxiv.org/abs/2207.06130)



## 使用 Usage

```python
# Checkout the latest Fengshenbang-LM directory and run following script under Fengshenbang-LM root directory 
import sys
import torch
import argparse
from torch.nn.utils.rnn import pad_sequence
from fengshen.models.deepVAE.vae_pl_module import DeepVAEModule



if __name__ == "__main__":
    # TODO: Update this path to the downloaded directory
    checkpoint_path = '..../Randeng-DELLA-226M-Chinese'
    gpt2_model_path = '..../Randeng-DELLA-226M-Chinese'

    args_parser = argparse.ArgumentParser()
    args_parser.add_argument("--checkpoint_path", type=str, default=checkpoint_path)
    args_parser.add_argument("--gpt2_model_path", type=str, default=gpt2_model_path)
    args_parser.add_argument("--latent_dim", type=int, default=256)
    args_parser.add_argument("--beta_kl_constraints_start", type=float, default=1e-5)
    args_parser.add_argument("--beta_kl_constraints_stop", type=float, default=1.)
    args_parser.add_argument("--beta_n_cycles", type=int, default=10)
    args_parser.add_argument("--latent_lmf_rank", type=int, default=4)
    args_parser.add_argument("--CVAE", action='store_true')
    args_parser.add_argument("--share_param", action='store_false',
        help="specify this argument if we want to share dec's and enc's params")

    args, unknown_args = args_parser.parse_known_args()

    # load model
    model, tokenizer =  DeepVAEModule.load_model(args, labels_dict=None)
    # VAE generation 
    sentence =  "本模型是在通用数据集下预训练的VAE模型,如要获得最佳效果请在特定领域微调后使用。"
    tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(sentence))
    decoder_target = [tokenizer.bos_token_id] + tokenized_text + [tokenizer.eos_token_id]
    inputs = []
    inputs.append(torch.tensor(decoder_target, dtype=torch.long))
    inputs = pad_sequence(inputs, batch_first=True, padding_value=0)

    max_length = 256
    top_p = 0.5
    top_k = 0
    temperature = .7
    repetition_penalty = 1.0
    sample = False
    device = 0
    model = model.eval()
    model = model.to(device)

    outputs = model.inference(inputs.to(device), top_p=top_p, top_k=top_k, max_length=max_length, sample=sample,
        temperature=temperature, repetition_penalty=repetition_penalty)

    for gen_sent, orig_sent in zip(outputs, inputs):
        print('orig_sent:', tokenizer.decode(orig_sent).replace(' ', ''))
        print('gen_sent:', tokenizer.decode(gen_sent).replace(' ', ''))
        print("-"*20)




```

## 引用 Citation

如果您在您的工作中使用了我们的模型,可以引用我们的[论文](https://arxiv.org/abs/2209.02970):

If you are using the resource for your work, please cite the our [paper](https://arxiv.org/abs/2209.02970):

```text
@article{fengshenbang,
  author    = {Junjie Wang and Yuxiang Zhang and Lin Zhang and Ping Yang and Xinyu Gao and Ziwei Wu and Xiaoqun Dong and Junqing He and Jianheng Zhuo and Qi Yang and Yongfeng Huang and Xiayu Li and Yanghan Wu and Junyu Lu and Xinyu Zhu and Weifeng Chen and Ting Han and Kunhao Pan and Rui Wang and Hao Wang and Xiaojun Wu and Zhongshen Zeng and Chongpei Chen and Ruyi Gan and Jiaxing Zhang},
  title     = {Fengshenbang 1.0: Being the Foundation of Chinese Cognitive Intelligence},
  journal   = {CoRR},
  volume    = {abs/2209.02970},
  year      = {2022}
}
```

也可以引用我们的[网站](https://github.com/IDEA-CCNL/Fengshenbang-LM/):

You can also cite our [website](https://github.com/IDEA-CCNL/Fengshenbang-LM/):

```text
@misc{Fengshenbang-LM,
  title={Fengshenbang-LM},
  author={IDEA-CCNL},
  year={2021},
  howpublished={\url{https://github.com/IDEA-CCNL/Fengshenbang-LM}},
}
```