File size: 5,563 Bytes
399674e
92582b8
 
 
 
 
 
399674e
92582b8
 
 
 
 
 
 
 
 
 
 
 
 
399674e
 
92582b8
399674e
92582b8
 
399674e
92582b8
 
 
 
399674e
92582b8
399674e
92582b8
399674e
92582b8
 
399674e
92582b8
 
 
 
 
 
399674e
92582b8
399674e
92582b8
 
399674e
92582b8
 
 
 
 
 
399674e
92582b8
 
399674e
92582b8
 
399674e
92582b8
 
 
 
 
 
 
399674e
92582b8
399674e
92582b8
 
399674e
92582b8
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
---
base_model:
- tokyotech-llm/Swallow-7b-hf
- tokyotech-llm/Swallow-7b-instruct-hf
- nitky/Superswallow-7b-v0.1
- nitky/Superswallow-7b-v0.2
- nitky/Superswallow-7b-v0.3
library_name: transformers
tags:
- merge
- moe
- lisa
license: cc-by-nc-sa-4.0
datasets:
- kunishou/amenokaku-code-instruct
- llm-jp/oasst1-21k-en
- hieunguyenminh/roleplay
- meta-math/MetaMathQA
- kunishou/jp-effective-instructions
language:
- ja
---

# Swallow-MoE-4x7B-lisa

