j.gilyazev commited on
Commit
0766044
·
1 Parent(s): 776e43c

add personalized-chat-bot

Browse files
personalized-chat-bot/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # coding=utf-8
personalized-chat-bot/models/personality_clustering.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sentence_transformers import SentenceTransformer
3
+ from sklearn.cluster import KMeans
4
+ import pickle
5
+
6
+
7
+ class PersonalityClustering:
8
+ DEFAULT_SENTENCE_TRANSFORMER = 'paraphrase-MiniLM-L6-v2'
9
+
10
+ @property
11
+ def sentence_transformer(self):
12
+ """Ленивая инициализация sentence_transformer."""
13
+ if not self.__sentence_transformer:
14
+ self.__sentence_transformer = SentenceTransformer(self.model_name, device=self.device)
15
+ return self.__sentence_transformer
16
+
17
+ @property
18
+ def clustering(self):
19
+ """Ленивая инициализация кластеризации."""
20
+ if not self.__clustering:
21
+ self.__clustering = KMeans(n_clusters=self.n_clusters)
22
+ return self.__clustering
23
+
24
+ def __init__(self, n_clusters=None, device='cpu', model_name=None):
25
+ if model_name is None:
26
+ self.model_name = self.DEFAULT_SENTENCE_TRANSFORMER
27
+ else:
28
+ self.model_name = model_name
29
+ self.device = device
30
+ self.n_clusters = n_clusters
31
+ self._cluster_centers = None
32
+ self.__clustering = None
33
+ self.__sentence_transformer = None
34
+
35
+ def load(self, path):
36
+ with open(path, "rb") as f:
37
+ self.__clustering, self._cluster_centers = pickle.load(f)
38
+
39
+ def save(self, path):
40
+ with open(path, "wb") as f:
41
+ pickle.dump((self.__clustering, self._cluster_centers), f)
42
+
43
+ def fit(self, personalities):
44
+ personalities = np.array(list(personalities))
45
+ train_embeddings = self.sentence_transformer.encode(personalities)
46
+ clusters = self.clustering.fit_predict(train_embeddings)
47
+ persona_cluster_centers = []
48
+ for clust, center in enumerate(self.clustering.cluster_centers_):
49
+ cur_clust_embed = train_embeddings[clusters == clust]
50
+ cur_clust_personalities = personalities[clusters == clust]
51
+ min_distance_to_center = np.inf
52
+ persona_center = None
53
+ for embed, persona in zip(cur_clust_embed, cur_clust_personalities):
54
+ cur_distance_to_center = np.linalg.norm(embed - center)
55
+ if cur_distance_to_center < min_distance_to_center:
56
+ min_distance_to_center = cur_distance_to_center
57
+ persona_center = persona
58
+ persona_cluster_centers.append(persona_center)
59
+ self._cluster_centers = np.array(persona_cluster_centers)
60
+ return self
61
+
62
+ def predict(self, personalities):
63
+ personalities = np.array(list(personalities))
64
+ embeddings = self.sentence_transformer.encode(personalities)
65
+ clusters = self.clustering.predict(embeddings)
66
+ return clusters
67
+
68
+ def predict_nearest_personality(self, personalities):
69
+ clusters = self.predict(personalities)
70
+ return np.array([self._cluster_centers[clust] for clust in clusters])
71
+
72
+ def fit_predict(self, personalities):
73
+ self.fit(personalities)
74
+ return self.predict(personalities)
personalized-chat-bot/scripts/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # coding=utf-8
personalized-chat-bot/scripts/config_176b.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "PERSONACHAT_DATASET_NAME": "bavard/personachat_truecased",
3
+ "MODEL_NAME": "bigscience/bloom-petals",
4
+ "INITIAL_PEERS": [],
5
+ "NUM_PREFIX_TOKENS": 16,
6
+ "DEVICE": "cpu",
7
+ "BATCH_SIZE": 4,
8
+ "LR": 0.01,
9
+ "WEIGHT_DECAY": 0.0,
10
+ "NUM_SAMPLES": 1000,
11
+ "SEED": 42,
12
+ "MODEL_MAX_LENGTH": 256,
13
+ "TUNING_MODE": "ptune",
14
+ "N_EPOCH": 10,
15
+ "PADDING_SIDE": "right"
16
+ }
personalized-chat-bot/scripts/config_6b.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "PERSONACHAT_DATASET_NAME": "bavard/personachat_truecased",
3
+ "MODEL_NAME": "bigscience/test-bloomd-6b3",
4
+ "INITIAL_PEERS":["/ip4/193.106.95.184/tcp/31000/p2p/QmSg7izCDtowVTACbUmWvEiQZNY4wgCQ9T9Doo66K59X6q"],
5
+ "NUM_PREFIX_TOKENS": 16,
6
+ "DEVICE": "cpu",
7
+ "BATCH_SIZE": 4,
8
+ "LR": 0.01,
9
+ "WEIGHT_DECAY": 0.0,
10
+ "NUM_SAMPLES": 1000,
11
+ "SEED": 42,
12
+ "MODEL_MAX_LENGTH": 256,
13
+ "TUNING_MODE": "ptune",
14
+ "N_EPOCH": 1,
15
+ "PADDING_SIDE": "right"
16
+ }
personalized-chat-bot/scripts/fit_personality_clustering.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from datasets import load_dataset
3
+ from models.personality_clustering import PersonalityClustering
4
+ import os
5
+
6
+ """Пример запуска
7
+ python -m scripts.fit_personality_clustering --clustering-path data/models --n-clusters 500
8
+ """
9
+
10
+ PERSONACHAT_DATASET = "bavard/personachat_truecased"
11
+
12
+
13
+ def load_persona_chat_personalities(personachat_dataset):
14
+ dataset = load_dataset(personachat_dataset)
15
+ train_personalities = [sent for persona in dataset['train']['personality']
16
+ for sent in persona]
17
+ test_personalities = [sent for persona in dataset['train']['personality']
18
+ for sent in persona]
19
+ personalities = list(set(train_personalities) | set(test_personalities))
20
+ return personalities
21
+
22
+
23
+ def parse_args(args=None):
24
+ parser = argparse.ArgumentParser(add_help=True, description="Class for personality clustering.")
25
+
26
+ parser.add_argument('-clustering-path', '--clustering-path', type=str,
27
+ help='Path to clustering data.')
28
+ parser.add_argument('-n-clusters', '--n-clusters', type=int, default=500,
29
+ help='The number of clusters to form.')
30
+ parser.add_argument('-model-name', '--model-name', type=str, default=None, required=False)
31
+ args = parser.parse_args(args)
32
+ return args
33
+
34
+
35
+ def main():
36
+ args = parse_args()
37
+ personalities = load_persona_chat_personalities(PERSONACHAT_DATASET)
38
+ print('Data loaded')
39
+ model = PersonalityClustering(n_clusters=args.n_clusters)
40
+ print('Model fitting')
41
+ model.fit(personalities)
42
+ print('Model fitted')
43
+ if args.model_name is None:
44
+ model_name = f'personality_clustering_{model.n_clusters}_{model.model_name}_k-means.pkl'
45
+ else:
46
+ model_name = args.model_name
47
+ model.save(os.path.join(args.clustering_path, model_name))
48
+ print(f'{model_name} saved')
49
+
50
+
51
+ if __name__ == '__main__':
52
+ main()
personalized-chat-bot/scripts/train_all.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ #python -m scripts.train_bloom_personachat --persona-ids 113 54 169 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
4
+ #python -m scripts.train_bloom_personachat --persona-ids 364 214 125 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
5
+ #python -m scripts.train_bloom_personachat --persona-ids 103 200 296 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
6
+ #python -m scripts.train_bloom_personachat --persona-ids 20 384 365 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
7
+ #python -m scripts.train_bloom_personachat --persona-ids 208 43 99 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
8
+ #python -m scripts.train_bloom_personachat --persona-ids 426 477 470 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
9
+ python -m scripts.train_bloom_personachat --persona-ids 470 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
10
+
11
+ python -m scripts.train_bloom_personachat --persona-ids 329 402 382 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
personalized-chat-bot/scripts/train_bloom_personachat.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch.cuda
4
+ from datasets import load_dataset
5
+ import json
6
+ import os
7
+ import transformers
8
+ from torch.utils.data import Subset
9
+ import wandb
10
+ import numpy as np
11
+ import gc
12
+
13
+ from models.personality_clustering import PersonalityClustering
14
+ from util.bloom_trainer import BloomTrainer
15
+ from util.data import PersonaChatDataset
16
+ from util.metrics import perplexity
17
+
18
+ from petals.client.remote_model import DistributedBloomForCausalLM
19
+
20
+ """Пример запуска
21
+ python -m scripts.train_bloom_personachat --persona-ids 6 --config scripts/config.json --prompt-path data/models/
22
+ """
23
+
24
+ DEFAULT_CLUSTERING_MODEL = './data/models/personality_clustering_500_paraphrase-MiniLM-L6-v2_k-means.pkl'
25
+ MAX_VAL_DATA_SIZE = 4
26
+
27
+
28
+ def load_config(path):
29
+ with open(path, 'r') as f:
30
+ config = json.load(f)
31
+ return argparse.Namespace(**config)
32
+
33
+
34
+ def main():
35
+ args = parse_args()
36
+ persona_clustering = PersonalityClustering()
37
+ persona_clustering.load(args.clustering_model_path)
38
+
39
+ config = load_config(args.config)
40
+
41
+ tokenizer = transformers.BloomTokenizerFast.from_pretrained(config.MODEL_NAME)
42
+ tokenizer.padding_side = config.PADDING_SIDE
43
+ tokenizer.model_max_length = config.MODEL_MAX_LENGTH
44
+
45
+ dataset = load_dataset(config.PERSONACHAT_DATASET_NAME)
46
+ personachat_train_dataset = PersonaChatDataset(persona_clustering,
47
+ dataset['train'],
48
+ tokenizer)
49
+ personachat_val_dataset = PersonaChatDataset(persona_clustering,
50
+ dataset['validation'],
51
+ tokenizer)
52
+
53
+ for id in args.persona_ids:
54
+ prompt_path = os.path.join(args.prompt_path, f'{id}_persona_prompt_embedding.pt')
55
+ train_dataset = personachat_train_dataset[id]
56
+ val_dataset = personachat_val_dataset[id]
57
+ honest_validation = True
58
+ if len(val_dataset) < 4:
59
+ val_dataset = personachat_train_dataset[id]
60
+ honest_validation = False
61
+ # для ускорения обрежем размер валидации до некоторой границы
62
+ if len(val_dataset) > MAX_VAL_DATA_SIZE:
63
+ subset_indexes = np.random.choice(len(val_dataset), MAX_VAL_DATA_SIZE, replace=False)
64
+ val_dataset = Subset(val_dataset, subset_indexes)
65
+ # train_dataset.shuffle()
66
+
67
+ wandb_run = wandb.init(
68
+ project=args.wandb_project,
69
+ config={
70
+ 'lr': config.LR,
71
+ 'batch_size': config.BATCH_SIZE,
72
+ 'persona_id': id,
73
+ 'device': config.DEVICE,
74
+ 'model_name': config.MODEL_NAME,
75
+ 'n_epoch': config.N_EPOCH,
76
+ 'honest_validation': honest_validation
77
+ },
78
+ name=f'id{id}',
79
+ reinit=True
80
+ )
81
+ if len(config.INITIAL_PEERS) == 0:
82
+ model = DistributedBloomForCausalLM.from_pretrained(
83
+ config.MODEL_NAME,
84
+ pre_seq_len=config.NUM_PREFIX_TOKENS,
85
+ tuning_mode=config.TUNING_MODE
86
+ ).to(config.DEVICE)
87
+ else:
88
+ model = DistributedBloomForCausalLM.from_pretrained(
89
+ config.MODEL_NAME,
90
+ initial_peers=config.INITIAL_PEERS,
91
+ pre_seq_len=config.NUM_PREFIX_TOKENS,
92
+ tuning_mode=config.TUNING_MODE
93
+ ).to(config.DEVICE)
94
+
95
+ trainer = BloomTrainer(model, config, train_dataset, val_dataset, wandb_run, prompt_path)
96
+ trainer.train()
97
+ eval_perplexity = trainer.evaluate(perplexity)
98
+ trainer.save_model(prompt_path)
99
+ wandb_run.log({'perplexity': eval_perplexity, 'model_path': prompt_path})
100
+
101
+ del model
102
+ gc.collect()
103
+ torch.cuda.empty_cache()
104
+
105
+
106
+ def parse_args(args=None):
107
+ parser = argparse.ArgumentParser(add_help=True,
108
+ description="bloom training script")
109
+ parser.add_argument('--persona-ids', type=int, nargs='+',
110
+ help='Ids of persona')
111
+ parser.add_argument('-clustering-model-path', '--clustering-model-path', type=str,
112
+ default=DEFAULT_CLUSTERING_MODEL,
113
+ help='Path to clustering model')
114
+ parser.add_argument('--config', type=str, help='Path to training config file')
115
+ parser.add_argument('--prompt-path', type=str,
116
+ help='Path to dir with trained soft prompts')
117
+ parser.add_argument('--wandb-project', type=str, default='test_bloom_personachat_176b_v3')
118
+ args = parser.parse_args(args)
119
+ return args
120
+
121
+
122
+ if __name__ == '__main__':
123
+ main()