File size: 5,796 Bytes
a1b31ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import argparse
import os

from transformers import MultiBeamTextStreamer, AutoTokenizer, AutoModelForCausalLM


class BeamOutputManager:
    """Manages file handlers for beam outputs"""

    def __init__(self, output_dir: str, num_beams: int):
        self.output_dir = output_dir
        self.num_beams = num_beams
        self.counter = 0

        # Create main output directory and closed beams directory
        os.makedirs(output_dir, exist_ok=True)
        self.closed_beams_dir = os.path.join(output_dir, "closed_beams")
        os.makedirs(self.closed_beams_dir, exist_ok=True)

        # Store complete text for each beam
        self.beam_texts = {i: "" for i in range(num_beams)}
        self.active_beams = set(range(num_beams))

        # Initialize empty files
        for beam_idx in range(num_beams):
            filename = os.path.join(output_dir, f'beam_{beam_idx}.txt')
            with open(filename, 'w', encoding='utf-8') as f:
                f.write('')

    def write_to_beam(self, beam_idx: int, text: str):
        """Write text to the specified beam's file"""
        if 0 <= beam_idx < self.num_beams and beam_idx in self.active_beams:
            # Update stored text
            self.beam_texts[beam_idx] = text

            # Write complete text to file
            filename = os.path.join(self.output_dir, f'beam_{beam_idx}.txt')
            with open(filename, 'w', encoding='utf-8') as f:
                f.write(self.beam_texts[beam_idx])

    def finalize_beam(self, final_text: str):
        """
        Handle a completed beam by creating a new file in the closed_beams directory.

        Args:
            final_text (str): The complete text generated by the finished beam
        """
        # Create a timestamp-based filename to ensure uniqueness
        self.counter += 1
        filename = os.path.join(self.closed_beams_dir, f'completed_beam_{self.counter}.txt')

        # Write the final text to the completed beam file
        with open(filename, 'w', encoding='utf-8') as f:
            f.write(final_text)

        return filename


def setup_model_and_tokenizer(model_name):
    """
    Initialize the model and tokenizer.

    Args:
        model_name (str): Name of the model to use

    Returns:
        tuple: (model, tokenizer)
    """
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map="auto"
    )
    return model, tokenizer


def generate_with_beam_search(model, tokenizer, user_prompt, output_dir, num_beams=5, max_new_tokens=512):
    """
    Generate responses using beam search and write results to files.

    Args:
        model: The language model
        tokenizer: The tokenizer
        user_prompt (str): Input prompt
        output_dir (str): Directory to save beam outputs
        num_beams (int): Number of beams to use
        max_new_tokens (int): Maximum number of new tokens to generate
    """
    # Initialize the output manager
    output_manager = BeamOutputManager(output_dir, num_beams)

    def on_beam_update(beam_idx: int, new_text: str):
        """Handler for beam updates - write new text to file"""
        output_manager.write_to_beam(beam_idx, new_text)

    def on_beam_finished(final_text: str):
        """Handler for completed beams - create final output file"""
        final_path = output_manager.finalize_beam(final_text)
        print(f"\nCompleted beam saved to: {final_path}")

    # Create messages format
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": user_prompt}
    ]

    # Apply chat template
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    # Prepare inputs
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

    # Initialize streamer with handlers
    streamer = MultiBeamTextStreamer(
        tokenizer=tokenizer,
        num_beams=num_beams,
        on_beam_update=on_beam_update,
        on_beam_finished=on_beam_finished,
        skip_prompt=True
    )

    # Generate with beam search
    model.generate(
        **model_inputs,
        num_beams=num_beams,
        num_return_sequences=num_beams,
        max_new_tokens=max_new_tokens,
        output_scores=True,
        return_dict_in_generate=True,
        early_stopping=True,
        streamer=streamer
    )


def main():
    # Setup command line arguments
    parser = argparse.ArgumentParser(description='Language Model Text Generation with Beam Search')
    parser.add_argument('--model', type=str, default='Qwen/Qwen2.5-0.5B-Instruct',
                        help='Name of the model to use')
    parser.add_argument('--num_beams', type=int, default=5,
                        help='Number of beams for beam search')
    parser.add_argument('--max_tokens', type=int, default=512,
                        help='Maximum number of new tokens to generate')
    parser.add_argument('--output_dir', type=str, default='beam_outputs',
                        help='Directory to save beam outputs')

    args = parser.parse_args()

    # Initialize model and tokenizer
    model, tokenizer = setup_model_and_tokenizer(args.model)

    # Interactive loop
    while True:
        prompt = input("\nEnter your prompt (or 'quit' to exit): ")
        if prompt.lower() == 'quit':
            break

        generate_with_beam_search(
            model,
            tokenizer,
            prompt,
            args.output_dir,
            num_beams=args.num_beams,
            max_new_tokens=args.max_tokens
        )
        print(f"\nOutputs written to: {args.output_dir}/beam_*.txt")


if __name__ == "__main__":
    main()