File size: 3,998 Bytes
20076b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import json
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Literal, Optional

from ..extras.constants import DATA_CONFIG
from ..extras.misc import use_modelscope


if TYPE_CHECKING:
    from ..hparams import DataArguments


@dataclass
class DatasetAttr:
    load_from: Literal["hf_hub", "ms_hub", "script", "file"]
    dataset_name: Optional[str] = None
    dataset_sha1: Optional[str] = None
    subset: Optional[str] = None
    folder: Optional[str] = None
    ranking: Optional[bool] = False
    formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"

    system: Optional[str] = None

    prompt: Optional[str] = "instruction"
    query: Optional[str] = "input"
    response: Optional[str] = "output"
    history: Optional[str] = None

    messages: Optional[str] = "conversations"
    tools: Optional[str] = None

    role_tag: Optional[str] = "from"
    content_tag: Optional[str] = "value"
    user_tag: Optional[str] = "human"
    assistant_tag: Optional[str] = "gpt"
    observation_tag: Optional[str] = "observation"
    function_tag: Optional[str] = "function_call"

    def __repr__(self) -> str:
        return self.dataset_name


def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
    dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] if data_args.dataset is not None else []
    try:
        with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
            dataset_info = json.load(f)
    except Exception as err:
        if data_args.dataset is not None:
            raise ValueError(
                "Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
            )
        dataset_info = None

    if data_args.interleave_probs is not None:
        data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")]

    dataset_list: List[DatasetAttr] = []
    for name in dataset_names:
        if name not in dataset_info:
            raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))

        has_hf_url = "hf_hub_url" in dataset_info[name]
        has_ms_url = "ms_hub_url" in dataset_info[name]

        if has_hf_url or has_ms_url:
            if (use_modelscope() and has_ms_url) or (not has_hf_url):
                dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
            else:
                dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
        elif "script_url" in dataset_info[name]:
            dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
        else:
            dataset_attr = DatasetAttr(
                "file",
                dataset_name=dataset_info[name]["file_name"],
                dataset_sha1=dataset_info[name].get("file_sha1", None),
            )

        dataset_attr.subset = dataset_info[name].get("subset", None)
        dataset_attr.folder = dataset_info[name].get("folder", None)
        dataset_attr.ranking = dataset_info[name].get("ranking", False)
        dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")

        if "columns" in dataset_info[name]:
            if dataset_attr.formatting == "alpaca":
                column_names = ["prompt", "query", "response", "history"]
            else:
                column_names = ["messages", "tools"]

            column_names += ["system"]
            for column_name in column_names:
                setattr(dataset_attr, column_name, dataset_info[name]["columns"].get(column_name, None))

        if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
            for tag in ["role_tag", "content_tag", "user_tag", "assistant_tag", "observation_tag", "function_tag"]:
                setattr(dataset_attr, tag, dataset_info[name]["tags"].get(tag, None))

        dataset_list.append(dataset_attr)

    return dataset_list