|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Optional |
|
|
|
import networkx as nx |
|
import torch |
|
|
|
import llm_transparency_tool.routes.contributions as contributions |
|
from llm_transparency_tool.models.transparent_llm import TransparentLlm |
|
|
|
|
|
class GraphBuilder: |
|
""" |
|
Constructs the contributions graph with edges given one by one. The resulting graph |
|
is a networkx graph that can be accessed via the `graph` field. It contains the |
|
following types of nodes: |
|
|
|
- X0_<token>: the original token. |
|
- A<layer>_<token>: the residual stream after attention at the given layer for the |
|
given token. |
|
- M<layer>_<token>: the ffn block. |
|
- I<layer>_<token>: the residual stream after the ffn block. |
|
""" |
|
|
|
def __init__(self, n_layers: int, n_tokens: int): |
|
self._n_layers = n_layers |
|
self._n_tokens = n_tokens |
|
|
|
self.graph = nx.DiGraph() |
|
for layer in range(n_layers): |
|
for token in range(n_tokens): |
|
self.graph.add_node(f"A{layer}_{token}") |
|
self.graph.add_node(f"I{layer}_{token}") |
|
self.graph.add_node(f"M{layer}_{token}") |
|
for token in range(n_tokens): |
|
self.graph.add_node(f"X0_{token}") |
|
|
|
def get_output_node(self, token: int): |
|
return f"I{self._n_layers - 1}_{token}" |
|
|
|
def _add_edge(self, u: str, v: str, weight: float): |
|
|
|
|
|
|
|
|
|
|
|
if self.graph.has_edge(u, v): |
|
self.graph[u][v]["weight"] += weight |
|
else: |
|
self.graph.add_edge(u, v, weight=weight) |
|
|
|
def add_attention_edge(self, layer: int, token_from: int, token_to: int, w: float): |
|
self._add_edge( |
|
f"I{layer-1}_{token_from}" if layer > 0 else f"X0_{token_from}", |
|
f"A{layer}_{token_to}", |
|
w, |
|
) |
|
|
|
def add_residual_to_attn(self, layer: int, token: int, w: float): |
|
self._add_edge( |
|
f"I{layer-1}_{token}" if layer > 0 else f"X0_{token}", |
|
f"A{layer}_{token}", |
|
w, |
|
) |
|
|
|
def add_ffn_edge(self, layer: int, token: int, w: float): |
|
self._add_edge(f"A{layer}_{token}", f"M{layer}_{token}", w) |
|
self._add_edge(f"M{layer}_{token}", f"I{layer}_{token}", w) |
|
|
|
def add_residual_to_ffn(self, layer: int, token: int, w: float): |
|
self._add_edge(f"A{layer}_{token}", f"I{layer}_{token}", w) |
|
|
|
|
|
@torch.no_grad() |
|
def build_full_graph( |
|
model: TransparentLlm, |
|
batch_i: int = 0, |
|
renormalizing_threshold: Optional[float] = None, |
|
) -> nx.Graph: |
|
""" |
|
Build the contribution graph for all blocks of the model and all tokens. |
|
|
|
model: The transparent llm which already did the inference. |
|
batch_i: Which sentence to use from the batch that was given to the model. |
|
renormalizing_threshold: If specified, will apply renormalizing thresholding to the |
|
contributions. All contributions below the threshold will be erazed and the rest |
|
will be renormalized. |
|
""" |
|
n_layers = model.model_info().n_layers |
|
n_tokens = model.tokens()[batch_i].shape[0] |
|
|
|
builder = GraphBuilder(n_layers, n_tokens) |
|
|
|
for layer in range(n_layers): |
|
c_attn, c_resid_attn = contributions.get_attention_contributions( |
|
resid_pre=model.residual_in(layer)[batch_i].unsqueeze(0), |
|
resid_mid=model.residual_after_attn(layer)[batch_i].unsqueeze(0), |
|
decomposed_attn=model.decomposed_attn(batch_i, layer).unsqueeze(0), |
|
) |
|
if renormalizing_threshold is not None: |
|
c_attn, c_resid_attn = contributions.apply_threshold_and_renormalize( |
|
renormalizing_threshold, c_attn, c_resid_attn |
|
) |
|
for token_from in range(n_tokens): |
|
for token_to in range(n_tokens): |
|
|
|
c = c_attn[batch_i, token_to, token_from].sum().item() |
|
builder.add_attention_edge(layer, token_from, token_to, c) |
|
for token in range(n_tokens): |
|
builder.add_residual_to_attn( |
|
layer, token, c_resid_attn[batch_i, token].item() |
|
) |
|
|
|
c_ffn, c_resid_ffn = contributions.get_mlp_contributions( |
|
resid_mid=model.residual_after_attn(layer)[batch_i].unsqueeze(0), |
|
resid_post=model.residual_out(layer)[batch_i].unsqueeze(0), |
|
mlp_out=model.ffn_out(layer)[batch_i].unsqueeze(0), |
|
) |
|
if renormalizing_threshold is not None: |
|
c_ffn, c_resid_ffn = contributions.apply_threshold_and_renormalize( |
|
renormalizing_threshold, c_ffn, c_resid_ffn |
|
) |
|
for token in range(n_tokens): |
|
builder.add_ffn_edge(layer, token, c_ffn[batch_i, token].item()) |
|
builder.add_residual_to_ffn( |
|
layer, token, c_resid_ffn[batch_i, token].item() |
|
) |
|
|
|
return builder.graph |
|
|
|
|
|
def build_paths_to_predictions( |
|
graph: nx.Graph, |
|
n_layers: int, |
|
n_tokens: int, |
|
starting_tokens: List[int], |
|
threshold: float, |
|
) -> List[nx.Graph]: |
|
""" |
|
Given the full graph, this function returns only the trees leading to the specified |
|
tokens. Edges with weight below `threshold` will be ignored. |
|
""" |
|
builder = GraphBuilder(n_layers, n_tokens) |
|
|
|
rgraph = graph.reverse() |
|
search_graph = nx.subgraph_view( |
|
rgraph, filter_edge=lambda u, v: rgraph[u][v]["weight"] > threshold |
|
) |
|
|
|
result = [] |
|
for start in starting_tokens: |
|
assert start < n_tokens |
|
assert start >= 0 |
|
edges = nx.edge_dfs(search_graph, source=builder.get_output_node(start)) |
|
tree = search_graph.edge_subgraph(edges) |
|
|
|
result.append(tree.reverse()) |
|
|
|
return result |
|
|