Spaces:
Runtime error
Runtime error
j.gilyazev
commited on
Commit
·
0766044
1
Parent(s):
776e43c
add personalized-chat-bot
Browse files- personalized-chat-bot/models/__init__.py +1 -0
- personalized-chat-bot/models/personality_clustering.py +74 -0
- personalized-chat-bot/scripts/__init__.py +1 -0
- personalized-chat-bot/scripts/config_176b.json +16 -0
- personalized-chat-bot/scripts/config_6b.json +16 -0
- personalized-chat-bot/scripts/fit_personality_clustering.py +52 -0
- personalized-chat-bot/scripts/train_all.sh +11 -0
- personalized-chat-bot/scripts/train_bloom_personachat.py +123 -0
personalized-chat-bot/models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# coding=utf-8
|
personalized-chat-bot/models/personality_clustering.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from sentence_transformers import SentenceTransformer
|
3 |
+
from sklearn.cluster import KMeans
|
4 |
+
import pickle
|
5 |
+
|
6 |
+
|
7 |
+
class PersonalityClustering:
|
8 |
+
DEFAULT_SENTENCE_TRANSFORMER = 'paraphrase-MiniLM-L6-v2'
|
9 |
+
|
10 |
+
@property
|
11 |
+
def sentence_transformer(self):
|
12 |
+
"""Ленивая инициализация sentence_transformer."""
|
13 |
+
if not self.__sentence_transformer:
|
14 |
+
self.__sentence_transformer = SentenceTransformer(self.model_name, device=self.device)
|
15 |
+
return self.__sentence_transformer
|
16 |
+
|
17 |
+
@property
|
18 |
+
def clustering(self):
|
19 |
+
"""Ленивая инициализация кластеризации."""
|
20 |
+
if not self.__clustering:
|
21 |
+
self.__clustering = KMeans(n_clusters=self.n_clusters)
|
22 |
+
return self.__clustering
|
23 |
+
|
24 |
+
def __init__(self, n_clusters=None, device='cpu', model_name=None):
|
25 |
+
if model_name is None:
|
26 |
+
self.model_name = self.DEFAULT_SENTENCE_TRANSFORMER
|
27 |
+
else:
|
28 |
+
self.model_name = model_name
|
29 |
+
self.device = device
|
30 |
+
self.n_clusters = n_clusters
|
31 |
+
self._cluster_centers = None
|
32 |
+
self.__clustering = None
|
33 |
+
self.__sentence_transformer = None
|
34 |
+
|
35 |
+
def load(self, path):
|
36 |
+
with open(path, "rb") as f:
|
37 |
+
self.__clustering, self._cluster_centers = pickle.load(f)
|
38 |
+
|
39 |
+
def save(self, path):
|
40 |
+
with open(path, "wb") as f:
|
41 |
+
pickle.dump((self.__clustering, self._cluster_centers), f)
|
42 |
+
|
43 |
+
def fit(self, personalities):
|
44 |
+
personalities = np.array(list(personalities))
|
45 |
+
train_embeddings = self.sentence_transformer.encode(personalities)
|
46 |
+
clusters = self.clustering.fit_predict(train_embeddings)
|
47 |
+
persona_cluster_centers = []
|
48 |
+
for clust, center in enumerate(self.clustering.cluster_centers_):
|
49 |
+
cur_clust_embed = train_embeddings[clusters == clust]
|
50 |
+
cur_clust_personalities = personalities[clusters == clust]
|
51 |
+
min_distance_to_center = np.inf
|
52 |
+
persona_center = None
|
53 |
+
for embed, persona in zip(cur_clust_embed, cur_clust_personalities):
|
54 |
+
cur_distance_to_center = np.linalg.norm(embed - center)
|
55 |
+
if cur_distance_to_center < min_distance_to_center:
|
56 |
+
min_distance_to_center = cur_distance_to_center
|
57 |
+
persona_center = persona
|
58 |
+
persona_cluster_centers.append(persona_center)
|
59 |
+
self._cluster_centers = np.array(persona_cluster_centers)
|
60 |
+
return self
|
61 |
+
|
62 |
+
def predict(self, personalities):
|
63 |
+
personalities = np.array(list(personalities))
|
64 |
+
embeddings = self.sentence_transformer.encode(personalities)
|
65 |
+
clusters = self.clustering.predict(embeddings)
|
66 |
+
return clusters
|
67 |
+
|
68 |
+
def predict_nearest_personality(self, personalities):
|
69 |
+
clusters = self.predict(personalities)
|
70 |
+
return np.array([self._cluster_centers[clust] for clust in clusters])
|
71 |
+
|
72 |
+
def fit_predict(self, personalities):
|
73 |
+
self.fit(personalities)
|
74 |
+
return self.predict(personalities)
|
personalized-chat-bot/scripts/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# coding=utf-8
|
personalized-chat-bot/scripts/config_176b.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"PERSONACHAT_DATASET_NAME": "bavard/personachat_truecased",
|
3 |
+
"MODEL_NAME": "bigscience/bloom-petals",
|
4 |
+
"INITIAL_PEERS": [],
|
5 |
+
"NUM_PREFIX_TOKENS": 16,
|
6 |
+
"DEVICE": "cpu",
|
7 |
+
"BATCH_SIZE": 4,
|
8 |
+
"LR": 0.01,
|
9 |
+
"WEIGHT_DECAY": 0.0,
|
10 |
+
"NUM_SAMPLES": 1000,
|
11 |
+
"SEED": 42,
|
12 |
+
"MODEL_MAX_LENGTH": 256,
|
13 |
+
"TUNING_MODE": "ptune",
|
14 |
+
"N_EPOCH": 10,
|
15 |
+
"PADDING_SIDE": "right"
|
16 |
+
}
|
personalized-chat-bot/scripts/config_6b.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"PERSONACHAT_DATASET_NAME": "bavard/personachat_truecased",
|
3 |
+
"MODEL_NAME": "bigscience/test-bloomd-6b3",
|
4 |
+
"INITIAL_PEERS":["/ip4/193.106.95.184/tcp/31000/p2p/QmSg7izCDtowVTACbUmWvEiQZNY4wgCQ9T9Doo66K59X6q"],
|
5 |
+
"NUM_PREFIX_TOKENS": 16,
|
6 |
+
"DEVICE": "cpu",
|
7 |
+
"BATCH_SIZE": 4,
|
8 |
+
"LR": 0.01,
|
9 |
+
"WEIGHT_DECAY": 0.0,
|
10 |
+
"NUM_SAMPLES": 1000,
|
11 |
+
"SEED": 42,
|
12 |
+
"MODEL_MAX_LENGTH": 256,
|
13 |
+
"TUNING_MODE": "ptune",
|
14 |
+
"N_EPOCH": 1,
|
15 |
+
"PADDING_SIDE": "right"
|
16 |
+
}
|
personalized-chat-bot/scripts/fit_personality_clustering.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from datasets import load_dataset
|
3 |
+
from models.personality_clustering import PersonalityClustering
|
4 |
+
import os
|
5 |
+
|
6 |
+
"""Пример запуска
|
7 |
+
python -m scripts.fit_personality_clustering --clustering-path data/models --n-clusters 500
|
8 |
+
"""
|
9 |
+
|
10 |
+
PERSONACHAT_DATASET = "bavard/personachat_truecased"
|
11 |
+
|
12 |
+
|
13 |
+
def load_persona_chat_personalities(personachat_dataset):
|
14 |
+
dataset = load_dataset(personachat_dataset)
|
15 |
+
train_personalities = [sent for persona in dataset['train']['personality']
|
16 |
+
for sent in persona]
|
17 |
+
test_personalities = [sent for persona in dataset['train']['personality']
|
18 |
+
for sent in persona]
|
19 |
+
personalities = list(set(train_personalities) | set(test_personalities))
|
20 |
+
return personalities
|
21 |
+
|
22 |
+
|
23 |
+
def parse_args(args=None):
|
24 |
+
parser = argparse.ArgumentParser(add_help=True, description="Class for personality clustering.")
|
25 |
+
|
26 |
+
parser.add_argument('-clustering-path', '--clustering-path', type=str,
|
27 |
+
help='Path to clustering data.')
|
28 |
+
parser.add_argument('-n-clusters', '--n-clusters', type=int, default=500,
|
29 |
+
help='The number of clusters to form.')
|
30 |
+
parser.add_argument('-model-name', '--model-name', type=str, default=None, required=False)
|
31 |
+
args = parser.parse_args(args)
|
32 |
+
return args
|
33 |
+
|
34 |
+
|
35 |
+
def main():
|
36 |
+
args = parse_args()
|
37 |
+
personalities = load_persona_chat_personalities(PERSONACHAT_DATASET)
|
38 |
+
print('Data loaded')
|
39 |
+
model = PersonalityClustering(n_clusters=args.n_clusters)
|
40 |
+
print('Model fitting')
|
41 |
+
model.fit(personalities)
|
42 |
+
print('Model fitted')
|
43 |
+
if args.model_name is None:
|
44 |
+
model_name = f'personality_clustering_{model.n_clusters}_{model.model_name}_k-means.pkl'
|
45 |
+
else:
|
46 |
+
model_name = args.model_name
|
47 |
+
model.save(os.path.join(args.clustering_path, model_name))
|
48 |
+
print(f'{model_name} saved')
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == '__main__':
|
52 |
+
main()
|
personalized-chat-bot/scripts/train_all.sh
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
#python -m scripts.train_bloom_personachat --persona-ids 113 54 169 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
|
4 |
+
#python -m scripts.train_bloom_personachat --persona-ids 364 214 125 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
|
5 |
+
#python -m scripts.train_bloom_personachat --persona-ids 103 200 296 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
|
6 |
+
#python -m scripts.train_bloom_personachat --persona-ids 20 384 365 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
|
7 |
+
#python -m scripts.train_bloom_personachat --persona-ids 208 43 99 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
|
8 |
+
#python -m scripts.train_bloom_personachat --persona-ids 426 477 470 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
|
9 |
+
python -m scripts.train_bloom_personachat --persona-ids 470 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
|
10 |
+
|
11 |
+
python -m scripts.train_bloom_personachat --persona-ids 329 402 382 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
|
personalized-chat-bot/scripts/train_bloom_personachat.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import torch.cuda
|
4 |
+
from datasets import load_dataset
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import transformers
|
8 |
+
from torch.utils.data import Subset
|
9 |
+
import wandb
|
10 |
+
import numpy as np
|
11 |
+
import gc
|
12 |
+
|
13 |
+
from models.personality_clustering import PersonalityClustering
|
14 |
+
from util.bloom_trainer import BloomTrainer
|
15 |
+
from util.data import PersonaChatDataset
|
16 |
+
from util.metrics import perplexity
|
17 |
+
|
18 |
+
from petals.client.remote_model import DistributedBloomForCausalLM
|
19 |
+
|
20 |
+
"""Пример запуска
|
21 |
+
python -m scripts.train_bloom_personachat --persona-ids 6 --config scripts/config.json --prompt-path data/models/
|
22 |
+
"""
|
23 |
+
|
24 |
+
DEFAULT_CLUSTERING_MODEL = './data/models/personality_clustering_500_paraphrase-MiniLM-L6-v2_k-means.pkl'
|
25 |
+
MAX_VAL_DATA_SIZE = 4
|
26 |
+
|
27 |
+
|
28 |
+
def load_config(path):
|
29 |
+
with open(path, 'r') as f:
|
30 |
+
config = json.load(f)
|
31 |
+
return argparse.Namespace(**config)
|
32 |
+
|
33 |
+
|
34 |
+
def main():
|
35 |
+
args = parse_args()
|
36 |
+
persona_clustering = PersonalityClustering()
|
37 |
+
persona_clustering.load(args.clustering_model_path)
|
38 |
+
|
39 |
+
config = load_config(args.config)
|
40 |
+
|
41 |
+
tokenizer = transformers.BloomTokenizerFast.from_pretrained(config.MODEL_NAME)
|
42 |
+
tokenizer.padding_side = config.PADDING_SIDE
|
43 |
+
tokenizer.model_max_length = config.MODEL_MAX_LENGTH
|
44 |
+
|
45 |
+
dataset = load_dataset(config.PERSONACHAT_DATASET_NAME)
|
46 |
+
personachat_train_dataset = PersonaChatDataset(persona_clustering,
|
47 |
+
dataset['train'],
|
48 |
+
tokenizer)
|
49 |
+
personachat_val_dataset = PersonaChatDataset(persona_clustering,
|
50 |
+
dataset['validation'],
|
51 |
+
tokenizer)
|
52 |
+
|
53 |
+
for id in args.persona_ids:
|
54 |
+
prompt_path = os.path.join(args.prompt_path, f'{id}_persona_prompt_embedding.pt')
|
55 |
+
train_dataset = personachat_train_dataset[id]
|
56 |
+
val_dataset = personachat_val_dataset[id]
|
57 |
+
honest_validation = True
|
58 |
+
if len(val_dataset) < 4:
|
59 |
+
val_dataset = personachat_train_dataset[id]
|
60 |
+
honest_validation = False
|
61 |
+
# для ускорения обрежем размер валидации до некоторой границы
|
62 |
+
if len(val_dataset) > MAX_VAL_DATA_SIZE:
|
63 |
+
subset_indexes = np.random.choice(len(val_dataset), MAX_VAL_DATA_SIZE, replace=False)
|
64 |
+
val_dataset = Subset(val_dataset, subset_indexes)
|
65 |
+
# train_dataset.shuffle()
|
66 |
+
|
67 |
+
wandb_run = wandb.init(
|
68 |
+
project=args.wandb_project,
|
69 |
+
config={
|
70 |
+
'lr': config.LR,
|
71 |
+
'batch_size': config.BATCH_SIZE,
|
72 |
+
'persona_id': id,
|
73 |
+
'device': config.DEVICE,
|
74 |
+
'model_name': config.MODEL_NAME,
|
75 |
+
'n_epoch': config.N_EPOCH,
|
76 |
+
'honest_validation': honest_validation
|
77 |
+
},
|
78 |
+
name=f'id{id}',
|
79 |
+
reinit=True
|
80 |
+
)
|
81 |
+
if len(config.INITIAL_PEERS) == 0:
|
82 |
+
model = DistributedBloomForCausalLM.from_pretrained(
|
83 |
+
config.MODEL_NAME,
|
84 |
+
pre_seq_len=config.NUM_PREFIX_TOKENS,
|
85 |
+
tuning_mode=config.TUNING_MODE
|
86 |
+
).to(config.DEVICE)
|
87 |
+
else:
|
88 |
+
model = DistributedBloomForCausalLM.from_pretrained(
|
89 |
+
config.MODEL_NAME,
|
90 |
+
initial_peers=config.INITIAL_PEERS,
|
91 |
+
pre_seq_len=config.NUM_PREFIX_TOKENS,
|
92 |
+
tuning_mode=config.TUNING_MODE
|
93 |
+
).to(config.DEVICE)
|
94 |
+
|
95 |
+
trainer = BloomTrainer(model, config, train_dataset, val_dataset, wandb_run, prompt_path)
|
96 |
+
trainer.train()
|
97 |
+
eval_perplexity = trainer.evaluate(perplexity)
|
98 |
+
trainer.save_model(prompt_path)
|
99 |
+
wandb_run.log({'perplexity': eval_perplexity, 'model_path': prompt_path})
|
100 |
+
|
101 |
+
del model
|
102 |
+
gc.collect()
|
103 |
+
torch.cuda.empty_cache()
|
104 |
+
|
105 |
+
|
106 |
+
def parse_args(args=None):
|
107 |
+
parser = argparse.ArgumentParser(add_help=True,
|
108 |
+
description="bloom training script")
|
109 |
+
parser.add_argument('--persona-ids', type=int, nargs='+',
|
110 |
+
help='Ids of persona')
|
111 |
+
parser.add_argument('-clustering-model-path', '--clustering-model-path', type=str,
|
112 |
+
default=DEFAULT_CLUSTERING_MODEL,
|
113 |
+
help='Path to clustering model')
|
114 |
+
parser.add_argument('--config', type=str, help='Path to training config file')
|
115 |
+
parser.add_argument('--prompt-path', type=str,
|
116 |
+
help='Path to dir with trained soft prompts')
|
117 |
+
parser.add_argument('--wandb-project', type=str, default='test_bloom_personachat_176b_v3')
|
118 |
+
args = parser.parse_args(args)
|
119 |
+
return args
|
120 |
+
|
121 |
+
|
122 |
+
if __name__ == '__main__':
|
123 |
+
main()
|