File size: 1,147 Bytes
f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import functools
import gc
import json
import os
import pickle
import sys
import warnings
from enum import Enum
from itertools import chain
from typing import Dict, List, Optional, Union
import numpy as np
import optuna
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from datasets import load_from_disk
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, roc_curve
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader
from transformers import (
AdamW,
BatchEncoding,
BertConfig,
BertModel,
DataCollatorForTokenClassification,
SpecialTokensMixin,
get_cosine_schedule_with_warmup,
get_linear_schedule_with_warmup,
get_scheduler,
)
from transformers.utils import logging, to_py_obj
from .collators import DataCollatorForMultitaskCellClassification
# local modules
from .data import get_data_loader, preload_and_process_data
from .model import GeneformerMultiTask
from .optuna_utils import create_optuna_study
from .utils import save_model
|