Spaces:
Sleeping
Sleeping
File size: 4,931 Bytes
e20d86e 925ba7d 767fba0 db16ef6 e20d86e 925ba7d ef46523 db16ef6 501033d 925ba7d db16ef6 501033d db16ef6 925ba7d db16ef6 ef46523 767fba0 925ba7d 6c8c083 925ba7d 501033d 6c8c083 925ba7d 767fba0 501033d 925ba7d e20d86e ef46523 62ffb32 f576f58 62ffb32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import torch
import gradio as gr
import threading
import logging
import sys
from urllib.parse import urlparse
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from datasets import load_dataset
# Configure logging
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
def parse_hf_dataset_url(url: str) -> tuple[str, str | None]:
"""Parse Hugging Face dataset URL into (dataset_name, config)"""
parsed = urlparse(url)
path_parts = parsed.path.split('/')
try:
# Find 'datasets' in path
datasets_idx = path_parts.index('datasets')
except ValueError:
raise ValueError("Invalid Hugging Face dataset URL")
dataset_parts = path_parts[datasets_idx+1:]
dataset_name = "/".join(dataset_parts[0:2])
# Try to find config (common pattern for datasets with viewer)
try:
viewer_idx = dataset_parts.index('viewer')
config = dataset_parts[viewer_idx+1] if viewer_idx+1 < len(dataset_parts) else None
except ValueError:
config = None
return dataset_name, config
def train(dataset_url: str):
try:
# Parse dataset URL
dataset_name, dataset_config = parse_hf_dataset_url(dataset_url)
logging.info(f"Loading dataset: {dataset_name} (config: {dataset_config})")
# Load model and tokenizer
model_name = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu", trust_remote_code=True)
# Add padding token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load dataset from Hugging Face Hub
dataset = load_dataset(
dataset_name,
dataset_config,
trust_remote_code=True
)
# Handle dataset splits
if "train" not in dataset:
raise ValueError("Dataset must have a 'train' split")
train_dataset = dataset["train"]
eval_dataset = dataset.get("validation", dataset.get("test", None))
# Split if no validation set
if eval_dataset is None:
split = train_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = split["train"]
eval_dataset = split["test"]
# Tokenization function
def tokenize_function(examples):
return tokenizer(
examples["text"], # Adjust column name as needed
padding="max_length",
truncation=True,
max_length=256,
return_tensors="pt",
)
# Tokenize datasets
tokenized_train = train_dataset.map(
tokenize_function,
batched=True,
remove_columns=train_dataset.column_names
)
tokenized_eval = eval_dataset.map(
tokenize_function,
batched=True,
remove_columns=eval_dataset.column_names
)
# Data collator
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
# Training arguments
training_args = TrainingArguments(
output_dir="./phi2-results",
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
num_train_epochs=3,
logging_dir="./logs",
logging_steps=10,
fp16=False,
)
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_train,
eval_dataset=tokenized_eval,
data_collator=data_collator,
)
# Start training
logging.info("Training started...")
trainer.train()
trainer.save_model("./phi2-trained-model")
logging.info("Training completed!")
return "β
Training succeeded! Model saved."
except Exception as e:
logging.error(f"Training failed: {str(e)}")
return f"β Training failed: {str(e)}"
# Gradio interface
with gr.Blocks(title="Phi-2 Training") as demo:
gr.Markdown("# π Train Phi-2 with HF Hub Data")
with gr.Row():
dataset_url = gr.Textbox(
label="Dataset URL",
value="https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0"
)
start_btn = gr.Button("Start Training", variant="primary")
status_output = gr.Textbox(label="Status", interactive=False)
start_btn.click(
fn=lambda url: threading.Thread(target=train, args=(url,)).start(),
inputs=[dataset_url],
outputs=status_output
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860
) |