File size: 841 Bytes
2432bca
 
 
 
 
b4c12f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2432bca
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
---
library_name: transformers
tags: []
---

```py
def scaling(x, min_x, max_x, r1, r2):
    # Scale data x (n_samples x 1) to [r1, r2]
    x_s = x
    x_s = (x_s - min_x) * (r2 - r1) / (max_x - min_x)
    x_s = r1 + x_s
    return x_s

def descaling(x_s, min_x, max_x, r1, r2):
    # Re-scale data x (n_samples x 1) to [min_x, max_x]
    x = x_s
    x = (x - r1) * (max_x - min_x) / (r2 - r1) + min_x
    return x

# Inference example
with torch.no_grad():
    x = "They are equally important, absolutely, and just as real as each other."
    x = tokenizer([x], return_tensors="pt", add_special_tokens=True, padding=True)
    y_hat = model(**x.to(device)).logits
    y_hat = torch.tanh(y_hat).cpu()

l_hat = descaling(y_hat, 1, 7, -1, 1)[0].numpy() 
print(l_hat)
# [C, O, E, A, S]
# [6.0583944 4.4941516 1.6538751 5.5261126 4.725995 ]
```