Moshe Ofer
Initial commit for Hugging Face Space
a1b31ed
raw
history blame
5.8 kB
import argparse
import os
from transformers import MultiBeamTextStreamer, AutoTokenizer, AutoModelForCausalLM
class BeamOutputManager:
"""Manages file handlers for beam outputs"""
def __init__(self, output_dir: str, num_beams: int):
self.output_dir = output_dir
self.num_beams = num_beams
self.counter = 0
# Create main output directory and closed beams directory
os.makedirs(output_dir, exist_ok=True)
self.closed_beams_dir = os.path.join(output_dir, "closed_beams")
os.makedirs(self.closed_beams_dir, exist_ok=True)
# Store complete text for each beam
self.beam_texts = {i: "" for i in range(num_beams)}
self.active_beams = set(range(num_beams))
# Initialize empty files
for beam_idx in range(num_beams):
filename = os.path.join(output_dir, f'beam_{beam_idx}.txt')
with open(filename, 'w', encoding='utf-8') as f:
f.write('')
def write_to_beam(self, beam_idx: int, text: str):
"""Write text to the specified beam's file"""
if 0 <= beam_idx < self.num_beams and beam_idx in self.active_beams:
# Update stored text
self.beam_texts[beam_idx] = text
# Write complete text to file
filename = os.path.join(self.output_dir, f'beam_{beam_idx}.txt')
with open(filename, 'w', encoding='utf-8') as f:
f.write(self.beam_texts[beam_idx])
def finalize_beam(self, final_text: str):
"""
Handle a completed beam by creating a new file in the closed_beams directory.
Args:
final_text (str): The complete text generated by the finished beam
"""
# Create a timestamp-based filename to ensure uniqueness
self.counter += 1
filename = os.path.join(self.closed_beams_dir, f'completed_beam_{self.counter}.txt')
# Write the final text to the completed beam file
with open(filename, 'w', encoding='utf-8') as f:
f.write(final_text)
return filename
def setup_model_and_tokenizer(model_name):
"""
Initialize the model and tokenizer.
Args:
model_name (str): Name of the model to use
Returns:
tuple: (model, tokenizer)
"""
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
return model, tokenizer
def generate_with_beam_search(model, tokenizer, user_prompt, output_dir, num_beams=5, max_new_tokens=512):
"""
Generate responses using beam search and write results to files.
Args:
model: The language model
tokenizer: The tokenizer
user_prompt (str): Input prompt
output_dir (str): Directory to save beam outputs
num_beams (int): Number of beams to use
max_new_tokens (int): Maximum number of new tokens to generate
"""
# Initialize the output manager
output_manager = BeamOutputManager(output_dir, num_beams)
def on_beam_update(beam_idx: int, new_text: str):
"""Handler for beam updates - write new text to file"""
output_manager.write_to_beam(beam_idx, new_text)
def on_beam_finished(final_text: str):
"""Handler for completed beams - create final output file"""
final_path = output_manager.finalize_beam(final_text)
print(f"\nCompleted beam saved to: {final_path}")
# Create messages format
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": user_prompt}
]
# Apply chat template
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Prepare inputs
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# Initialize streamer with handlers
streamer = MultiBeamTextStreamer(
tokenizer=tokenizer,
num_beams=num_beams,
on_beam_update=on_beam_update,
on_beam_finished=on_beam_finished,
skip_prompt=True
)
# Generate with beam search
model.generate(
**model_inputs,
num_beams=num_beams,
num_return_sequences=num_beams,
max_new_tokens=max_new_tokens,
output_scores=True,
return_dict_in_generate=True,
early_stopping=True,
streamer=streamer
)
def main():
# Setup command line arguments
parser = argparse.ArgumentParser(description='Language Model Text Generation with Beam Search')
parser.add_argument('--model', type=str, default='Qwen/Qwen2.5-0.5B-Instruct',
help='Name of the model to use')
parser.add_argument('--num_beams', type=int, default=5,
help='Number of beams for beam search')
parser.add_argument('--max_tokens', type=int, default=512,
help='Maximum number of new tokens to generate')
parser.add_argument('--output_dir', type=str, default='beam_outputs',
help='Directory to save beam outputs')
args = parser.parse_args()
# Initialize model and tokenizer
model, tokenizer = setup_model_and_tokenizer(args.model)
# Interactive loop
while True:
prompt = input("\nEnter your prompt (or 'quit' to exit): ")
if prompt.lower() == 'quit':
break
generate_with_beam_search(
model,
tokenizer,
prompt,
args.output_dir,
num_beams=args.num_beams,
max_new_tokens=args.max_tokens
)
print(f"\nOutputs written to: {args.output_dir}/beam_*.txt")
if __name__ == "__main__":
main()