File size: 635 Bytes
3c8fdbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def model_predict(model, tokenizer, sentences):
    """
    Predict the labels of the sentences using the model and tokenizer
    Args:
        model: Model (transformers)
        tokenizer: Tokenizer (transformers tokenizer)
        sentences: Sentences to predict (ndarray)
    Returns:
        predictions: Predicted labels
    """


    inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt", max_length=512).to(device)
    # Classify sentences
    with torch.no_grad():
        outputs = model(**inputs) # get the logits
        label = np.argmax(outputs.logits.to("cpu"))
        
    return str(labels)