taka-yamakoshi commited on
Commit
d82123b
·
1 Parent(s): 75f767b
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -5,8 +5,7 @@ import streamlit as st
5
  import matplotlib.pyplot as plt
6
  import seaborn as sns
7
 
8
- import torch
9
- import torch.nn.functional as F
10
 
11
  from transformers import AlbertTokenizer
12
 
@@ -50,12 +49,12 @@ if __name__=='__main__':
50
  tokenizer,model = load_model()
51
  mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
52
 
53
- input_ids = tokenizer('This is a sample sentence.',return_tensors='pt').input_ids
54
  input_ids[0][4] = mask_id
55
 
56
  with torch.no_grad():
57
  outputs = model(input_ids)
58
- logprobs = F.log_softmax(outputs.logits, dim = -1)
59
  st.write(logprobs.shape)
60
- preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1).item() for probs in logprobs[0]]
61
  st.write([tokenizer.decode([token]) for token in preds])
 
5
  import matplotlib.pyplot as plt
6
  import seaborn as sns
7
 
8
+ import jax
 
9
 
10
  from transformers import AlbertTokenizer
11
 
 
49
  tokenizer,model = load_model()
50
  mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
51
 
52
+ input_ids = tokenizer('This is a sample sentence.',return_tensors='np').input_ids
53
  input_ids[0][4] = mask_id
54
 
55
  with torch.no_grad():
56
  outputs = model(input_ids)
57
+ logprobs = jax.nn.log_softmax(outputs.logits, axis = -1)
58
  st.write(logprobs.shape)
59
+ preds = [np.choice(np.arange(len(probs)),p=np.exp(probs)) for probs in logprobs[0]]
60
  st.write([tokenizer.decode([token]) for token in preds])