Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import concurrent.futures | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# Load the model and tokenizer (using GPT-2 as an example) | |
model_name = "gpt2" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
model.eval() | |
torch.set_num_threads(2) | |
def min_p_sampling(logits, pbase=0.1): | |
""" | |
Perform min-p sampling on the logits. As described in | |
https://arxiv.org/abs/2407.01082 | |
Args: | |
logits (torch.Tensor): 1D tensor of logits for the next token. | |
pbase (float): Base probability to scale pmax. | |
Returns: | |
int: The sampled token index. | |
""" | |
# Convert logits to probabilities. | |
probs = torch.softmax(logits, dim=-1) | |
# 1. Find maximum probability. | |
pmax = probs.max() | |
# 2. Compute the dynamic threshold. | |
pscaled = pbase * pmax | |
# 3. Create a mask of tokens with probability >= pscaled. | |
mask = probs >= pscaled | |
# In the unlikely event that no token meets the threshold, use the full distribution. | |
if mask.sum() == 0: | |
mask = torch.ones_like(probs, dtype=torch.bool) | |
probs_filtered = probs * mask.float() | |
# 4. Normalize and sample. | |
probs_normalized = probs_filtered / probs_filtered.sum() | |
sampled_index = torch.multinomial(probs_normalized, num_samples=1) | |
return sampled_index.item() | |
def generate_laconic_completion(prompt: str, n: int = 5, max_length: int = 100): | |
# generate n completions greedily and return the shortest one | |
with torch.no_grad(): | |
# Encode the prompt and get the attention mask. | |
encoded = tokenizer(prompt, return_tensors="pt") | |
input_ids = encoded["input_ids"] | |
attention_mask = encoded["attention_mask"] | |
# Generate the output. | |
outputs = model.generate( | |
input_ids, | |
attention_mask=attention_mask, | |
max_length=max_length, | |
num_return_sequences=n, | |
do_sample=True, | |
) | |
completions = [ | |
tokenizer.decode(output, skip_special_tokens=True) for output in outputs | |
] | |
return min(completions, key=len) | |
def generate_with_confidence(input_ids, max_length): | |
""" | |
Generate a sequence using greedy decoding while returning the scores. | |
""" | |
outputs = model.generate( | |
input_ids, | |
max_length=max_length, | |
do_sample=False, | |
output_scores=True, | |
return_dict_in_generate=True, | |
) | |
return outputs | |
def compute_answer_confidence(outputs): | |
""" | |
Compute the answer confidence over the generated tokens. | |
For each generated token, compute the difference between the top-1 and top-2 logits. | |
Returns the average difference. | |
""" | |
diffs = [] | |
for score in outputs.scores: | |
# Get top-2 logit values | |
top2 = torch.topk(score[0], 2) | |
diff = top2.values[0] - top2.values[1] | |
diffs.append(diff.item()) | |
return sum(diffs) / len(diffs) if diffs else 0.0 | |
def cot_decoding(prompt, k=5, max_length=100): | |
""" | |
Perform Chain-of-Thought (CoT) decoding by exploring top-k alternative paths. | |
""" | |
input_ids = tokenizer.encode(prompt, return_tensors="pt") | |
# Get logits for the next token | |
with torch.no_grad(): | |
outputs = model(input_ids) | |
logits = outputs.logits[0, -1, :] | |
# Get top-k candidate tokens | |
topk = torch.topk(logits, k) | |
candidate_tokens = topk.indices | |
paths = [] | |
for token in candidate_tokens: | |
# Append the candidate token to the prompt | |
new_input_ids = torch.cat([input_ids, token.view(1, 1)], dim=1) | |
# Generate a full sequence with output scores | |
gen_outputs = generate_with_confidence( | |
new_input_ids, max_length=new_input_ids.shape[1] + max_length | |
) | |
# Decode the generated sequence | |
generated_text = tokenizer.decode( | |
gen_outputs.sequences[0], skip_special_tokens=True | |
) | |
# Compute answer confidence | |
confidence = compute_answer_confidence(gen_outputs) | |
paths.append({"text": generated_text, "confidence": confidence}) | |
return max(paths, key=lambda x: x["confidence"])["text"] | |
def generate_completion(prompt, strategy, params): | |
""" | |
Generate a complete answer using model.generate with specified parameters. | |
""" | |
with torch.no_grad(): | |
# Encode the prompt and get the attention mask. | |
encoded = tokenizer(prompt, return_tensors="pt") | |
input_ids = encoded["input_ids"] | |
attention_mask = encoded["attention_mask"] | |
# Generate the output. | |
output_ids = model.generate( | |
input_ids, attention_mask=attention_mask, max_length=100, **params | |
) | |
return tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
def generate_min_p_completion(prompt, pbase=0.1, max_length=100): | |
input_ids = tokenizer.encode(prompt, return_tensors="pt") | |
past = None | |
with torch.no_grad(): | |
for _ in range(max_length - input_ids.size(1)): | |
# Only pass the last token if past is available | |
outputs = ( | |
model(input_ids[:, -1:], past_key_values=past) | |
if past is not None | |
else model(input_ids) | |
) | |
past = outputs.past_key_values | |
logits = outputs.logits[:, -1, :] | |
next_token = min_p_sampling(logits, pbase=pbase) | |
input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=-1) | |
if next_token == tokenizer.eos_token_id: | |
break | |
return tokenizer.decode(input_ids[0], skip_special_tokens=True) | |
def generate_all(prompt): | |
""" | |
Run multiple decoding strategies concurrently and yield updates as each completes. | |
""" | |
# Define each decoding strategy and its parameters. | |
methods = { | |
"Greedy": {"type": "default", "params": {"do_sample": False}}, | |
"Top-k Sampling": { | |
"type": "default", | |
"params": {"do_sample": True, "top_k": 100}, | |
}, | |
"Top-p Sampling": { | |
"type": "default", | |
"params": {"do_sample": True, "top_p": 0.95}, | |
}, | |
"Beam Search": { | |
"type": "default", | |
"params": {"num_beams": 5, "early_stopping": True}, | |
}, | |
"Eta Sampling": { | |
"type": "default", | |
"params": {"do_sample": True, "eta_cutoff": 0.3}, | |
}, | |
"Epsilon Sampling": { | |
"type": "default", | |
"params": {"do_sample": True, "epsilon_cutoff": 0.2}, | |
}, | |
"Min-p Sampling": {"type": "min_p", "pbase": 0.1}, | |
"laconic": { | |
"type": "default", | |
"params": {"do_sample": True, "num_return_sequences": 5}, | |
}, | |
"COT Decoding": { | |
"type": "cot_decoding", | |
"params": {"k": 5, "max_length": 100}, | |
}, | |
} | |
# Define the order for display. | |
method_order = [ | |
"Greedy", | |
"Top-k Sampling", | |
"Top-p Sampling", | |
"Beam Search", | |
"Min-p Sampling", | |
"Eta Sampling", | |
"Epsilon Sampling", | |
"laconic", | |
"COT Decoding", | |
] | |
results = {method: None for method in methods} | |
# Yield an initial placeholder state. | |
yield tuple("Processing..." for _ in method_order) | |
# Use a thread pool to run each generation concurrently. | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
future_to_method = {} | |
for method, info in methods.items(): | |
if info["type"] == "default": | |
future = executor.submit( | |
generate_completion, prompt, method, info["params"] | |
) | |
elif info["type"] == "min_p": | |
future = executor.submit( | |
generate_min_p_completion, prompt, info["pbase"] | |
) | |
elif method == "laconic": | |
future = executor.submit(generate_laconic_completion, prompt) | |
elif method == "COT Decoding": | |
future = executor.submit(cot_decoding, prompt, **info["params"]) | |
future_to_method[future] = method | |
# As each future completes, update its result and yield the current state. | |
for future in concurrent.futures.as_completed(future_to_method): | |
method = future_to_method[future] | |
try: | |
result = future.result() | |
except Exception as exc: | |
result = f"Error: {exc}" | |
results[method] = result | |
# Yield the results in the pre-defined order; pending methods show "Processing..." | |
yield tuple( | |
results[m] if results[m] is not None else "Processing..." | |
for m in method_order | |
) | |
# Create the Gradio interface. | |
interface = gr.Interface( | |
fn=generate_all, | |
inputs=gr.Textbox(lines=3, placeholder="Enter your prompt here...", label="Prompt"), | |
outputs=[ | |
gr.Textbox(label="Greedy"), | |
gr.Textbox(label="Top-k Sampling"), | |
gr.Textbox(label="Top-p Sampling"), | |
gr.Textbox(label="Beam Search"), | |
gr.Textbox(label="Min-p Sampling (as in https://arxiv.org/abs/2407.01082)"), | |
gr.Textbox(label="Eta Sampling"), | |
gr.Textbox(label="Epsilon Sampling"), | |
gr.Textbox( | |
label="laconic decoding (by Alex Dimakis, 2025, search for twitter thread)" | |
), | |
gr.Textbox( | |
label="COT Decoding (Chain-of-Thought Reasoning without Prompting, Wang, Zhou, 2024)" | |
), | |
], | |
title="Decoding Methods Comparison", | |
description="Each decoding method's final answer is printed as soon as it is done. Model used: GPT-2.", | |
) | |
if __name__ == "__main__": | |
interface.launch() | |