|
import argparse |
|
import os |
|
|
|
from transformers import MultiBeamTextStreamer, AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
class BeamOutputManager: |
|
"""Manages file handlers for beam outputs""" |
|
|
|
def __init__(self, output_dir: str, num_beams: int): |
|
self.output_dir = output_dir |
|
self.num_beams = num_beams |
|
self.counter = 0 |
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
self.closed_beams_dir = os.path.join(output_dir, "closed_beams") |
|
os.makedirs(self.closed_beams_dir, exist_ok=True) |
|
|
|
|
|
self.beam_texts = {i: "" for i in range(num_beams)} |
|
self.active_beams = set(range(num_beams)) |
|
|
|
|
|
for beam_idx in range(num_beams): |
|
filename = os.path.join(output_dir, f'beam_{beam_idx}.txt') |
|
with open(filename, 'w', encoding='utf-8') as f: |
|
f.write('') |
|
|
|
def write_to_beam(self, beam_idx: int, text: str): |
|
"""Write text to the specified beam's file""" |
|
if 0 <= beam_idx < self.num_beams and beam_idx in self.active_beams: |
|
|
|
self.beam_texts[beam_idx] = text |
|
|
|
|
|
filename = os.path.join(self.output_dir, f'beam_{beam_idx}.txt') |
|
with open(filename, 'w', encoding='utf-8') as f: |
|
f.write(self.beam_texts[beam_idx]) |
|
|
|
def finalize_beam(self, final_text: str): |
|
""" |
|
Handle a completed beam by creating a new file in the closed_beams directory. |
|
|
|
Args: |
|
final_text (str): The complete text generated by the finished beam |
|
""" |
|
|
|
self.counter += 1 |
|
filename = os.path.join(self.closed_beams_dir, f'completed_beam_{self.counter}.txt') |
|
|
|
|
|
with open(filename, 'w', encoding='utf-8') as f: |
|
f.write(final_text) |
|
|
|
return filename |
|
|
|
|
|
def setup_model_and_tokenizer(model_name): |
|
""" |
|
Initialize the model and tokenizer. |
|
|
|
Args: |
|
model_name (str): Name of the model to use |
|
|
|
Returns: |
|
tuple: (model, tokenizer) |
|
""" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype="auto", |
|
device_map="auto" |
|
) |
|
return model, tokenizer |
|
|
|
|
|
def generate_with_beam_search(model, tokenizer, user_prompt, output_dir, num_beams=5, max_new_tokens=512): |
|
""" |
|
Generate responses using beam search and write results to files. |
|
|
|
Args: |
|
model: The language model |
|
tokenizer: The tokenizer |
|
user_prompt (str): Input prompt |
|
output_dir (str): Directory to save beam outputs |
|
num_beams (int): Number of beams to use |
|
max_new_tokens (int): Maximum number of new tokens to generate |
|
""" |
|
|
|
output_manager = BeamOutputManager(output_dir, num_beams) |
|
|
|
def on_beam_update(beam_idx: int, new_text: str): |
|
"""Handler for beam updates - write new text to file""" |
|
output_manager.write_to_beam(beam_idx, new_text) |
|
|
|
def on_beam_finished(final_text: str): |
|
"""Handler for completed beams - create final output file""" |
|
final_path = output_manager.finalize_beam(final_text) |
|
print(f"\nCompleted beam saved to: {final_path}") |
|
|
|
|
|
messages = [ |
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
{"role": "user", "content": user_prompt} |
|
] |
|
|
|
|
|
text = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
|
|
|
|
model_inputs = tokenizer([text], return_tensors="pt").to(model.device) |
|
|
|
|
|
streamer = MultiBeamTextStreamer( |
|
tokenizer=tokenizer, |
|
num_beams=num_beams, |
|
on_beam_update=on_beam_update, |
|
on_beam_finished=on_beam_finished, |
|
skip_prompt=True |
|
) |
|
|
|
|
|
model.generate( |
|
**model_inputs, |
|
num_beams=num_beams, |
|
num_return_sequences=num_beams, |
|
max_new_tokens=max_new_tokens, |
|
output_scores=True, |
|
return_dict_in_generate=True, |
|
early_stopping=True, |
|
streamer=streamer |
|
) |
|
|
|
|
|
def main(): |
|
|
|
parser = argparse.ArgumentParser(description='Language Model Text Generation with Beam Search') |
|
parser.add_argument('--model', type=str, default='Qwen/Qwen2.5-0.5B-Instruct', |
|
help='Name of the model to use') |
|
parser.add_argument('--num_beams', type=int, default=5, |
|
help='Number of beams for beam search') |
|
parser.add_argument('--max_tokens', type=int, default=512, |
|
help='Maximum number of new tokens to generate') |
|
parser.add_argument('--output_dir', type=str, default='beam_outputs', |
|
help='Directory to save beam outputs') |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
model, tokenizer = setup_model_and_tokenizer(args.model) |
|
|
|
|
|
while True: |
|
prompt = input("\nEnter your prompt (or 'quit' to exit): ") |
|
if prompt.lower() == 'quit': |
|
break |
|
|
|
generate_with_beam_search( |
|
model, |
|
tokenizer, |
|
prompt, |
|
args.output_dir, |
|
num_beams=args.num_beams, |
|
max_new_tokens=args.max_tokens |
|
) |
|
print(f"\nOutputs written to: {args.output_dir}/beam_*.txt") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |