MLADI / background_inference.py
AMR-KELEG's picture
Skip the results of incomplete evaluations
013e3f5
raw
history blame
1.82 kB
import os
import sys
import utils
import datasets
import eval_utils
from constants import DIALECTS_WITH_LABELS
from transformers import AutoTokenizer, AutoModelForSequenceClassification
model_name = sys.argv[1]
commit_id = sys.argv[2]
inference_function = sys.argv[3]
utils.update_model_queue(
repo_id=os.environ["PREDICTIONS_DATASET_NAME"],
model_name=model_name,
commit_id=commit_id,
inference_function=inference_function,
status="in_progress",
)
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, revision=commit_id)
model = AutoModelForSequenceClassification.from_pretrained(
model_name, revision=commit_id
)
# Load the dataset
dataset_name = os.environ["DATASET_NAME"]
dataset = datasets.load_dataset(dataset_name)["test"]
sentences = dataset["sentence"]
labels = {dialect: dataset[dialect] for dialect in DIALECTS_WITH_LABELS}
predictions = []
for i, sentence in enumerate(sentences):
predictions.append(
getattr(eval_utils, inference_function)(model, tokenizer, sentence)
)
print(
f"Inference progress ({model_name}, {inference_function}): {round(100 * (i + 1) / len(sentences), 1)}%"
)
# Store the predictions in a private dataset
utils.upload_predictions(
os.environ["PREDICTIONS_DATASET_NAME"],
predictions,
model_name,
commit_id,
inference_function,
)
print(f"Inference completed!")
except Exception as e:
print(f"An error occurred during inference of {model_name}: {e}")
utils.update_model_queue(
repo_id=os.environ["PREDICTIONS_DATASET_NAME"],
model_name=model_name,
commit_id=commit_id,
inference_function=inference_function,
status="failed (online)",
)