|
import os |
|
import sys |
|
import utils |
|
import datasets |
|
import eval_utils |
|
from constants import DIALECTS_WITH_LABELS |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
model_name = sys.argv[1] |
|
commit_id = sys.argv[2] |
|
inference_function = sys.argv[3] |
|
|
|
utils.update_model_queue( |
|
repo_id=os.environ["PREDICTIONS_DATASET_NAME"], |
|
model_name=model_name, |
|
commit_id=commit_id, |
|
inference_function=inference_function, |
|
status="in_progress", |
|
) |
|
|
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, revision=commit_id) |
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
model_name, revision=commit_id |
|
) |
|
|
|
|
|
dataset_name = os.environ["DATASET_NAME"] |
|
dataset = datasets.load_dataset(dataset_name)["test"] |
|
|
|
sentences = dataset["sentence"] |
|
labels = {dialect: dataset[dialect] for dialect in DIALECTS_WITH_LABELS} |
|
|
|
predictions = [] |
|
for i, sentence in enumerate(sentences): |
|
predictions.append( |
|
getattr(eval_utils, inference_function)(model, tokenizer, sentence) |
|
) |
|
print( |
|
f"Inference progress ({model_name}, {inference_function}): {round(100 * (i + 1) / len(sentences), 1)}%" |
|
) |
|
|
|
|
|
utils.upload_predictions( |
|
os.environ["PREDICTIONS_DATASET_NAME"], |
|
predictions, |
|
model_name, |
|
commit_id, |
|
inference_function, |
|
) |
|
|
|
print(f"Inference completed!") |
|
|
|
except Exception as e: |
|
print(f"An error occurred during inference of {model_name}: {e}") |
|
utils.update_model_queue( |
|
repo_id=os.environ["PREDICTIONS_DATASET_NAME"], |
|
model_name=model_name, |
|
commit_id=commit_id, |
|
inference_function=inference_function, |
|
status="failed (online)", |
|
) |
|
|