Spaces:
Running
Running
Lev McKinney
commited on
Commit
β’
25369b9
1
Parent(s):
3d7c4c3
Basic app using gpt2 working
Browse files
app.py
CHANGED
@@ -1,17 +1,89 @@
|
|
1 |
-
|
|
|
2 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
-
from tuned_lens.nn import TunedLens
|
4 |
from tuned_lens.plotting import plot_lens
|
5 |
-
|
6 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
LENS_PATH = '<PATH TO LENS>'
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
-
iface = gr.Interface(fn=plot_lens_outputs, inputs="text", outputs=gr.outputs.Plot(type="auto"))
|
17 |
iface.launch()
|
|
|
1 |
+
import torch
|
2 |
+
from tuned_lens.nn.lenses import TunedLens, LogitLens
|
3 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
4 |
from tuned_lens.plotting import plot_lens
|
|
|
5 |
import gradio as gr
|
6 |
+
from plotly import graph_objects as go
|
7 |
+
|
8 |
+
device = torch.device("cpu")
|
9 |
+
print(f"Using device {device} for inference")
|
10 |
+
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
11 |
+
model = model.to(device)
|
12 |
+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
13 |
+
tuned_lens = TunedLens.load("lens/gpt2", map_location=device)
|
14 |
+
logit_lens = LogitLens(model)
|
15 |
+
|
16 |
+
lens_options_dict = {
|
17 |
+
"Tuned Lens": tuned_lens,
|
18 |
+
"Logit Lens": logit_lens,
|
19 |
+
}
|
20 |
+
statistic_options_dict = {
|
21 |
+
"Entropy": "entropy",
|
22 |
+
"Cross Entropy": "ce",
|
23 |
+
"Forward KL": "forward_kl",
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
def make_plot(lens, text, statistic, token_cutoff):
|
28 |
+
input_ids = tokenizer.encode(text, return_tensors="pt")
|
29 |
+
|
30 |
+
if len(input_ids[0]) == 0:
|
31 |
+
return go.Figure(layout=dict(title="Please enter some text."))
|
32 |
+
|
33 |
+
if token_cutoff < 1:
|
34 |
+
return go.Figure(layout=dict(title="Please provide valid token cut off."))
|
35 |
+
|
36 |
+
fig = plot_lens(
|
37 |
+
model,
|
38 |
+
tokenizer,
|
39 |
+
lens_options_dict[lens],
|
40 |
+
layer_stride=1,
|
41 |
+
input_ids=input_ids,
|
42 |
+
start_pos=max(len(input_ids[0]) - token_cutoff, 0),
|
43 |
+
statistic=statistic_options_dict[statistic],
|
44 |
+
)
|
45 |
+
fig.update_layout(template="plotly_dark")
|
46 |
+
|
47 |
+
# Update the colorscale of the heatmap trace
|
48 |
+
for trace in fig.data:
|
49 |
+
if trace.type == "heatmap":
|
50 |
+
trace.update(colorscale="Inferno")
|
51 |
+
|
52 |
+
return fig
|
53 |
+
|
54 |
+
|
55 |
+
preamble = """
|
56 |
+
# The Tuned Lens π
|
57 |
+
|
58 |
+
A tuned lens allows us to peak at the iterative computations a transformer uses to compute the next token.
|
59 |
+
|
60 |
+
A lens into a transformer with n layers allows you to replace the last $m$ layers of the model with an [affine transformation](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) (we call these affine translators).
|
61 |
+
|
62 |
+
This essentially skips over these last few layers and lets you see the best prediction that can be made from the model's representations, i.e. the residual stream, at layer $n - m$. Since the representations may be rotated, shifted, or stretched from layer to layer it's useful to train the len's affine adapters specifically on each layer. This training is what differentiates this method from simpler approaches that decode the residual stream of the network directly using the unembeding layer i.e. the logit lens. We explain this process in [the paper](https://arxiv.org/abs/2303.08112).
|
63 |
+
"""
|
64 |
|
|
|
65 |
|
66 |
+
with gr.Blocks() as iface:
|
67 |
+
gr.Markdown(preamble)
|
68 |
+
with gr.Column():
|
69 |
+
text = gr.Textbox(
|
70 |
+
value="the iterative computations a transformer uses to compute the next",
|
71 |
+
label="Input Text",
|
72 |
+
)
|
73 |
+
with gr.Row():
|
74 |
+
lens_options = gr.Dropdown(
|
75 |
+
list(lens_options_dict.keys()), value="Tuned Lens", label="Select Lens"
|
76 |
+
)
|
77 |
+
statistic = gr.Dropdown(
|
78 |
+
list(statistic_options_dict.keys()),
|
79 |
+
value="Entropy",
|
80 |
+
label="Select Statistic",
|
81 |
+
)
|
82 |
+
token_cutoff = gr.Slider(
|
83 |
+
maximum=20, minimum=2, value=10, label="Token Cut Off"
|
84 |
+
)
|
85 |
+
examine_btn = gr.Button(value="Examine")
|
86 |
+
plot = gr.Plot()
|
87 |
+
examine_btn.click(make_plot, [lens_options, text, statistic, token_cutoff], plot)
|
88 |
|
|
|
89 |
iface.launch()
|