File size: 1,985 Bytes
c5e73ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import json
import matplotlib.pyplot as plt
import numpy as np
from transformers import (
    AutoTokenizer,
    PreTrainedTokenizer,
    PreTrainedTokenizerBase,
    PreTrainedTokenizerFast,
)

# Open datasets
file_paths = ["ShareGPT_V3_filtered.json", "ShareGPT_V3_filtered_500.json"]

names = [file_path[:-5] for file_path in file_paths]

data_lists = []
for file_path in file_paths:
    with open(file_path, "r", encoding="utf-8") as file:
        data_list = json.load(file)
        data_lists.append(data_list)

for name, data_list in zip(names, data_lists):
    print(f"{name}: {len(data_list)}")

# Get prompt lengths using tokenizer
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
all_prompts = [
    [data["conversations"][0]["value"] for data in data_lists]
    for data_lists in data_lists
]
all_token_ids_per_prompts = [tokenizer(prompts).input_ids for prompts in all_prompts]
all_prompt_lens = [
    [len(token_ids) for token_ids in token_ids_per_prompt]
    for token_ids_per_prompt in all_token_ids_per_prompts
]

# Plotting the histograms
for name, prompt_lens in zip(names, all_prompt_lens):
    plt.hist(
        prompt_lens,
        bins=range(min(prompt_lens), max(prompt_lens) + 1),
        edgecolor="black",
    )
    plt.xlabel("Prompt Length (number of tokens)")
    plt.ylabel("Frequency")
    plt.title(f"Histogram of {name}")
    plt.savefig(f"{name}_distribution.png")
    plt.close()

# Plotting the CDF
for name, prompt_lens in zip(names, all_prompt_lens):
    values, counts = np.unique(prompt_lens, return_counts=True)
    relative_frequencies = counts / len(prompt_lens)
    sorted_data = np.sort(values)
    cumulative_frequencies = np.cumsum(relative_frequencies)
    plt.step(sorted_data, cumulative_frequencies, where="post", label=name)

plt.title(f"Cumulative Distribution Function (CDF) Overlayed")
plt.xlabel("Prompt Length (number of tokens)")
plt.ylabel("Cumulative Probability")
plt.savefig(f"{name}_cdf.png")
plt.close()