import argparse import os from transformers import MultiBeamTextStreamer, AutoTokenizer, AutoModelForCausalLM class BeamOutputManager: """Manages file handlers for beam outputs""" def __init__(self, output_dir: str, num_beams: int): self.output_dir = output_dir self.num_beams = num_beams self.counter = 0 # Create main output directory and closed beams directory os.makedirs(output_dir, exist_ok=True) self.closed_beams_dir = os.path.join(output_dir, "closed_beams") os.makedirs(self.closed_beams_dir, exist_ok=True) # Store complete text for each beam self.beam_texts = {i: "" for i in range(num_beams)} self.active_beams = set(range(num_beams)) # Initialize empty files for beam_idx in range(num_beams): filename = os.path.join(output_dir, f'beam_{beam_idx}.txt') with open(filename, 'w', encoding='utf-8') as f: f.write('') def write_to_beam(self, beam_idx: int, text: str): """Write text to the specified beam's file""" if 0 <= beam_idx < self.num_beams and beam_idx in self.active_beams: # Update stored text self.beam_texts[beam_idx] = text # Write complete text to file filename = os.path.join(self.output_dir, f'beam_{beam_idx}.txt') with open(filename, 'w', encoding='utf-8') as f: f.write(self.beam_texts[beam_idx]) def finalize_beam(self, final_text: str): """ Handle a completed beam by creating a new file in the closed_beams directory. Args: final_text (str): The complete text generated by the finished beam """ # Create a timestamp-based filename to ensure uniqueness self.counter += 1 filename = os.path.join(self.closed_beams_dir, f'completed_beam_{self.counter}.txt') # Write the final text to the completed beam file with open(filename, 'w', encoding='utf-8') as f: f.write(final_text) return filename def setup_model_and_tokenizer(model_name): """ Initialize the model and tokenizer. Args: model_name (str): Name of the model to use Returns: tuple: (model, tokenizer) """ tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype="auto", device_map="auto" ) return model, tokenizer def generate_with_beam_search(model, tokenizer, user_prompt, output_dir, num_beams=5, max_new_tokens=512): """ Generate responses using beam search and write results to files. Args: model: The language model tokenizer: The tokenizer user_prompt (str): Input prompt output_dir (str): Directory to save beam outputs num_beams (int): Number of beams to use max_new_tokens (int): Maximum number of new tokens to generate """ # Initialize the output manager output_manager = BeamOutputManager(output_dir, num_beams) def on_beam_update(beam_idx: int, new_text: str): """Handler for beam updates - write new text to file""" output_manager.write_to_beam(beam_idx, new_text) def on_beam_finished(final_text: str): """Handler for completed beams - create final output file""" final_path = output_manager.finalize_beam(final_text) print(f"\nCompleted beam saved to: {final_path}") # Create messages format messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": user_prompt} ] # Apply chat template text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Prepare inputs model_inputs = tokenizer([text], return_tensors="pt").to(model.device) # Initialize streamer with handlers streamer = MultiBeamTextStreamer( tokenizer=tokenizer, num_beams=num_beams, on_beam_update=on_beam_update, on_beam_finished=on_beam_finished, skip_prompt=True ) # Generate with beam search model.generate( **model_inputs, num_beams=num_beams, num_return_sequences=num_beams, max_new_tokens=max_new_tokens, output_scores=True, return_dict_in_generate=True, early_stopping=True, streamer=streamer ) def main(): # Setup command line arguments parser = argparse.ArgumentParser(description='Language Model Text Generation with Beam Search') parser.add_argument('--model', type=str, default='Qwen/Qwen2.5-0.5B-Instruct', help='Name of the model to use') parser.add_argument('--num_beams', type=int, default=5, help='Number of beams for beam search') parser.add_argument('--max_tokens', type=int, default=512, help='Maximum number of new tokens to generate') parser.add_argument('--output_dir', type=str, default='beam_outputs', help='Directory to save beam outputs') args = parser.parse_args() # Initialize model and tokenizer model, tokenizer = setup_model_and_tokenizer(args.model) # Interactive loop while True: prompt = input("\nEnter your prompt (or 'quit' to exit): ") if prompt.lower() == 'quit': break generate_with_beam_search( model, tokenizer, prompt, args.output_dir, num_beams=args.num_beams, max_new_tokens=args.max_tokens ) print(f"\nOutputs written to: {args.output_dir}/beam_*.txt") if __name__ == "__main__": main()