## 概要
[tokyotech-llm/Swallow-7b-hf](https://huggingface.co/tokyotech-llm/Swallow-7b-hf)をベースに、以下の4モデルをgate_mode=randomでMoEし、その後[LISA](https://arxiv.org/abs/2403.17919)という手法でインストラクションチューニングを施したモデルです。

- [tokyotech-llm/Swallow-7b-instruct-hf](https://huggingface.co/tokyotech-llm/Swallow-7b-instruct-hf)
- [nitky/Superswallow-7b-v0.1](https://huggingface.co/nitky/Superswallow-7b-v0.1)
- [nitky/Superswallow-7b-v0.2](https://huggingface.co/nitky/Superswallow-7b-v0.2)
- [nitky/Superswallow-7b-v0.3](https://huggingface.co/nitky/Superswallow-7b-v0.3)

お試しで作ってみたものなので、性能にはあまり期待しないでください。以下にベンチマーク結果も記載しております。

**なお、この学習で使ったLISAの実装には[不具合がある可能性](https://github.com/OptimalScale/LMFlow/issues/726)が指摘されており、正常に学習できていない可能性があります。**

## データセット
以下の合計14327件のデータを学習に利用しました。プロンプトフォーマットはAlpacaを利用しています。

- [kunishou/amenokaku-code-instruct](https://huggingface.co/datasets/kunishou/amenokaku-code-instruct)の各sourceから最大100件、計1475件
- [kunishou/jp-effective-instructions](https://huggingface.co/datasets/kunishou/jp-effective-instructions)のinstructionとoutputがともに11文字以上のデータ、計5050件
- [llm-jp/oasst1-21k-en](https://huggingface.co/datasets/llm-jp/oasst1-21k-en)よりランダムな1000件(英語)
- [hieunguyenminh/roleplay](https://huggingface.co/datasets/hieunguyenminh/roleplay)よりランダムな1000件(英語)
- [meta-math/MetaMathQA](https://huggingface.co/datasets/meta-math/MetaMathQA)よりランダムな1000件(英語)
- [ichikara-instruction](https://liat-aip.sakura.ne.jp/wp/llm%E3%81%AE%E3%81%9F%E3%82%81%E3%81%AE%E6%97%A5%E6%9C%AC%E8%AA%9E%E3%82%A4%E3%83%B3%E3%82%B9%E3%83%88%E3%83%A9%E3%82%AF%E3%82%B7%E3%83%A7%E3%83%B3%E3%83%87%E3%83%BC%E3%82%BF%E4%BD%9C%E6%88%90/llm%E3%81%AE%E3%81%9F%E3%82%81%E3%81%AE%E6%97%A5%E6%9C%AC%E8%AA%9E%E3%82%A4%E3%83%B3%E3%82%B9%E3%83%88%E3%83%A9%E3%82%AF%E3%82%B7%E3%83%A7%E3%83%B3%E3%83%87%E3%83%BC%E3%82%BF-%E5%85%AC%E9%96%8B/)より、4802件

なお、ichikara-instructionの利用によりCC-BY-NC-SAを継承します。

## 学習の設定
主な学習パラメータは以下の通りです。なお、学習途中でのエラーのため2epochs程度しか学習できておりません。

- lisa_activated_layers: 8
- lisa_interval_steps: 13
- learning_rate: 5e-5
- num_train_epochs: 約2epochs
- batch_size: 64
- max_seq_length: 2048

## 評価
マージに利用したモデル群と本モデルの[japanese-mt-bench](https://github.com/Stability-AI/FastChat/tree/jp-stable/fastchat/llm_judge)の結果は以下の通りです。(シングルターン)

Swallow-instructよりはスコアが高く、Superswallowよりは低いという何とも言えない結果になっております。
とはいえ、少量のデータセット・たった2epochsの学習でSwallow-instructを超えられているのは一定の成果とも言えるかもしれません。

|Model|Size|Coding|Extraction|Humanities|Math|Reasoning|Roleplay|STEM|Writing|avg_score|
|---|---|---|---|---|---|---|---|---|---|---|
| Swallow-7b-instruct-hf | 7B | 2.0 | 4.6 | 5.4 | 1.7 | 2.8 | 5.0 | 5.9 | 6.9 | 4.2875 |
| Superswallow-7b-v0.1  | 7B | 2.0 | 5.1 | 7.8 | 2.1 | 3.6 | 6.2 | 7.3 | 7.5 | 5.2000 |
| Superswallow-7b-v0.2  | 7B | 2.2 | 5.8 | 6.7 | 2.5 | 4.3 | 5.5 | 6.6 | 5.8 | 4.9250 |
| Superswallow-7b-v0.3  | 7B | 2.1 | 4.6 | 8.3 | 2.1 | 5.0 | 6.3 | 7.7 | 8.9 | 5.6250 |
| **This model**  | **4x7B** | **2.0** | **3.4** | **7.5** | **1.9** | **2.6** | **5.5** | **6.3** | **7.5** | **4.5875** |

![レーダーチャート](./japanese_mt_bench.png)

同様に、jsquad(jsquad-1.1-0.3, 2-shots)、jcommonsenseqa(jcommonsenseqa-1.1-0.3, 3-shots)、jnli(jnli-1.3-0.3, 3-shots)、marc_ja(marc_ja-1.1-0.3, 3-shots)結果は以下の通りです。(jsquadは100で割り、それぞれ小数点以下第4位を四捨五入)
ここでもSwallow-instructよりはスコアが高く、Superswallowよりは低い結果になっています。なお、こちらは参考として本モデルのインストラクションチューニング前(MoEのみ)のモデルのスコアも載せてあります。

|Model|Size|jsquad(exact_match)|jcommonsenseqa(acc)|jnli(acc)|marc_ja(acc)|average|
|---|---|---|---|---|---|---|
| Swallow-7b-instruct-hf | 7B | 0.757 | 0.831 | 0.212 | 0.945 | 0.686 |
| Superswallow-7b-v0.1  | 7B | 0.441 | 0.846 | 0.374 | 0.966 | 0.657 |
| Superswallow-7b-v0.2  | 7B | 0.722 | 0.846 | 0.381 | 0.964 | 0.728 |
| Superswallow-7b-v0.3  | 7B | 0.721 | 0.850 | 0.362 | 0.964 | 0.724 |
| **This model without fine-tuning**  | **4x7B** | **0.674** | **0.809** | **0.333** | **0.952** | **0.692** |
| **This model**  | **4x7B** | **0.741** | **0.806** | **0.385** | **0.948** | **0.719** |