File size: 4,931 Bytes
e20d86e
925ba7d
767fba0
db16ef6
 
 
e20d86e
 
 
 
 
 
 
 
925ba7d
 
 
ef46523
db16ef6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501033d
 
925ba7d
db16ef6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501033d
db16ef6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
925ba7d
db16ef6
 
ef46523
767fba0
925ba7d
6c8c083
925ba7d
 
501033d
 
 
 
6c8c083
 
 
925ba7d
 
767fba0
501033d
925ba7d
 
e20d86e
ef46523
62ffb32
 
f576f58
62ffb32
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import torch
import gradio as gr
import threading
import logging
import sys
from urllib.parse import urlparse
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from datasets import load_dataset

# Configure logging
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

def parse_hf_dataset_url(url: str) -> tuple[str, str | None]:
    """Parse Hugging Face dataset URL into (dataset_name, config)"""
    parsed = urlparse(url)
    path_parts = parsed.path.split('/')
    
    try:
        # Find 'datasets' in path
        datasets_idx = path_parts.index('datasets')
    except ValueError:
        raise ValueError("Invalid Hugging Face dataset URL")
    
    dataset_parts = path_parts[datasets_idx+1:]
    dataset_name = "/".join(dataset_parts[0:2])
    
    # Try to find config (common pattern for datasets with viewer)
    try:
        viewer_idx = dataset_parts.index('viewer')
        config = dataset_parts[viewer_idx+1] if viewer_idx+1 < len(dataset_parts) else None
    except ValueError:
        config = None
    
    return dataset_name, config

def train(dataset_url: str):
    try:
        # Parse dataset URL
        dataset_name, dataset_config = parse_hf_dataset_url(dataset_url)
        logging.info(f"Loading dataset: {dataset_name} (config: {dataset_config})")

        # Load model and tokenizer
        model_name = "microsoft/phi-2"
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu", trust_remote_code=True)

        # Add padding token
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        # Load dataset from Hugging Face Hub
        dataset = load_dataset(
            dataset_name,
            dataset_config,
            trust_remote_code=True
        )

        # Handle dataset splits
        if "train" not in dataset:
            raise ValueError("Dataset must have a 'train' split")
        
        train_dataset = dataset["train"]
        eval_dataset = dataset.get("validation", dataset.get("test", None))

        # Split if no validation set
        if eval_dataset is None:
            split = train_dataset.train_test_split(test_size=0.1, seed=42)
            train_dataset = split["train"]
            eval_dataset = split["test"]

        # Tokenization function
        def tokenize_function(examples):
            return tokenizer(
                examples["text"],  # Adjust column name as needed
                padding="max_length",
                truncation=True,
                max_length=256,
                return_tensors="pt",
            )

        # Tokenize datasets
        tokenized_train = train_dataset.map(
            tokenize_function,
            batched=True,
            remove_columns=train_dataset.column_names
        )
        tokenized_eval = eval_dataset.map(
            tokenize_function,
            batched=True,
            remove_columns=eval_dataset.column_names
        )

        # Data collator
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=tokenizer,
            mlm=False
        )

        # Training arguments
        training_args = TrainingArguments(
            output_dir="./phi2-results",
            per_device_train_batch_size=2,
            per_device_eval_batch_size=2,
            num_train_epochs=3,
            logging_dir="./logs",
            logging_steps=10,
            fp16=False,
        )

        # Trainer
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=tokenized_train,
            eval_dataset=tokenized_eval,
            data_collator=data_collator,
        )

        # Start training
        logging.info("Training started...")
        trainer.train()
        trainer.save_model("./phi2-trained-model")
        logging.info("Training completed!")

        return "βœ… Training succeeded! Model saved."

    except Exception as e:
        logging.error(f"Training failed: {str(e)}")
        return f"❌ Training failed: {str(e)}"

# Gradio interface
with gr.Blocks(title="Phi-2 Training") as demo:
    gr.Markdown("# πŸš€ Train Phi-2 with HF Hub Data")
    
    with gr.Row():
        dataset_url = gr.Textbox(
            label="Dataset URL",
            value="https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0"
        )
    
    start_btn = gr.Button("Start Training", variant="primary")
    status_output = gr.Textbox(label="Status", interactive=False)
    
    start_btn.click(
        fn=lambda url: threading.Thread(target=train, args=(url,)).start(),
        inputs=[dataset_url],
        outputs=status_output
    )

if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860
    )