File size: 2,663 Bytes
8e083dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
from q_star import GPTWithMoE, GPTConfig, mcts_decode_single


def generate_text_with_mcts(

    model: GPTWithMoE,

    tokenizer,  # Tokenizer to encode and decode text

    prompt: str,

    max_length: int = 50,

    num_simulations: int = 100,

    c_puct: float = 1.0,

    top_k: int = 10,

    device: str = "cuda"

):
    """

    Generate text using the GPTWithMoE model and MCTS-based decoding.



    Args:

        model (GPTWithMoE): The trained model.

        tokenizer: The tokenizer for text encoding and decoding.

        prompt (str): The initial text prompt.

        max_length (int): Maximum length of the generated text.

        num_simulations (int): Number of MCTS simulations for each decoding step.

        c_puct (float): Exploration parameter for MCTS.

        top_k (int): Top-k tokens to consider during MCTS expansion.

        device (str): Device to use for computation.



    Returns:

        str: The generated text.

    """
    model.eval()
    model.to(device)

    # Encode the prompt into input_ids
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    # Use MCTS to decode the sequence
    generated_ids = mcts_decode_single(
        model=model,
        input_ids=input_ids,
        max_length=max_length,
        num_simulations=num_simulations,
        c_puct=c_puct,
        top_k=top_k,
    )

    # Decode the generated IDs back to text
    generated_text = tokenizer.decode(generated_ids.tolist(), skip_special_tokens=True)
    return generated_text


if __name__ == "__main__":
    from transformers import GPT2Tokenizer

    # Define the device
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Initialize the tokenizer (adapt as per your model's tokenizer)
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token

    # Load the trained model
    config = GPTConfig(vocab_size=50304, block_size=512, n_layer=6, n_head=4, n_embd=256)
    model = GPTWithMoE(config, num_experts=3, expert_layers=3, block_size_q=32, block_size_kv=32, num_blocks_kv=4, device=device)
    model.load_state_dict(torch.load("C:\\Users\\Admin\\MODELS\\moe_mcts_new.pt", map_location=device))

    # Generate text using a prompt
    prompt = "Once upon a time in a distant galaxy,"
    generated_text = generate_text_with_mcts(
        model=model,
        tokenizer=tokenizer,
        prompt=prompt,
        max_length=100,
        num_simulations=50,
        c_puct=1.5,
        top_k=5,
        device=device,
    )

    print("Generated Text:")
    print(generated_text)