File size: 1,978 Bytes
d63014d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import transformers
import torch

# Model and tokenizer initialization
model_path_name = "SicariusSicariiStuff/LLAMA-3_8B_Unaligned_BETA"  # Replace with your model path

# Initialize the pipeline
pipeline = transformers.pipeline(
    "text-generation",
    model=model_path_name,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="auto",  # Adjust to 'cuda' if needed
)

# Prepare the message list
message_list = [
    [
        {'role': 'system', 'content': "You are an AI assistant."},
        {'role': 'user', 'content': "Who are you?"}
    ]
]

# Apply the chat template or manually format the prompts
try:
    prompts = [
        pipeline.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
        for messages in message_list
    ]
except AttributeError:
    # Fallback: Manually format the prompts if `apply_chat_template` is unsupported
    prompts = [
        f"<|im_start|>system\n{msg[0]['content']}<|im_end|>\n"
        f"<|im_start|>user\n{msg[1]['content']}<|im_end|>\n<|im_start|>assistant\n"
        for msg in message_list
    ]

# Debugging: Print prompts
print("Formatted Prompts:", prompts)

# Validate tokenizer and model's EOS and PAD token IDs
eos_token_id = pipeline.tokenizer.eos_token_id or 50256  # Default fallback for GPT-like models
pad_token_id = eos_token_id  # Ensure consistency
print("EOS Token ID:", eos_token_id)

# Tokenize the prompts (optional debugging step)
tokens = pipeline.tokenizer(prompts, padding=True, return_tensors="pt")
print("Tokenized Input:", tokens)

# Generate the output
try:
    outputs = pipeline(
        prompts,
        max_new_tokens=100,  # Reduce for debugging purposes
        do_sample=True,
        temperature=0.5,
        top_p=0.5,
        eos_token_id=eos_token_id,
        pad_token_id=pad_token_id,
    )
    print("Outputs:", outputs)
except Exception as e:
    print("Error during generation:", str(e))