Spaces:
Sleeping
Sleeping
import argparse | |
import os | |
from transformers import MultiBeamTextStreamer, AutoTokenizer, AutoModelForCausalLM | |
class BeamOutputManager: | |
"""Manages file handlers for beam outputs""" | |
def __init__(self, output_dir: str, num_beams: int): | |
self.output_dir = output_dir | |
self.num_beams = num_beams | |
self.counter = 0 | |
# Create main output directory and closed beams directory | |
os.makedirs(output_dir, exist_ok=True) | |
self.closed_beams_dir = os.path.join(output_dir, "closed_beams") | |
os.makedirs(self.closed_beams_dir, exist_ok=True) | |
# Store complete text for each beam | |
self.beam_texts = {i: "" for i in range(num_beams)} | |
self.active_beams = set(range(num_beams)) | |
# Initialize empty files | |
for beam_idx in range(num_beams): | |
filename = os.path.join(output_dir, f'beam_{beam_idx}.txt') | |
with open(filename, 'w', encoding='utf-8') as f: | |
f.write('') | |
def write_to_beam(self, beam_idx: int, text: str): | |
"""Write text to the specified beam's file""" | |
if 0 <= beam_idx < self.num_beams and beam_idx in self.active_beams: | |
# Update stored text | |
self.beam_texts[beam_idx] = text | |
# Write complete text to file | |
filename = os.path.join(self.output_dir, f'beam_{beam_idx}.txt') | |
with open(filename, 'w', encoding='utf-8') as f: | |
f.write(self.beam_texts[beam_idx]) | |
def finalize_beam(self, final_text: str): | |
""" | |
Handle a completed beam by creating a new file in the closed_beams directory. | |
Args: | |
final_text (str): The complete text generated by the finished beam | |
""" | |
# Create a timestamp-based filename to ensure uniqueness | |
self.counter += 1 | |
filename = os.path.join(self.closed_beams_dir, f'completed_beam_{self.counter}.txt') | |
# Write the final text to the completed beam file | |
with open(filename, 'w', encoding='utf-8') as f: | |
f.write(final_text) | |
return filename | |
def setup_model_and_tokenizer(model_name): | |
""" | |
Initialize the model and tokenizer. | |
Args: | |
model_name (str): Name of the model to use | |
Returns: | |
tuple: (model, tokenizer) | |
""" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype="auto", | |
device_map="auto" | |
) | |
return model, tokenizer | |
def generate_with_beam_search(model, tokenizer, user_prompt, output_dir, num_beams=5, max_new_tokens=512): | |
""" | |
Generate responses using beam search and write results to files. | |
Args: | |
model: The language model | |
tokenizer: The tokenizer | |
user_prompt (str): Input prompt | |
output_dir (str): Directory to save beam outputs | |
num_beams (int): Number of beams to use | |
max_new_tokens (int): Maximum number of new tokens to generate | |
""" | |
# Initialize the output manager | |
output_manager = BeamOutputManager(output_dir, num_beams) | |
def on_beam_update(beam_idx: int, new_text: str): | |
"""Handler for beam updates - write new text to file""" | |
output_manager.write_to_beam(beam_idx, new_text) | |
def on_beam_finished(final_text: str): | |
"""Handler for completed beams - create final output file""" | |
final_path = output_manager.finalize_beam(final_text) | |
print(f"\nCompleted beam saved to: {final_path}") | |
# Create messages format | |
messages = [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": user_prompt} | |
] | |
# Apply chat template | |
text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
# Prepare inputs | |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
# Initialize streamer with handlers | |
streamer = MultiBeamTextStreamer( | |
tokenizer=tokenizer, | |
num_beams=num_beams, | |
on_beam_update=on_beam_update, | |
on_beam_finished=on_beam_finished, | |
skip_prompt=True | |
) | |
# Generate with beam search | |
model.generate( | |
**model_inputs, | |
num_beams=num_beams, | |
num_return_sequences=num_beams, | |
max_new_tokens=max_new_tokens, | |
output_scores=True, | |
return_dict_in_generate=True, | |
early_stopping=True, | |
streamer=streamer | |
) | |
def main(): | |
# Setup command line arguments | |
parser = argparse.ArgumentParser(description='Language Model Text Generation with Beam Search') | |
parser.add_argument('--model', type=str, default='Qwen/Qwen2.5-0.5B-Instruct', | |
help='Name of the model to use') | |
parser.add_argument('--num_beams', type=int, default=5, | |
help='Number of beams for beam search') | |
parser.add_argument('--max_tokens', type=int, default=512, | |
help='Maximum number of new tokens to generate') | |
parser.add_argument('--output_dir', type=str, default='beam_outputs', | |
help='Directory to save beam outputs') | |
args = parser.parse_args() | |
# Initialize model and tokenizer | |
model, tokenizer = setup_model_and_tokenizer(args.model) | |
# Interactive loop | |
while True: | |
prompt = input("\nEnter your prompt (or 'quit' to exit): ") | |
if prompt.lower() == 'quit': | |
break | |
generate_with_beam_search( | |
model, | |
tokenizer, | |
prompt, | |
args.output_dir, | |
num_beams=args.num_beams, | |
max_new_tokens=args.max_tokens | |
) | |
print(f"\nOutputs written to: {args.output_dir}/beam_*.txt") | |
if __name__ == "__main__": | |
main() |