lyangas
commited on
Commit
·
6931ba0
1
Parent(s):
4c9b947
missed files
Browse files- helpers/__init__.py +0 -0
- helpers/data_processor.py +180 -0
- helpers/firebase.py +148 -0
- helpers/gcloud.py +98 -0
- helpers/required_classes.py +177 -0
- helpers/trainer_classifiers.py +240 -0
- helpers/trainer_embedder.py +58 -0
helpers/__init__.py
ADDED
File without changes
|
helpers/data_processor.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
from sklearn.model_selection import train_test_split
|
4 |
+
from random import choices
|
5 |
+
|
6 |
+
|
7 |
+
def log(*args):
|
8 |
+
print(*args, flush=True)
|
9 |
+
|
10 |
+
def create_group(code):
|
11 |
+
"""
|
12 |
+
Creating group column, transforming an input string
|
13 |
+
Parameters:
|
14 |
+
code (str): string with ICD-10 code name
|
15 |
+
Returns:
|
16 |
+
group(str): string with ICD-10 group name
|
17 |
+
"""
|
18 |
+
|
19 |
+
group = code.split('.')[0]
|
20 |
+
return group
|
21 |
+
|
22 |
+
def df_creation(texts, labels,
|
23 |
+
all_classes, prompt_column_name,
|
24 |
+
code_column_name, group_column_name):
|
25 |
+
"""
|
26 |
+
Creates a DataFrame from medical reports, their corresponding ICD-10 codes, and class information.
|
27 |
+
|
28 |
+
Parameters:
|
29 |
+
texts (List[str]): A list of strings, where each string is a medical report.
|
30 |
+
labels (List[str]): A list of strings, where each string is an ICD-10 code name
|
31 |
+
relevant to the corresponding text in 'texts'.
|
32 |
+
all_classes (List[str]): A list of all ICD-10 code names from the initial dataset.
|
33 |
+
prompt_column_name (str): The column name in the DataFrame for the prompts.
|
34 |
+
code_column_name (str): The column name in the DataFrame for the codes.
|
35 |
+
group_column_name (str): The column name in the DataFrame for the groups.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
pandas.DataFrame: A DataFrame where each row contains the text of the report,
|
39 |
+
its corresponding ICD-10 code, and the group category derived
|
40 |
+
from the code.
|
41 |
+
"""
|
42 |
+
|
43 |
+
df = pd.DataFrame()
|
44 |
+
df[prompt_column_name] = texts
|
45 |
+
df[code_column_name] = [all_classes[c] for c in labels]
|
46 |
+
df[group_column_name] = [all_classes[c].split('.')[0] for c in labels]
|
47 |
+
return df
|
48 |
+
|
49 |
+
def select_random_rows(df_test, balance_column, random_n):
|
50 |
+
"""
|
51 |
+
Selects a random, balanced subset of rows from a DataFrame based on a specified column.
|
52 |
+
|
53 |
+
This function aims to create a balanced DataFrame by randomly selecting a specified number of rows
|
54 |
+
from each unique value in the balance column. It's particularly useful in scenarios where you
|
55 |
+
need a balanced sample from a dataset for testing or validation purposes.
|
56 |
+
|
57 |
+
Parameters:
|
58 |
+
df_test (pandas.DataFrame): The DataFrame to select rows from.
|
59 |
+
balance_column (str): The name of the column used to balance the data. The function will
|
60 |
+
select rows such that each unique value in this column is equally represented.
|
61 |
+
random_n (int): The number of rows to select for each unique value in the balance column.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
pandas.DataFrame: A new DataFrame containing a balanced, random subset of rows.
|
65 |
+
"""
|
66 |
+
|
67 |
+
classes = df_test[balance_column].unique()
|
68 |
+
balanced_data = []
|
69 |
+
for class_name in classes:
|
70 |
+
balanced_data += choices(df_test[df_test[balance_column]==class_name].to_dict('records'), k=random_n)
|
71 |
+
|
72 |
+
df = pd.DataFrame(balanced_data)
|
73 |
+
return df
|
74 |
+
|
75 |
+
def extract_valuable_data(path_to_raw_csv, prompt_column_name,
|
76 |
+
code_column_name, path_to_processed_csv, min_text_len, min_samples_per_cls):
|
77 |
+
"""
|
78 |
+
Extracts and processes valuable data from a raw CSV file based on specified criteria.
|
79 |
+
|
80 |
+
This function loads data from a CSV file, filters out rows based on non-null values in specified columns,
|
81 |
+
removes codes with a low number of associated prompts, filters for prompt length, creates a new 'group'
|
82 |
+
column, and saves the processed data to a new CSV file.
|
83 |
+
|
84 |
+
Parameters:
|
85 |
+
path_to_raw_csv (str): The file path to the raw CSV data file.
|
86 |
+
prompt_column_name (str): The column name in the CSV file for prompts.
|
87 |
+
code_column_name (str): The column name in the CSV file for codes.
|
88 |
+
path_to_processed_csv (str): The file path where the processed CSV data will be saved.
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
pandas.DataFrame: A DataFrame containing the processed dataset.
|
92 |
+
"""
|
93 |
+
|
94 |
+
df = pd.read_csv(path_to_raw_csv)
|
95 |
+
log(path_to_raw_csv, prompt_column_name, code_column_name, path_to_processed_csv, min_text_len, min_samples_per_cls)
|
96 |
+
|
97 |
+
df = df[df[prompt_column_name].notna() & df[code_column_name].notna()]
|
98 |
+
log(f"New data is loaded. New data has {len(df)} reports.")
|
99 |
+
log(f"New data contains {len(df['code'].unique())} unique codes.")
|
100 |
+
|
101 |
+
# Leave data for codes where more than min_samples_per_cls prompts.
|
102 |
+
unique_values = df['code'].value_counts()
|
103 |
+
values_to_remove = unique_values[unique_values <= min_samples_per_cls].index
|
104 |
+
df = df[~df['code'].isin(values_to_remove)]
|
105 |
+
|
106 |
+
# leave prompts that are longer that min_text_len characters
|
107 |
+
df = df[df[prompt_column_name].str.len() >= min_text_len]
|
108 |
+
|
109 |
+
# Creating GROUP column in dataset
|
110 |
+
df['group'] = df['code'].apply(create_group)
|
111 |
+
|
112 |
+
log(f"New data is processed. Processed data has {len(df)} reports.")
|
113 |
+
log(f"Processed dataset contains {len(df['code'].unique())} codes.")
|
114 |
+
log(f"Processed dataset contains {len(df['group'].unique())} groups.")
|
115 |
+
|
116 |
+
# Saving processed dataset
|
117 |
+
df.to_csv(path_to_processed_csv, index=False)
|
118 |
+
log(f"Processed dataset is saved to {path_to_processed_csv}.")
|
119 |
+
return df
|
120 |
+
|
121 |
+
|
122 |
+
def balance_data(df, prompt_column_name, code_column_name,
|
123 |
+
group_column_name,random_n, test_size, path_to_train_csv,
|
124 |
+
path_to_csv_test_codes, path_to_csv_test_groups):
|
125 |
+
"""
|
126 |
+
Balances and splits a dataset into training and test sets, then saves these sets to CSV files.
|
127 |
+
|
128 |
+
This function takes a DataFrame and performs stratified splitting based on the specified 'code_column_name'
|
129 |
+
to create balanced training and test datasets. It then saves the training dataset and two versions of
|
130 |
+
the test dataset (one for codes and one for groups) to separate CSV files.
|
131 |
+
|
132 |
+
Parameters:
|
133 |
+
df (pandas.DataFrame): The DataFrame to be processed and split.
|
134 |
+
prompt_column_name (str): The column name in the DataFrame for the prompts.
|
135 |
+
code_column_name (str): The column name in the DataFrame for the codes.
|
136 |
+
group_column_name (str): The column name in the DataFrame for the groups.
|
137 |
+
random_n (int): The number of rows to be randomly selected in test datasets for each unique code or group.
|
138 |
+
test_size (float): The proportion of the dataset to include in the test split.
|
139 |
+
path_to_train_csv (str): The file path where the training dataset CSV will be saved.
|
140 |
+
path_to_csv_test_codes (str): The file path where the test dataset for codes CSV will be saved.
|
141 |
+
path_to_csv_test_groups (str): The file path where the test dataset for groups CSV will be saved.
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
None
|
145 |
+
"""
|
146 |
+
|
147 |
+
texts = np.array(df[prompt_column_name])
|
148 |
+
labels = np.array(df[code_column_name])
|
149 |
+
groups = np.array(df[group_column_name])
|
150 |
+
|
151 |
+
all_classes = np.unique(labels).tolist()
|
152 |
+
labels = [all_classes.index(l) for l in labels]
|
153 |
+
log('='*50)
|
154 |
+
log(f"texts={len(texts)} labels={len(labels)} uniq_labels={len(np.unique(labels))} test_size={test_size}")
|
155 |
+
log('='*50)
|
156 |
+
texts_train, texts_test, labels_train, labels_test = train_test_split(
|
157 |
+
texts, labels, test_size=test_size, random_state=42, stratify=labels
|
158 |
+
)
|
159 |
+
|
160 |
+
log(f"Train dataset len={len(texts_train)}")
|
161 |
+
log(f"Test dataset len={len(texts_test)}")
|
162 |
+
log(f"Count of classes={len(np.unique(labels))}")
|
163 |
+
|
164 |
+
# Creating TRAIN and TEST dataset
|
165 |
+
df_train = df_creation(texts_train, labels_train, all_classes,
|
166 |
+
prompt_column_name, code_column_name, group_column_name)
|
167 |
+
df_train.to_csv(path_to_train_csv, index=False)
|
168 |
+
log(f"TRAIN dataset is saved to {path_to_train_csv}")
|
169 |
+
|
170 |
+
# Creating test datasets for codes and groups
|
171 |
+
df_test = df_creation(texts_test, labels_test, all_classes,
|
172 |
+
prompt_column_name, code_column_name, group_column_name)
|
173 |
+
|
174 |
+
df_test_codes = df_test # select_random_rows(df_test, code_column_name, random_n)
|
175 |
+
df_test_codes.to_csv(path_to_csv_test_codes, index=False)
|
176 |
+
log(f"TEST dataset for codes is saved to {path_to_csv_test_codes}")
|
177 |
+
|
178 |
+
df_test_groups = df_test # select_random_rows(df_test, group_column_name, random_n)
|
179 |
+
df_test_groups.to_csv(path_to_csv_test_groups, index=False)
|
180 |
+
log(f"TEST dataset for groups is saved to {path_to_csv_test_groups}")
|
helpers/firebase.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import firebase_admin
|
2 |
+
from firebase_admin import credentials
|
3 |
+
from firebase_admin import firestore
|
4 |
+
|
5 |
+
|
6 |
+
class FirebaseClient:
|
7 |
+
def __init__(self, path_to_certificate):
|
8 |
+
# Initialize Firebase Admin SDK
|
9 |
+
cred = credentials.Certificate(path_to_certificate) # Path to your service account key JSON file
|
10 |
+
firebase_admin.initialize_app(cred)
|
11 |
+
|
12 |
+
# Initialize Firestore database
|
13 |
+
self.db = firestore.client()
|
14 |
+
|
15 |
+
def add_task(self, task_data):
|
16 |
+
"""
|
17 |
+
Add a new task to Firestore.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
task_data (dict): Dictionary containing task data.
|
21 |
+
Example: {'title': 'Task Title', 'description': 'Task Description', 'status': 'pending'}
|
22 |
+
"""
|
23 |
+
# Add task data to Firestore
|
24 |
+
doc_ref = self.db.collection('tasks').document()
|
25 |
+
doc_ref.set(task_data)
|
26 |
+
return doc_ref.id
|
27 |
+
|
28 |
+
def get_task_by_status(self, status):
|
29 |
+
# Reference to the tasks collection
|
30 |
+
tasks_ref = self.db.collection('tasks')
|
31 |
+
|
32 |
+
# Query tasks with status 'pending'
|
33 |
+
query = tasks_ref.where('status', '==', status)
|
34 |
+
|
35 |
+
# Get documents that match the query
|
36 |
+
pending_tasks = query.stream()
|
37 |
+
|
38 |
+
# Convert documents to dictionaries
|
39 |
+
pending_tasks_data = []
|
40 |
+
for doc in pending_tasks:
|
41 |
+
task_data = doc.to_dict()
|
42 |
+
task_data['id'] = doc.id
|
43 |
+
pending_tasks_data.append(task_data)
|
44 |
+
|
45 |
+
return pending_tasks_data
|
46 |
+
|
47 |
+
def get_all_tasks(self):
|
48 |
+
"""
|
49 |
+
Retrieve all tasks from Firestore.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
list: A list containing dictionaries, each representing a task.
|
53 |
+
"""
|
54 |
+
# Reference to the 'tasks' collection
|
55 |
+
tasks_ref = self.db.collection('tasks')
|
56 |
+
|
57 |
+
# Get all documents in the collection
|
58 |
+
docs = tasks_ref.stream()
|
59 |
+
|
60 |
+
# Initialize an empty list to store tasks
|
61 |
+
tasks = []
|
62 |
+
|
63 |
+
# Iterate over each document and add it to the tasks list
|
64 |
+
for doc in docs:
|
65 |
+
doc_dict = doc.to_dict()
|
66 |
+
doc_dict['id'] = doc.id
|
67 |
+
tasks.append(doc_dict)
|
68 |
+
|
69 |
+
return tasks
|
70 |
+
|
71 |
+
def update(self, task_id, data):
|
72 |
+
"""
|
73 |
+
Reserve a task by a worker.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
task_id (str): ID of the task to be reserved.
|
77 |
+
worker_id (str): ID of the worker reserving the task.
|
78 |
+
"""
|
79 |
+
# Reference to the task document
|
80 |
+
task_ref = self.db.collection('tasks').document(task_id)
|
81 |
+
|
82 |
+
# Update the task document to indicate it has been reserved by the worker
|
83 |
+
task_ref.update(data)
|
84 |
+
|
85 |
+
def delete_task(self, task_id):
|
86 |
+
"""
|
87 |
+
Delete a task from Firestore by its ID.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
task_id (str): ID of the task to be deleted.
|
91 |
+
"""
|
92 |
+
# Reference to the task document
|
93 |
+
task_ref = self.db.collection('tasks').document(task_id)
|
94 |
+
|
95 |
+
# Delete the task document
|
96 |
+
task_ref.delete()
|
97 |
+
|
98 |
+
def get_task_by_id(self, task_id):
|
99 |
+
"""
|
100 |
+
Retrieve a task from Firestore by its ID.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
task_id (str): ID of the task to be retrieved.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
dict or None: Dictionary containing the task data if found, None otherwise.
|
107 |
+
"""
|
108 |
+
# Reference to the task document
|
109 |
+
task_ref = self.db.collection('tasks').document(task_id)
|
110 |
+
|
111 |
+
# Retrieve the task document
|
112 |
+
task_doc = task_ref.get()
|
113 |
+
|
114 |
+
# Check if the task document exists
|
115 |
+
if task_doc.exists:
|
116 |
+
return task_doc.to_dict()
|
117 |
+
else:
|
118 |
+
return None
|
119 |
+
|
120 |
+
def find_tasks_by_status(self, status):
|
121 |
+
"""
|
122 |
+
Find all tasks in Firestore with the specified status.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
status (str): Status value to filter tasks by.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
list: List of dictionaries containing task data.
|
129 |
+
"""
|
130 |
+
# Reference to the 'tasks' collection
|
131 |
+
tasks_ref = self.db.collection('tasks')
|
132 |
+
|
133 |
+
# Query tasks with the specified status
|
134 |
+
query = tasks_ref.where('status', '==', status)
|
135 |
+
|
136 |
+
# Get documents that match the query
|
137 |
+
docs = query.stream()
|
138 |
+
|
139 |
+
# Initialize an empty list to store tasks
|
140 |
+
tasks = []
|
141 |
+
|
142 |
+
# Iterate over each document and add it to the tasks list
|
143 |
+
for doc in docs:
|
144 |
+
task = doc.to_dict()
|
145 |
+
task['id'] = doc.id
|
146 |
+
tasks.append(task)
|
147 |
+
|
148 |
+
return tasks
|
helpers/gcloud.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from google.cloud import storage
|
3 |
+
from tqdm import tqdm
|
4 |
+
from googleapiclient import discovery
|
5 |
+
import requests
|
6 |
+
|
7 |
+
|
8 |
+
service = discovery.build('compute', 'v1')
|
9 |
+
storage_client = storage.Client()
|
10 |
+
|
11 |
+
def download_csv_from_gcloud(bucket_name, object_name, destination_file_path):
|
12 |
+
"""Download a file from Google Cloud Storage."""
|
13 |
+
|
14 |
+
bucket = storage_client.bucket(bucket_name)
|
15 |
+
blob = bucket.blob(object_name)
|
16 |
+
|
17 |
+
# Download the file to a local path
|
18 |
+
blob.download_to_filename(destination_file_path)
|
19 |
+
print(f"File {object_name} downloaded to {destination_file_path}")
|
20 |
+
|
21 |
+
def upload_folder_to_gcloud(bucket_name, source_folder_path, destination_folder_name):
|
22 |
+
"""Uploads all files in a folder to the Google Cloud Storage bucket."""
|
23 |
+
# Instantiates a client
|
24 |
+
# storage_client = storage.Client()
|
25 |
+
|
26 |
+
# Gets the bucket
|
27 |
+
print(f"bucket_name={bucket_name}, source_folder_path={source_folder_path}, destination_folder_name={destination_folder_name}", flush=True)
|
28 |
+
bucket = storage_client.bucket(bucket_name)
|
29 |
+
|
30 |
+
# Walk through the folder and upload each file
|
31 |
+
for root, _, files in os.walk(source_folder_path):
|
32 |
+
for file_name in files:
|
33 |
+
# Construct the local file path
|
34 |
+
local_file_path = os.path.join(root, file_name)
|
35 |
+
|
36 |
+
# Construct the destination blob name
|
37 |
+
destination_blob_name = os.path.join(destination_folder_name, os.path.relpath(local_file_path, source_folder_path))
|
38 |
+
print(f"destination_blob_name={destination_blob_name}")
|
39 |
+
# Upload the file
|
40 |
+
blob = bucket.blob(destination_blob_name)
|
41 |
+
blob.upload_from_filename(local_file_path)
|
42 |
+
|
43 |
+
print(f"File {local_file_path} uploaded to {destination_blob_name}.")
|
44 |
+
|
45 |
+
|
46 |
+
def download_folder(bucket_name, folder_name, destination_directory):
|
47 |
+
"""
|
48 |
+
Download the contents of a folder from a Google Cloud Storage bucket to a local directory.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
bucket_name (str): Name of the Google Cloud Storage bucket.
|
52 |
+
folder_name (str): Name of the folder in the bucket to download.
|
53 |
+
destination_directory (str): Local directory to save the downloaded files.
|
54 |
+
"""
|
55 |
+
|
56 |
+
# Get the bucket
|
57 |
+
bucket = storage_client.get_bucket(bucket_name)
|
58 |
+
|
59 |
+
# List objects in the folder
|
60 |
+
blobs = bucket.list_blobs(prefix=folder_name)
|
61 |
+
|
62 |
+
# Ensure destination directory exists
|
63 |
+
os.makedirs(destination_directory, exist_ok=True)
|
64 |
+
|
65 |
+
# Iterate over each object in the folder
|
66 |
+
for blob in tqdm(blobs, desc=f'Downloading {folder_name}'):
|
67 |
+
# Determine local file path
|
68 |
+
local_file_path = os.path.join(destination_directory, os.path.relpath(blob.name, folder_name))
|
69 |
+
|
70 |
+
# Ensure local directory exists
|
71 |
+
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
|
72 |
+
|
73 |
+
# Download the object to a local file
|
74 |
+
blob.download_to_filename(local_file_path)
|
75 |
+
|
76 |
+
|
77 |
+
def start_vm(project, zone, instance):
|
78 |
+
request = service.instances().start(project=project, zone=zone, instance=instance)
|
79 |
+
response = request.execute()
|
80 |
+
return response
|
81 |
+
|
82 |
+
def stop_vm(project, zone, instance):
|
83 |
+
request = service.instances().stop(project=project, zone=zone, instance=instance)
|
84 |
+
response = request.execute()
|
85 |
+
return response
|
86 |
+
|
87 |
+
def get_current_instance_name():
|
88 |
+
# URL for the metadata server
|
89 |
+
METADATA_URL = "http://metadata.google.internal/computeMetadata/v1/instance/name"
|
90 |
+
HEADERS = {"Metadata-Flavor": "Google"}
|
91 |
+
try:
|
92 |
+
response = requests.get(METADATA_URL, headers=HEADERS)
|
93 |
+
response.raise_for_status() # Raise an error for bad status codes
|
94 |
+
instance_name = response.text
|
95 |
+
return instance_name
|
96 |
+
except requests.exceptions.RequestException as e:
|
97 |
+
print(f"Error fetching instance name: {e}")
|
98 |
+
return None
|
helpers/required_classes.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from typing import List
|
3 |
+
import pandas as pd
|
4 |
+
import torch
|
5 |
+
import xgboost as xgb
|
6 |
+
from transformers import AutoTokenizer, BertForSequenceClassification
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
|
10 |
+
class BertEmbedder:
|
11 |
+
def __init__(self, tokenizer_path:str, model_path:str, cut_head:bool=False):
|
12 |
+
"""
|
13 |
+
cut_head = True if the model have classifier head
|
14 |
+
"""
|
15 |
+
self.embedder = BertForSequenceClassification.from_pretrained(model_path)
|
16 |
+
self.max_length = self.embedder.config.max_position_embeddings
|
17 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, max_length=self.max_length)
|
18 |
+
|
19 |
+
if cut_head:
|
20 |
+
self.embedder = self.embedder.bert
|
21 |
+
|
22 |
+
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
23 |
+
print(f"Used device for BERT: {self.device }", flush=True)
|
24 |
+
self.embedder.to(self.device)
|
25 |
+
|
26 |
+
def __call__(self, text: str):
|
27 |
+
encoded_input = self.tokenizer(text,
|
28 |
+
return_tensors='pt',
|
29 |
+
max_length=self.max_length,
|
30 |
+
padding=True,
|
31 |
+
truncation=True).to(self.device)
|
32 |
+
model_output = self.embedder(**encoded_input)
|
33 |
+
text_embed = model_output.pooler_output[0].cpu()
|
34 |
+
return text_embed
|
35 |
+
|
36 |
+
def batch_predict(self, texts: List[str]):
|
37 |
+
encoded_input = self.tokenizer(texts,
|
38 |
+
return_tensors='pt',
|
39 |
+
max_length=self.max_length,
|
40 |
+
padding=True,
|
41 |
+
truncation=True).to(self.device)
|
42 |
+
model_output = self.embedder(**encoded_input)
|
43 |
+
texts_embeds = model_output.pooler_output.cpu()
|
44 |
+
return texts_embeds
|
45 |
+
|
46 |
+
|
47 |
+
class PredictModel:
|
48 |
+
def __init__(self, embedder, classifier_code, classifier_group, batch_size=8):
|
49 |
+
self.batch_size = batch_size
|
50 |
+
self.embedder = embedder
|
51 |
+
self.classifier_code = classifier_code
|
52 |
+
self.classifier_group = classifier_group
|
53 |
+
|
54 |
+
def _texts2vecs(self, texts, logging=False):
|
55 |
+
embeds = []
|
56 |
+
batches_texts = np.array_split(texts, len(texts) // self.batch_size)
|
57 |
+
if logging:
|
58 |
+
iterator = tqdm(batches_texts)
|
59 |
+
else:
|
60 |
+
iterator = batches_texts
|
61 |
+
for batch_texts in iterator:
|
62 |
+
batch_texts = batch_texts.tolist()
|
63 |
+
embeds += self.embedder.batch_predict(batch_texts).tolist()
|
64 |
+
embeds = np.array(embeds)
|
65 |
+
return embeds
|
66 |
+
|
67 |
+
def fit(self, texts: List[str], labels: List[str], logging: bool=False):
|
68 |
+
if logging:
|
69 |
+
print('Start text2vec transform')
|
70 |
+
embeds = self._texts2vecs(texts, logging)
|
71 |
+
if logging:
|
72 |
+
print('Start codes-classifier fitting')
|
73 |
+
self.classifier_code.fit(embeds, labels)
|
74 |
+
labels = [l.split('.')[0] for l in labels]
|
75 |
+
if logging:
|
76 |
+
print('Start groups-classifier fitting')
|
77 |
+
self.classifier_group.fit(embeds, labels)
|
78 |
+
|
79 |
+
def predict_code(self, texts: List[str], log: bool=False):
|
80 |
+
if log:
|
81 |
+
print('Start text2vec transform')
|
82 |
+
embeds = self._texts2vecs(texts, log)
|
83 |
+
if log:
|
84 |
+
print('Start classifier prediction')
|
85 |
+
prediction = self.classifier_code.predict(embeds)
|
86 |
+
return prediction
|
87 |
+
|
88 |
+
def predict_group(self, texts: List[str], logging: bool=False):
|
89 |
+
if logging:
|
90 |
+
print('Start text2vec transform')
|
91 |
+
embeds = self._texts2vecs(texts, logging)
|
92 |
+
if logging:
|
93 |
+
print('Start classifier prediction')
|
94 |
+
prediction = self.classifier_group.predict(embeds)
|
95 |
+
return prediction
|
96 |
+
|
97 |
+
class CustomXGBoost:
|
98 |
+
def __init__(self, use_gpu):
|
99 |
+
if use_gpu:
|
100 |
+
self.model = xgb.XGBClassifier(tree_method="gpu_hist")
|
101 |
+
else:
|
102 |
+
self.model = xgb.XGBClassifier()
|
103 |
+
self.classes_ = None
|
104 |
+
|
105 |
+
def fit(self, X, y, **kwargs):
|
106 |
+
self.classes_ = np.unique(y).tolist()
|
107 |
+
y = [self.classes_.index(l) for l in y]
|
108 |
+
self.model.fit(X, y, **kwargs)
|
109 |
+
|
110 |
+
def predict_proba(self, X):
|
111 |
+
pred = self.model.predict_proba(X)
|
112 |
+
return pred
|
113 |
+
|
114 |
+
def predict(self, X):
|
115 |
+
preds = self.model.predict_proba(X)
|
116 |
+
return np.array([self.classes_[p] for p in np.argmax(preds, axis=1)])
|
117 |
+
|
118 |
+
class SimpleModel:
|
119 |
+
def __init__(self):
|
120 |
+
self.classes_ = None
|
121 |
+
|
122 |
+
def fit(self, X, y):
|
123 |
+
print(y[0])
|
124 |
+
self.classes_ = [y[0]]
|
125 |
+
|
126 |
+
def predict_proba(self, X):
|
127 |
+
return np.array([[1.0]] * len(X))
|
128 |
+
|
129 |
+
def balance_dataset(labels_train_for_group, vecs_train_for_group, balance=None, logging=True):
|
130 |
+
if balance == 'remove':
|
131 |
+
min_len = -1
|
132 |
+
for code_l in np.unique(labels_train_for_group):
|
133 |
+
cur_len = sum(labels_train_for_group==code_l)
|
134 |
+
if logging:
|
135 |
+
print(code_l, cur_len)
|
136 |
+
if min_len > cur_len or min_len==-1:
|
137 |
+
min_len = cur_len
|
138 |
+
if logging:
|
139 |
+
print('min_len is', min_len)
|
140 |
+
df_train_group = pd.DataFrame()
|
141 |
+
df_train_group['labels'] = labels_train_for_group
|
142 |
+
df_train_group['vecs'] = vecs_train_for_group.tolist()
|
143 |
+
df_train_group = df_train_group.groupby('labels', as_index=False).apply(lambda array: array.loc[np.random.choice(array.index, min_len, False),:])
|
144 |
+
labels_train_for_group = df_train_group['labels'].values
|
145 |
+
vecs_train_for_group = [np.array(v) for v in df_train_group['vecs'].values]
|
146 |
+
|
147 |
+
elif balance == 'duplicate':
|
148 |
+
df_train_group = pd.DataFrame()
|
149 |
+
df_train_group['labels'] = labels_train_for_group
|
150 |
+
df_train_group['vecs'] = vecs_train_for_group.tolist()
|
151 |
+
max_len = 0
|
152 |
+
for code_data in df_train_group.groupby('labels'):
|
153 |
+
cur_len = len(code_data[1])
|
154 |
+
if logging:
|
155 |
+
print(code_data[0], cur_len)
|
156 |
+
if max_len < cur_len:
|
157 |
+
max_len = cur_len
|
158 |
+
if logging:
|
159 |
+
print('max_len is ', max_len)
|
160 |
+
labels_train_for_group = []
|
161 |
+
vecs_train_for_group = []
|
162 |
+
for code_data in df_train_group.groupby('labels'):
|
163 |
+
cur_len = len(code_data[1])
|
164 |
+
cur_labels = code_data[1]['labels'].values.tolist()
|
165 |
+
cur_vecs = code_data[1]['vecs'].values.tolist()
|
166 |
+
while cur_len < max_len:
|
167 |
+
cur_len *= 2
|
168 |
+
cur_labels += cur_labels
|
169 |
+
cur_vecs += cur_vecs
|
170 |
+
cur_labels = cur_labels[:max_len]
|
171 |
+
cur_vecs = cur_vecs[:max_len]
|
172 |
+
labels_train_for_group += cur_labels
|
173 |
+
vecs_train_for_group += cur_vecs
|
174 |
+
|
175 |
+
labels_train_for_group = np.array(labels_train_for_group)
|
176 |
+
vecs_train_for_group = np.array(vecs_train_for_group)
|
177 |
+
return labels_train_for_group, vecs_train_for_group
|
helpers/trainer_classifiers.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from sklearn.metrics import accuracy_score, fbeta_score, confusion_matrix, ConfusionMatrixDisplay
|
5 |
+
from sklearn.utils.class_weight import compute_sample_weight
|
6 |
+
import pickle as pkl
|
7 |
+
from tqdm import tqdm
|
8 |
+
import time
|
9 |
+
import os
|
10 |
+
import shutil
|
11 |
+
import json
|
12 |
+
from copy import deepcopy
|
13 |
+
|
14 |
+
from helpers.required_classes import *
|
15 |
+
|
16 |
+
|
17 |
+
def log(*args):
|
18 |
+
print(*args, flush=True)
|
19 |
+
|
20 |
+
def train_code_classifier(vecs_train_codes, vecs_test_for_groups,
|
21 |
+
labels_train_codes, labels_test_groups_codes, labels_test_groups_groups,
|
22 |
+
labels_train_groups,
|
23 |
+
models_folder, group_name, balance=None, logging=True, use_gpu=True):
|
24 |
+
"""
|
25 |
+
balance - is a type of balancing dataset:
|
26 |
+
remove - remove items per class until amount texts per clas is not the same as minimum amount
|
27 |
+
duplicate - duplicate items per class until amount texts per clas is not the same as maximum amount
|
28 |
+
weight - weighted training model
|
29 |
+
None - without any balancing method
|
30 |
+
"""
|
31 |
+
|
32 |
+
log(f"training model for codes classifiers in group {group_name}")
|
33 |
+
|
34 |
+
# create / remove folder
|
35 |
+
experiment_path = f"{models_folder}/{group_name}"
|
36 |
+
if not os.path.exists(experiment_path):
|
37 |
+
os.makedirs(experiment_path, exist_ok=True)
|
38 |
+
else:
|
39 |
+
shutil.rmtree(experiment_path)
|
40 |
+
os.makedirs(experiment_path, exist_ok=True)
|
41 |
+
|
42 |
+
labels_train_for_group = labels_train_codes[labels_train_groups==group_name]
|
43 |
+
if logging:
|
44 |
+
log(f"e.g. labels in the group: {labels_train_for_group[:3]} cng of codes: {len(np.unique(labels_train_for_group))} cnt of texts: {len(labels_train_for_group)}")
|
45 |
+
|
46 |
+
# prepare train labels
|
47 |
+
if len(np.unique(labels_train_for_group)) < 2:
|
48 |
+
# if group have only one code inside
|
49 |
+
code_name = labels_train_for_group[0]
|
50 |
+
if logging:
|
51 |
+
log(f'group {group_name} have only one code inside {code_name}')
|
52 |
+
simple_clf = SimpleModel()
|
53 |
+
simple_clf.fit([], [code_name])
|
54 |
+
pkl.dump(simple_clf, open(f"{experiment_path}/{group_name}_code_clf.pkl", 'wb'))
|
55 |
+
return {"f1_score": 'one_cls', "accuracy": 'one_cls'}
|
56 |
+
|
57 |
+
sample_weights = compute_sample_weight(
|
58 |
+
class_weight='balanced',
|
59 |
+
y=labels_train_for_group
|
60 |
+
)
|
61 |
+
|
62 |
+
# prepare other data
|
63 |
+
vecs_train_for_group = vecs_train_codes[labels_train_groups==group_name]
|
64 |
+
vecs_test_for_group = vecs_test_for_groups[labels_test_groups_groups==group_name]
|
65 |
+
labels_test_for_group = labels_test_groups_codes[labels_test_groups_groups==group_name]
|
66 |
+
|
67 |
+
labels_train_for_group, vecs_train_for_group = balance_dataset(
|
68 |
+
labels_train_for_group, vecs_train_for_group, balance=balance
|
69 |
+
)
|
70 |
+
|
71 |
+
fit_start_time = time.time()
|
72 |
+
model = CustomXGBoost(use_gpu)
|
73 |
+
|
74 |
+
if balance == 'weight':
|
75 |
+
model.fit(vecs_train_for_group, labels_train_for_group, sample_weight=sample_weights)
|
76 |
+
else:
|
77 |
+
model.fit(vecs_train_for_group, labels_train_for_group)
|
78 |
+
|
79 |
+
pkl.dump(model, open(f"{experiment_path}/{group_name}_code_clf.pkl", 'wb'))
|
80 |
+
if logging:
|
81 |
+
log(f'Trained in {time.time() - fit_start_time}s')
|
82 |
+
|
83 |
+
pred_start_time = time.time()
|
84 |
+
predictions_group = model.predict(vecs_test_for_group)
|
85 |
+
scores = {
|
86 |
+
"f1_score": fbeta_score(labels_test_for_group, predictions_group, beta=1, average='macro'),
|
87 |
+
"accuracy": accuracy_score(labels_test_for_group, predictions_group)
|
88 |
+
}
|
89 |
+
if logging:
|
90 |
+
log(scores, f'Predicted in {time.time() - pred_start_time}s')
|
91 |
+
with open(f"{experiment_path}/{group_name}_scores.json", 'w') as f:
|
92 |
+
f.write(json.dumps(scores))
|
93 |
+
|
94 |
+
conf_matrix = confusion_matrix(labels_test_for_group, predictions_group)
|
95 |
+
disp_code = ConfusionMatrixDisplay(confusion_matrix=conf_matrix,
|
96 |
+
display_labels=model.classes_, )
|
97 |
+
fig, ax = plt.subplots(figsize=(5,5))
|
98 |
+
disp_code.plot(ax=ax)
|
99 |
+
plt.xticks(rotation=90)
|
100 |
+
plt.savefig(f"{experiment_path}/{group_name}_matrix.png")
|
101 |
+
|
102 |
+
return scores
|
103 |
+
|
104 |
+
def train_codes_for_groups(vecs_train_codes, vecs_test_groups,
|
105 |
+
labels_train_codes, labels_test_groups_codes, labels_test_groups_groups,
|
106 |
+
labels_train_groups,
|
107 |
+
output_path, logging, use_gpu=True):
|
108 |
+
all_scores = []
|
109 |
+
for group_name in tqdm(np.unique(labels_train_groups)):
|
110 |
+
row = {'group': group_name}
|
111 |
+
for balanced_method in ['weight']: # [None, 'remove', 'weight', 'duplicate']:
|
112 |
+
if logging:
|
113 |
+
log('\n', '-'*50)
|
114 |
+
scores = train_code_classifier(vecs_train_codes, vecs_test_groups,
|
115 |
+
labels_train_codes, labels_test_groups_codes, labels_test_groups_groups,
|
116 |
+
labels_train_groups,
|
117 |
+
output_path, group_name, balanced_method, logging, use_gpu)
|
118 |
+
scores = {f"{balanced_method}_{k}": v for k, v in scores.items()}
|
119 |
+
row.update(scores)
|
120 |
+
all_scores.append(row)
|
121 |
+
|
122 |
+
df = pd.DataFrame(all_scores)
|
123 |
+
columns = df.columns.tolist()
|
124 |
+
columns.remove('group')
|
125 |
+
mean_scores = {'group': 'MEAN'}
|
126 |
+
for score_name in columns:
|
127 |
+
mean_score = df[df[score_name] != 'one_cls'][score_name].mean()
|
128 |
+
mean_scores.update({score_name: float(mean_score)})
|
129 |
+
df = pd.concat([df, pd.DataFrame([mean_scores])], ignore_index=True)
|
130 |
+
return df
|
131 |
+
|
132 |
+
def make_experiment_classifier(vecs_train_codes, vecs_test_codes, vecs_test_group,
|
133 |
+
labels_train_codes, labels_test_codes,
|
134 |
+
labels_test_groups, labels_train_groups,
|
135 |
+
sample_weights_codes, sample_weights_groups,
|
136 |
+
texts_test_codes, texts_test_groups,
|
137 |
+
experiment_name, classifier_model_code, classifier_model_group, experiment_path, balance=None):
|
138 |
+
# train different models as base model for group and codes
|
139 |
+
|
140 |
+
log(f'Model: {experiment_name}')
|
141 |
+
# create / remove experiment folder
|
142 |
+
experiment_path = f"{experiment_path}/{experiment_name}"
|
143 |
+
if not os.path.exists(experiment_path):
|
144 |
+
os.makedirs(experiment_path, exist_ok=True)
|
145 |
+
else:
|
146 |
+
shutil.rmtree(experiment_path)
|
147 |
+
os.makedirs(experiment_path, exist_ok=True)
|
148 |
+
|
149 |
+
# fit the models
|
150 |
+
cls_codes = deepcopy(classifier_model_code)
|
151 |
+
cls_groups = deepcopy(classifier_model_group)
|
152 |
+
|
153 |
+
labels_train_codes_balanced, vecs_train_codes_balanced = balance_dataset(
|
154 |
+
labels_train_codes, vecs_train_codes, balance=balance
|
155 |
+
)
|
156 |
+
labels_train_groups_balanced, vecs_train_codes_balanced = balance_dataset(
|
157 |
+
labels_train_groups, vecs_train_codes, balance=balance
|
158 |
+
)
|
159 |
+
|
160 |
+
log('start training base model')
|
161 |
+
if balance == 'weight':
|
162 |
+
try:
|
163 |
+
start_time = time.time()
|
164 |
+
cls_codes.fit(vecs_train_codes_balanced, labels_train_codes_balanced, sample_weight=sample_weights_codes)
|
165 |
+
log(f'codes classify trained in {(time.time() - start_time) / 60}m')
|
166 |
+
start_time = time.time()
|
167 |
+
cls_groups.fit(vecs_train_codes_balanced, labels_train_groups_balanced, sample_weight=sample_weights_groups)
|
168 |
+
log(f'groups classify trained in {(time.time() - start_time) / 60}m')
|
169 |
+
except Exception as e:
|
170 |
+
log(str(e))
|
171 |
+
start_time = time.time()
|
172 |
+
cls_codes.fit(vecs_train_codes_balanced, labels_train_codes_balanced)
|
173 |
+
log(f'codes classify trained in {(time.time() - start_time) / 60}m')
|
174 |
+
start_time = time.time()
|
175 |
+
cls_groups.fit(vecs_train_codes_balanced, labels_train_groups_balanced)
|
176 |
+
log(f'groups classify trained in {(time.time() - start_time) / 60}m')
|
177 |
+
else:
|
178 |
+
start_time = time.time()
|
179 |
+
cls_codes.fit(vecs_train_codes_balanced, labels_train_codes_balanced)
|
180 |
+
log(f'codes classify trained in {(time.time() - start_time) / 60}m')
|
181 |
+
start_time = time.time()
|
182 |
+
cls_groups.fit(vecs_train_codes_balanced, labels_train_groups_balanced)
|
183 |
+
log(f'groups classify trained in {(time.time() - start_time) / 60}m')
|
184 |
+
|
185 |
+
pkl.dump(cls_codes, open(f"{experiment_path}/{experiment_name}_codes.pkl", 'wb'))
|
186 |
+
pkl.dump(cls_groups, open(f"{experiment_path}/{experiment_name}_groups.pkl", 'wb'))
|
187 |
+
|
188 |
+
# inference the model
|
189 |
+
predictions_code = cls_codes.predict(vecs_test_codes)
|
190 |
+
predictions_group = cls_groups.predict(vecs_test_group)
|
191 |
+
scores = {
|
192 |
+
"f1_score_code": fbeta_score(labels_test_codes, predictions_code, beta=1, average='macro'),
|
193 |
+
"f1_score_group": fbeta_score(labels_test_groups, predictions_group, beta=1, average='macro'),
|
194 |
+
"accuracy_code": accuracy_score(labels_test_codes, predictions_code),
|
195 |
+
"accuracy_group": accuracy_score(labels_test_groups, predictions_group)
|
196 |
+
}
|
197 |
+
with open(f"{experiment_path}/{experiment_name}_scores.json", 'w') as f:
|
198 |
+
f.write(json.dumps(scores))
|
199 |
+
|
200 |
+
conf_matrix = confusion_matrix(labels_test_codes, predictions_code)
|
201 |
+
disp_code = ConfusionMatrixDisplay(confusion_matrix=conf_matrix,
|
202 |
+
display_labels=cls_codes.classes_, )
|
203 |
+
fig, ax = plt.subplots(figsize=(20,20))
|
204 |
+
disp_code.plot(ax=ax)
|
205 |
+
plt.xticks(rotation=90)
|
206 |
+
plt.savefig(f"{experiment_path}/{experiment_name}_codes_matrix.png")
|
207 |
+
|
208 |
+
conf_matrix = confusion_matrix(labels_test_groups, predictions_group)
|
209 |
+
disp_group = ConfusionMatrixDisplay(confusion_matrix=conf_matrix,
|
210 |
+
display_labels=cls_groups.classes_, )
|
211 |
+
|
212 |
+
fig, ax = plt.subplots(figsize=(20,20))
|
213 |
+
disp_group.plot(ax=ax)
|
214 |
+
plt.xticks(rotation=90)
|
215 |
+
plt.savefig(f"{experiment_path}/{experiment_name}_groups_matrix.png")
|
216 |
+
|
217 |
+
pd.DataFrame({'codes': predictions_code, 'truth': labels_test_codes, 'text': texts_test_codes}).to_csv(f"{experiment_path}/{experiment_name}_pred_codes.csv")
|
218 |
+
pd.DataFrame({'groups': predictions_group, 'truth': labels_test_groups, 'text': texts_test_groups}).to_csv(f"{experiment_path}/{experiment_name}_pred_groups.csv")
|
219 |
+
return predictions_code, predictions_group, scores
|
220 |
+
|
221 |
+
def train_base_clfs(classifiers, vecs_train_codes, vecs_test_codes, vecs_test_group,
|
222 |
+
labels_train_codes, labels_test_codes,
|
223 |
+
labels_test_groups_codes, labels_test_groups_groups, labels_train_groups,
|
224 |
+
sample_weights_codes, sample_weights_groups,
|
225 |
+
texts_test_codes, texts_test_groups, output_path):
|
226 |
+
results = ''
|
227 |
+
for experiment_data in classifiers:
|
228 |
+
for balanced_method in ['weight']:
|
229 |
+
exp_name = experiment_data['name']
|
230 |
+
cls_model = experiment_data['model']
|
231 |
+
_, _, scores = make_experiment_classifier(vecs_train_codes, vecs_test_codes, vecs_test_group,
|
232 |
+
labels_train_codes, labels_test_codes,
|
233 |
+
labels_test_groups_groups, labels_train_groups,
|
234 |
+
sample_weights_codes, sample_weights_groups,
|
235 |
+
texts_test_codes, texts_test_groups,
|
236 |
+
exp_name, cls_model, cls_model, output_path, balance=None)
|
237 |
+
res = f"\n\n{exp_name} balanced by: {balanced_method} scores: {scores}"
|
238 |
+
results += res
|
239 |
+
log(res)
|
240 |
+
return results
|
helpers/trainer_embedder.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score
|
4 |
+
from transformers import TrainingArguments, Trainer
|
5 |
+
from transformers import EarlyStoppingCallback
|
6 |
+
import pickle as pkl
|
7 |
+
from datetime import datetime
|
8 |
+
|
9 |
+
|
10 |
+
class Dataset(torch.utils.data.Dataset):
|
11 |
+
def __init__(self, encodings, labels=None):
|
12 |
+
self.encodings = encodings
|
13 |
+
self.labels = labels
|
14 |
+
|
15 |
+
def __getitem__(self, idx):
|
16 |
+
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
|
17 |
+
item["labels"] = torch.tensor(self.labels[idx])
|
18 |
+
return item
|
19 |
+
|
20 |
+
def __len__(self):
|
21 |
+
return len(self.encodings["input_ids"])
|
22 |
+
|
23 |
+
def compute_metrics(p):
|
24 |
+
pred, labels = p
|
25 |
+
pred = np.argmax(pred, axis=1)
|
26 |
+
|
27 |
+
accuracy = accuracy_score(y_true=labels, y_pred=pred)
|
28 |
+
recall = recall_score(y_true=labels, y_pred=pred, average='macro', zero_division=0)
|
29 |
+
precision = precision_score(y_true=labels, y_pred=pred, average='macro', zero_division=0)
|
30 |
+
f1 = f1_score(y_true=labels, y_pred=pred, average="macro", zero_division=0)
|
31 |
+
|
32 |
+
return {"eval_accuracy": accuracy, "eval_precision": precision, "eval_recall": recall, "eval_f1": f1}
|
33 |
+
|
34 |
+
def train(model, train_dataset, val_dataset, output_dir, save_steps, num_train_epochs=10):
|
35 |
+
args = TrainingArguments(
|
36 |
+
output_dir=output_dir,
|
37 |
+
overwrite_output_dir=True,
|
38 |
+
evaluation_strategy="steps",
|
39 |
+
eval_steps=save_steps,
|
40 |
+
per_device_train_batch_size=16,
|
41 |
+
per_device_eval_batch_size=16,
|
42 |
+
num_train_epochs=num_train_epochs,
|
43 |
+
seed=0,
|
44 |
+
save_steps=save_steps,
|
45 |
+
save_total_limit=2,
|
46 |
+
load_best_model_at_end=True,
|
47 |
+
metric_for_best_model='eval_f1'
|
48 |
+
)
|
49 |
+
trainer = Trainer(
|
50 |
+
model=model,
|
51 |
+
args=args,
|
52 |
+
train_dataset=train_dataset,
|
53 |
+
eval_dataset=val_dataset,
|
54 |
+
compute_metrics=compute_metrics,
|
55 |
+
callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]
|
56 |
+
)
|
57 |
+
|
58 |
+
res = trainer.train()
|