--- license: mit language: - en metrics: - accuracy - mse - f1 base_model: - dmis-lab/biobert-base-cased-v1.2 - google-bert/bert-base-cased pipeline_tag: text-classification model-index: - name: bert-causation-rating-dr1 results: - task: type: text-classification dataset: name: rating_dr1 type: dataset metrics: - name: off by 1 accuracy type: accuracy value: 71.1864406779661 - name: mean squared error for ordinal data type: mse value: 0.7796610169491526 - name: weighted F1 score type: f1 value: 0.7164302155262606 - name: Kendall's tau coefficient type: Kendall's tau value: 0.8013637826548397 source: name: Keling Wang url: https://github.com/Keling-Wang datasets: - kelingwang/causation_strength_rating --- # Model description This `bert-causation-rating-dr1` model is a fine-tuned [biobert-base-cased-v1.2](https://huggingface.co/dmis-lab/biobert-base-cased-v1.2) model on a small set of manually annotated texts with causation labels. This model is tasked with classifying a sentence into different levels of strength of causation expressed in this sentence. Before tuning on this dataset, the `biobert-base-cased-v1.2` model is fine-tuned on a dataset containing causation labels from a published paper. This model starts from pre-trained [`kelingwang/bert-causation-rating-pubmed`](https://huggingface.co/kelingwang/bert-causation-rating-pubmed). For more information please view the link and my [GitHub page](https://github.com/Keling-Wang/causation_rating). The sentences in the dataset were rated independently by two researchers. This `dr1` version is tuned on the set of sentences with labels rated by Rater 1. # Intended use and limitations This model is primarily used to rate for the strength of expressed causation in a sentence extracted from a clinical guideline in the field of diabetes mellitus management. This model predicts strength of causation (SoC) labels based on the text inputs as: * -1: No correlation or variable relationships mentioned in the sentence. * 0: There is correlational relationships but not causation in the sentence. * 1: The sentence expresses weak causation. * 2: The sentence expresses moderate causation. * 3: The sentence expresses strong causation. *NOTE:* The model output is five one-hot logits and will be 0-index based, and the labels will be 0 to 4. It is good to use [this `python` module](https://github.com/Keling-Wang/causation_rating/blob/main/tests/prediction_from_pretrained.py) if one wants to make predictions. # Performance and hyperparameters ## Test metrics This model achieves the following results on the test dataset. The test dataset is a 25% held-out stratified split of the entire dataset with `SEED=114514`. * Loss: 5.2014 * Off-by-1 accuracy: 71.1864 * Off-by-2 accuracy: 90.6780 * MSE for ordinal data: 0.7797 * Weighted F1: 0.7164 * Kendall's Tau: 0.8014 This performance is achieved with the following hyperparameters: * Learning rate: 7.94278e-05 * Weight decay: 0.111616 * Warmup ratio: 0.301057 * Power of polynomial learning rate scheduler: 2.619975 * Power to the distance measure used in the loss function \alpha: 2.0 ## Hyperparameter tuning metrics During the Bayesian optimization procedure for hyperparameter tuning, this model achieves the best target metric (Off-by-1 accuracy) of *99.1147*, as the result from 4-fold cross-validation procedure based on best hyperparameters. # Training settings The following training configurations apply: * Pre-trained model: `kelingwang/bert-causation-rating-pubmed` * `seed`: 114514 * `batch_size`: 128 * `epoch`: 8 * `max_length` in `torch.utils.data.Dataset`: 128 * Loss function: the [OLL loss](https://aclanthology.org/2022.coling-1.407/) with a tunable hyperparameter \alpha (Power to the distance measure used in the loss function). * `lr`: 7.94278e-05 * `weight_decay`: 0.111616 * `warmup_ratio`: 0.301057 * `lr_scheduler_type`: polynomial * `lr_scheduler_kwargs`: `{"power": 2.619975, "lr_end": 1e-8}` * Power to the distance measure used in the loss function \alpha: 2.0 # Framework versions and devices This model is run on a NVIDIA P100 CPU provided by Kaggle. Framework versions are: * python==3.10.14 * cuda==12.4 * NVIDIA-SMI==550.90.07 * torch=2.4.0 * transformers==4.45.1 * scikit-learn==1.2.2 * optuna==4.0.0 * nlpaug==1.1.11