mshook commited on
Commit
66f3409
·
verified ·
1 Parent(s): 9bef1a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -16
app.py CHANGED
@@ -10,22 +10,7 @@ import numpy as np
10
  import plotly.express as px
11
  import circuitsvis as cv
12
 
13
- """
14
- roneneldan/TinyStories-1M
15
- roneneldan/TinyStories-3M
16
- roneneldan/TinyStories-8M
17
- roneneldan/TinyStories-28M
18
- roneneldan/TinyStories-33M
19
- roneneldan/TinyStories-1Layer-21M
20
- roneneldan/TinyStories-2Layers-33M
21
- roneneldan/TinyStories-Instruct-1M
22
- roneneldan/TinyStories-Instruct-3M
23
- roneneldan/TinyStories-Instruct-8M
24
- roneneldan/TinyStories-Instruct-28M
25
- roneneldan/TinyStories-Instruct-33M
26
- roneneldan/TinyStories-Instuct-1Layer-21M
27
- roneneldan/TinyStories-Instruct-2Layers-33M
28
- """
29
 
30
 
31
  # Little bit of front end for model selector
@@ -62,6 +47,22 @@ model_name = st.sidebar.radio("Model (only use patching for\nsmall (<4L) models
62
  model = HookedTransformer.from_pretrained(model_name)
63
 
64
  def predict_next_token(prompt):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  logits = model(prompt)[0,-1]
66
  answer_index = logits.argmax()
67
  answer = model.tokenizer.decode(answer_index)
 
10
  import plotly.express as px
11
  import circuitsvis as cv
12
 
13
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  # Little bit of front end for model selector
 
47
  model = HookedTransformer.from_pretrained(model_name)
48
 
49
  def predict_next_token(prompt):
50
+ """
51
+ roneneldan/TinyStories-1M
52
+ roneneldan/TinyStories-3M
53
+ roneneldan/TinyStories-8M
54
+ roneneldan/TinyStories-28M
55
+ roneneldan/TinyStories-33M
56
+ roneneldan/TinyStories-1Layer-21M
57
+ roneneldan/TinyStories-2Layers-33M
58
+ roneneldan/TinyStories-Instruct-1M
59
+ roneneldan/TinyStories-Instruct-3M
60
+ roneneldan/TinyStories-Instruct-8M
61
+ roneneldan/TinyStories-Instruct-28M
62
+ roneneldan/TinyStories-Instruct-33M
63
+ roneneldan/TinyStories-Instuct-1Layer-21M
64
+ roneneldan/TinyStories-Instruct-2Layers-33M
65
+ """
66
  logits = model(prompt)[0,-1]
67
  answer_index = logits.argmax()
68
  answer = model.tokenizer.decode(answer_index)