llm-decoders / app.py
kgourgou's picture
Update app.py
c816679 verified
import gradio as gr
import torch
import concurrent.futures
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load the model and tokenizer (using GPT-2 as an example)
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval()
torch.set_num_threads(2)
def min_p_sampling(logits, pbase=0.1):
"""
Perform min-p sampling on the logits. As described in
https://arxiv.org/abs/2407.01082
Args:
logits (torch.Tensor): 1D tensor of logits for the next token.
pbase (float): Base probability to scale pmax.
Returns:
int: The sampled token index.
"""
# Convert logits to probabilities.
probs = torch.softmax(logits, dim=-1)
# 1. Find maximum probability.
pmax = probs.max()
# 2. Compute the dynamic threshold.
pscaled = pbase * pmax
# 3. Create a mask of tokens with probability >= pscaled.
mask = probs >= pscaled
# In the unlikely event that no token meets the threshold, use the full distribution.
if mask.sum() == 0:
mask = torch.ones_like(probs, dtype=torch.bool)
probs_filtered = probs * mask.float()
# 4. Normalize and sample.
probs_normalized = probs_filtered / probs_filtered.sum()
sampled_index = torch.multinomial(probs_normalized, num_samples=1)
return sampled_index.item()
def generate_laconic_completion(prompt: str, n: int = 5, max_length: int = 100):
# generate n completions greedily and return the shortest one
with torch.no_grad():
# Encode the prompt and get the attention mask.
encoded = tokenizer(prompt, return_tensors="pt")
input_ids = encoded["input_ids"]
attention_mask = encoded["attention_mask"]
# Generate the output.
outputs = model.generate(
input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_return_sequences=n,
do_sample=True,
)
completions = [
tokenizer.decode(output, skip_special_tokens=True) for output in outputs
]
return min(completions, key=len)
def generate_with_confidence(input_ids, max_length):
"""
Generate a sequence using greedy decoding while returning the scores.
"""
outputs = model.generate(
input_ids,
max_length=max_length,
do_sample=False,
output_scores=True,
return_dict_in_generate=True,
)
return outputs
def compute_answer_confidence(outputs):
"""
Compute the answer confidence over the generated tokens.
For each generated token, compute the difference between the top-1 and top-2 logits.
Returns the average difference.
"""
diffs = []
for score in outputs.scores:
# Get top-2 logit values
top2 = torch.topk(score[0], 2)
diff = top2.values[0] - top2.values[1]
diffs.append(diff.item())
return sum(diffs) / len(diffs) if diffs else 0.0
def cot_decoding(prompt, k=5, max_length=100):
"""
Perform Chain-of-Thought (CoT) decoding by exploring top-k alternative paths.
"""
input_ids = tokenizer.encode(prompt, return_tensors="pt")
# Get logits for the next token
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits[0, -1, :]
# Get top-k candidate tokens
topk = torch.topk(logits, k)
candidate_tokens = topk.indices
paths = []
for token in candidate_tokens:
# Append the candidate token to the prompt
new_input_ids = torch.cat([input_ids, token.view(1, 1)], dim=1)
# Generate a full sequence with output scores
gen_outputs = generate_with_confidence(
new_input_ids, max_length=new_input_ids.shape[1] + max_length
)
# Decode the generated sequence
generated_text = tokenizer.decode(
gen_outputs.sequences[0], skip_special_tokens=True
)
# Compute answer confidence
confidence = compute_answer_confidence(gen_outputs)
paths.append({"text": generated_text, "confidence": confidence})
return max(paths, key=lambda x: x["confidence"])["text"]
def generate_completion(prompt, strategy, params):
"""
Generate a complete answer using model.generate with specified parameters.
"""
with torch.no_grad():
# Encode the prompt and get the attention mask.
encoded = tokenizer(prompt, return_tensors="pt")
input_ids = encoded["input_ids"]
attention_mask = encoded["attention_mask"]
# Generate the output.
output_ids = model.generate(
input_ids, attention_mask=attention_mask, max_length=100, **params
)
return tokenizer.decode(output_ids[0], skip_special_tokens=True)
def generate_min_p_completion(prompt, pbase=0.1, max_length=100):
input_ids = tokenizer.encode(prompt, return_tensors="pt")
past = None
with torch.no_grad():
for _ in range(max_length - input_ids.size(1)):
# Only pass the last token if past is available
outputs = (
model(input_ids[:, -1:], past_key_values=past)
if past is not None
else model(input_ids)
)
past = outputs.past_key_values
logits = outputs.logits[:, -1, :]
next_token = min_p_sampling(logits, pbase=pbase)
input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=-1)
if next_token == tokenizer.eos_token_id:
break
return tokenizer.decode(input_ids[0], skip_special_tokens=True)
def generate_all(prompt):
"""
Run multiple decoding strategies concurrently and yield updates as each completes.
"""
# Define each decoding strategy and its parameters.
methods = {
"Greedy": {"type": "default", "params": {"do_sample": False}},
"Top-k Sampling": {
"type": "default",
"params": {"do_sample": True, "top_k": 100},
},
"Top-p Sampling": {
"type": "default",
"params": {"do_sample": True, "top_p": 0.95},
},
"Beam Search": {
"type": "default",
"params": {"num_beams": 5, "early_stopping": True},
},
"Eta Sampling": {
"type": "default",
"params": {"do_sample": True, "eta_cutoff": 0.3},
},
"Epsilon Sampling": {
"type": "default",
"params": {"do_sample": True, "epsilon_cutoff": 0.2},
},
"Min-p Sampling": {"type": "min_p", "pbase": 0.1},
"laconic": {
"type": "default",
"params": {"do_sample": True, "num_return_sequences": 5},
},
"COT Decoding": {
"type": "cot_decoding",
"params": {"k": 5, "max_length": 100},
},
}
# Define the order for display.
method_order = [
"Greedy",
"Top-k Sampling",
"Top-p Sampling",
"Beam Search",
"Min-p Sampling",
"Eta Sampling",
"Epsilon Sampling",
"laconic",
"COT Decoding",
]
results = {method: None for method in methods}
# Yield an initial placeholder state.
yield tuple("Processing..." for _ in method_order)
# Use a thread pool to run each generation concurrently.
with concurrent.futures.ThreadPoolExecutor() as executor:
future_to_method = {}
for method, info in methods.items():
if info["type"] == "default":
future = executor.submit(
generate_completion, prompt, method, info["params"]
)
elif info["type"] == "min_p":
future = executor.submit(
generate_min_p_completion, prompt, info["pbase"]
)
elif method == "laconic":
future = executor.submit(generate_laconic_completion, prompt)
elif method == "COT Decoding":
future = executor.submit(cot_decoding, prompt, **info["params"])
future_to_method[future] = method
# As each future completes, update its result and yield the current state.
for future in concurrent.futures.as_completed(future_to_method):
method = future_to_method[future]
try:
result = future.result()
except Exception as exc:
result = f"Error: {exc}"
results[method] = result
# Yield the results in the pre-defined order; pending methods show "Processing..."
yield tuple(
results[m] if results[m] is not None else "Processing..."
for m in method_order
)
# Create the Gradio interface.
interface = gr.Interface(
fn=generate_all,
inputs=gr.Textbox(lines=3, placeholder="Enter your prompt here...", label="Prompt"),
outputs=[
gr.Textbox(label="Greedy"),
gr.Textbox(label="Top-k Sampling"),
gr.Textbox(label="Top-p Sampling"),
gr.Textbox(label="Beam Search"),
gr.Textbox(label="Min-p Sampling (as in https://arxiv.org/abs/2407.01082)"),
gr.Textbox(label="Eta Sampling"),
gr.Textbox(label="Epsilon Sampling"),
gr.Textbox(
label="laconic decoding (by Alex Dimakis, 2025, search for twitter thread)"
),
gr.Textbox(
label="COT Decoding (Chain-of-Thought Reasoning without Prompting, Wang, Zhou, 2024)"
),
],
title="Decoding Methods Comparison",
description="Each decoding method's final answer is printed as soon as it is done. Model used: GPT-2.",
)
if __name__ == "__main__":
interface.launch()