Cielciel's picture
Cielciel/aift-model-review-multiple-label-classification
bbc5ecf
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the \"License\");
# you may not use this file except in compliance with the License.\n",
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an \"AS IS\" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import datetime
from google.cloud import storage
from transformers import AutoTokenizer
from datasets import load_dataset, load_metric, ReadInstruction
from trainer import metadata
def preprocess_function(examples):
tokenizer = AutoTokenizer.from_pretrained(
metadata.PRETRAINED_MODEL_NAME,
use_fast=True,
)
# Tokenize the texts
tokenizer_args = (
(examples['text'],)
)
result = tokenizer(*tokenizer_args,
padding='max_length',
max_length=metadata.MAX_SEQ_LENGTH,
truncation=True)
# TEMP: We can extract this automatically but Unique method of the dataset
# is not reporting the label -1 which shows up in the pre-processing
# Hence the additional -1 term in the dictionary
label_to_id = metadata.TARGET_LABELS
# Map labels to IDs (not necessary for GLUE tasks)
if label_to_id is not None and "label" in examples:
result["label"] = [label_to_id[l] for l in examples["label"]]
return result
def load_data(args):
"""Loads the data into two different data loaders. (Train, Test)
Args:
args: arguments passed to the python script
"""
# Dataset loading repeated here to make this cell idempotent
# Since we are over-writing datasets variable
dataset = load_dataset(metadata.DATASET_NAME)
dataset = dataset.map(preprocess_function,
batched=True,
load_from_cache_file=True)
train_dataset, test_dataset = dataset["train"], dataset["test"]
return train_dataset, test_dataset
def save_model(args):
"""Saves the model to Google Cloud Storage or local file system
Args:
args: contains name for saved model.
"""
scheme = 'gs://'
if args.job_dir.startswith(scheme):
job_dir = args.job_dir.split("/")
bucket_name = job_dir[2]
object_prefix = "/".join(job_dir[3:]).rstrip("/")
if object_prefix:
model_path = '{}/{}'.format(object_prefix, args.model_name)
else:
model_path = '{}'.format(args.model_name)
bucket = storage.Client().bucket(bucket_name)
local_path = os.path.join("/tmp", args.model_name)
files = [f for f in os.listdir(local_path) if os.path.isfile(os.path.join(local_path, f))]
for file in files:
local_file = os.path.join(local_path, file)
blob = bucket.blob("/".join([model_path, file]))
blob.upload_from_filename(local_file)
print(f"Saved model files in gs://{bucket_name}/{model_path}")
else:
print(f"Saved model files at {os.path.join('/tmp', args.model_name)}")
print(f"To save model files in GCS bucket, please specify job_dir starting with gs://")