File size: 1,817 Bytes
5a3355b c147e35 5a3355b c147e35 5a3355b c147e35 013e3f5 5a3355b c147e35 5a3355b c147e35 84916fc 5a3355b c147e35 5a3355b c147e35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import os
import sys
import utils
import datasets
import eval_utils
from constants import DIALECTS_WITH_LABELS
from transformers import AutoTokenizer, AutoModelForSequenceClassification
model_name = sys.argv[1]
commit_id = sys.argv[2]
inference_function = sys.argv[3]
utils.update_model_queue(
repo_id=os.environ["PREDICTIONS_DATASET_NAME"],
model_name=model_name,
commit_id=commit_id,
inference_function=inference_function,
status="in_progress",
)
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, revision=commit_id)
model = AutoModelForSequenceClassification.from_pretrained(
model_name, revision=commit_id
)
# Load the dataset
dataset_name = os.environ["DATASET_NAME"]
dataset = datasets.load_dataset(dataset_name)["test"]
sentences = dataset["sentence"]
labels = {dialect: dataset[dialect] for dialect in DIALECTS_WITH_LABELS}
predictions = []
for i, sentence in enumerate(sentences):
predictions.append(
getattr(eval_utils, inference_function)(model, tokenizer, sentence)
)
print(
f"Inference progress ({model_name}, {inference_function}): {round(100 * (i + 1) / len(sentences), 1)}%"
)
# Store the predictions in a private dataset
utils.upload_predictions(
os.environ["PREDICTIONS_DATASET_NAME"],
predictions,
model_name,
commit_id,
inference_function,
)
print(f"Inference completed!")
except Exception as e:
print(f"An error occurred during inference of {model_name}: {e}")
utils.update_model_queue(
repo_id=os.environ["PREDICTIONS_DATASET_NAME"],
model_name=model_name,
commit_id=commit_id,
inference_function=inference_function,
status="failed (online)",
)
|