File size: 5,796 Bytes
a1b31ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import argparse
import os
from transformers import MultiBeamTextStreamer, AutoTokenizer, AutoModelForCausalLM
class BeamOutputManager:
"""Manages file handlers for beam outputs"""
def __init__(self, output_dir: str, num_beams: int):
self.output_dir = output_dir
self.num_beams = num_beams
self.counter = 0
# Create main output directory and closed beams directory
os.makedirs(output_dir, exist_ok=True)
self.closed_beams_dir = os.path.join(output_dir, "closed_beams")
os.makedirs(self.closed_beams_dir, exist_ok=True)
# Store complete text for each beam
self.beam_texts = {i: "" for i in range(num_beams)}
self.active_beams = set(range(num_beams))
# Initialize empty files
for beam_idx in range(num_beams):
filename = os.path.join(output_dir, f'beam_{beam_idx}.txt')
with open(filename, 'w', encoding='utf-8') as f:
f.write('')
def write_to_beam(self, beam_idx: int, text: str):
"""Write text to the specified beam's file"""
if 0 <= beam_idx < self.num_beams and beam_idx in self.active_beams:
# Update stored text
self.beam_texts[beam_idx] = text
# Write complete text to file
filename = os.path.join(self.output_dir, f'beam_{beam_idx}.txt')
with open(filename, 'w', encoding='utf-8') as f:
f.write(self.beam_texts[beam_idx])
def finalize_beam(self, final_text: str):
"""
Handle a completed beam by creating a new file in the closed_beams directory.
Args:
final_text (str): The complete text generated by the finished beam
"""
# Create a timestamp-based filename to ensure uniqueness
self.counter += 1
filename = os.path.join(self.closed_beams_dir, f'completed_beam_{self.counter}.txt')
# Write the final text to the completed beam file
with open(filename, 'w', encoding='utf-8') as f:
f.write(final_text)
return filename
def setup_model_and_tokenizer(model_name):
"""
Initialize the model and tokenizer.
Args:
model_name (str): Name of the model to use
Returns:
tuple: (model, tokenizer)
"""
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
return model, tokenizer
def generate_with_beam_search(model, tokenizer, user_prompt, output_dir, num_beams=5, max_new_tokens=512):
"""
Generate responses using beam search and write results to files.
Args:
model: The language model
tokenizer: The tokenizer
user_prompt (str): Input prompt
output_dir (str): Directory to save beam outputs
num_beams (int): Number of beams to use
max_new_tokens (int): Maximum number of new tokens to generate
"""
# Initialize the output manager
output_manager = BeamOutputManager(output_dir, num_beams)
def on_beam_update(beam_idx: int, new_text: str):
"""Handler for beam updates - write new text to file"""
output_manager.write_to_beam(beam_idx, new_text)
def on_beam_finished(final_text: str):
"""Handler for completed beams - create final output file"""
final_path = output_manager.finalize_beam(final_text)
print(f"\nCompleted beam saved to: {final_path}")
# Create messages format
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": user_prompt}
]
# Apply chat template
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Prepare inputs
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# Initialize streamer with handlers
streamer = MultiBeamTextStreamer(
tokenizer=tokenizer,
num_beams=num_beams,
on_beam_update=on_beam_update,
on_beam_finished=on_beam_finished,
skip_prompt=True
)
# Generate with beam search
model.generate(
**model_inputs,
num_beams=num_beams,
num_return_sequences=num_beams,
max_new_tokens=max_new_tokens,
output_scores=True,
return_dict_in_generate=True,
early_stopping=True,
streamer=streamer
)
def main():
# Setup command line arguments
parser = argparse.ArgumentParser(description='Language Model Text Generation with Beam Search')
parser.add_argument('--model', type=str, default='Qwen/Qwen2.5-0.5B-Instruct',
help='Name of the model to use')
parser.add_argument('--num_beams', type=int, default=5,
help='Number of beams for beam search')
parser.add_argument('--max_tokens', type=int, default=512,
help='Maximum number of new tokens to generate')
parser.add_argument('--output_dir', type=str, default='beam_outputs',
help='Directory to save beam outputs')
args = parser.parse_args()
# Initialize model and tokenizer
model, tokenizer = setup_model_and_tokenizer(args.model)
# Interactive loop
while True:
prompt = input("\nEnter your prompt (or 'quit' to exit): ")
if prompt.lower() == 'quit':
break
generate_with_beam_search(
model,
tokenizer,
prompt,
args.output_dir,
num_beams=args.num_beams,
max_new_tokens=args.max_tokens
)
print(f"\nOutputs written to: {args.output_dir}/beam_*.txt")
if __name__ == "__main__":
main() |