Spaces:
Runtime error
Runtime error
taka-yamakoshi
commited on
Commit
·
d82123b
1
Parent(s):
75f767b
jax
Browse files
app.py
CHANGED
@@ -5,8 +5,7 @@ import streamlit as st
|
|
5 |
import matplotlib.pyplot as plt
|
6 |
import seaborn as sns
|
7 |
|
8 |
-
import
|
9 |
-
import torch.nn.functional as F
|
10 |
|
11 |
from transformers import AlbertTokenizer
|
12 |
|
@@ -50,12 +49,12 @@ if __name__=='__main__':
|
|
50 |
tokenizer,model = load_model()
|
51 |
mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
|
52 |
|
53 |
-
input_ids = tokenizer('This is a sample sentence.',return_tensors='
|
54 |
input_ids[0][4] = mask_id
|
55 |
|
56 |
with torch.no_grad():
|
57 |
outputs = model(input_ids)
|
58 |
-
logprobs =
|
59 |
st.write(logprobs.shape)
|
60 |
-
preds = [
|
61 |
st.write([tokenizer.decode([token]) for token in preds])
|
|
|
5 |
import matplotlib.pyplot as plt
|
6 |
import seaborn as sns
|
7 |
|
8 |
+
import jax
|
|
|
9 |
|
10 |
from transformers import AlbertTokenizer
|
11 |
|
|
|
49 |
tokenizer,model = load_model()
|
50 |
mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
|
51 |
|
52 |
+
input_ids = tokenizer('This is a sample sentence.',return_tensors='np').input_ids
|
53 |
input_ids[0][4] = mask_id
|
54 |
|
55 |
with torch.no_grad():
|
56 |
outputs = model(input_ids)
|
57 |
+
logprobs = jax.nn.log_softmax(outputs.logits, axis = -1)
|
58 |
st.write(logprobs.shape)
|
59 |
+
preds = [np.choice(np.arange(len(probs)),p=np.exp(probs)) for probs in logprobs[0]]
|
60 |
st.write([tokenizer.decode([token]) for token in preds])
|