hackergeek98 commited on
Commit
6c8c083
Β·
verified Β·
1 Parent(s): 20915bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -15
app.py CHANGED
@@ -14,7 +14,7 @@ import sys
14
  # Configure logging
15
  logging.basicConfig(stream=sys.stdout, level=logging.INFO)
16
 
17
- def train():
18
  try:
19
  # Load model and tokenizer
20
  model_name = "microsoft/phi-2"
@@ -25,19 +25,22 @@ def train():
25
  if tokenizer.pad_token is None:
26
  tokenizer.pad_token = tokenizer.eos_token
27
 
28
- # Load dataset
 
29
  dataset = load_dataset(
30
- "csv",
31
- data_files={
32
- "train": "data/train/data.csv",
33
- "validation": "data/validation/data.csv"
34
- }
35
  )
36
 
37
- # Tokenization function
 
 
 
38
  def tokenize_function(examples):
39
  return tokenizer(
40
- examples["text"],
41
  padding="max_length",
42
  truncation=True,
43
  max_length=256,
@@ -47,7 +50,7 @@ def train():
47
  tokenized_dataset = dataset.map(
48
  tokenize_function,
49
  batched=True,
50
- remove_columns=["text"]
51
  )
52
 
53
  # Data collator
@@ -72,7 +75,7 @@ def train():
72
  model=model,
73
  args=training_args,
74
  train_dataset=tokenized_dataset["train"],
75
- eval_dataset=tokenized_dataset["validation"],
76
  data_collator=data_collator,
77
  )
78
 
@@ -88,16 +91,20 @@ def train():
88
  logging.error(f"Training failed: {str(e)}")
89
  return f"❌ Training failed: {str(e)}"
90
 
91
- # Gradio UI
92
  with gr.Blocks(title="Phi-2 Training") as demo:
93
- gr.Markdown("# πŸš€ Train Phi-2 on CPU")
94
 
95
  with gr.Row():
96
- start_btn = gr.Button("Start Training", variant="primary")
97
- status_output = gr.Textbox(label="Status", interactive=False)
 
 
 
98
 
99
  start_btn.click(
100
  fn=train,
 
101
  outputs=status_output
102
  )
103
 
 
14
  # Configure logging
15
  logging.basicConfig(stream=sys.stdout, level=logging.INFO)
16
 
17
+ def train(dataset_name: str, dataset_config: str = None):
18
  try:
19
  # Load model and tokenizer
20
  model_name = "microsoft/phi-2"
 
25
  if tokenizer.pad_token is None:
26
  tokenizer.pad_token = tokenizer.eos_token
27
 
28
+ # Load dataset from Hugging Face Hub
29
+ logging.info(f"Loading dataset: {eswardivi/medical_qa} (config: {dataset_config})")
30
  dataset = load_dataset(
31
+ dataset_name,
32
+ dataset_config, # Optional config (e.g., language for Common Voice)
33
+ split="train+validation", # Combine splits
34
+ trust_remote_code=True # Required for some datasets
 
35
  )
36
 
37
+ # Split into train/validation
38
+ dataset = dataset.train_test_split(test_size=0.1, seed=42)
39
+
40
+ # Tokenization function (adjust based on dataset columns)
41
  def tokenize_function(examples):
42
  return tokenizer(
43
+ examples["text"], # Replace "text" with your dataset's text column
44
  padding="max_length",
45
  truncation=True,
46
  max_length=256,
 
50
  tokenized_dataset = dataset.map(
51
  tokenize_function,
52
  batched=True,
53
+ remove_columns=dataset["train"].column_names
54
  )
55
 
56
  # Data collator
 
75
  model=model,
76
  args=training_args,
77
  train_dataset=tokenized_dataset["train"],
78
+ eval_dataset=tokenized_dataset["test"],
79
  data_collator=data_collator,
80
  )
81
 
 
91
  logging.error(f"Training failed: {str(e)}")
92
  return f"❌ Training failed: {str(e)}"
93
 
94
+ # Gradio UI with dataset input
95
  with gr.Blocks(title="Phi-2 Training") as demo:
96
+ gr.Markdown("# πŸš€ Train Phi-2 with HF Hub Data")
97
 
98
  with gr.Row():
99
+ dataset_name = gr.Textbox(label="Dataset Name", value="mozilla-foundation/common_voice_11_0")
100
+ dataset_config = gr.Textbox(label="Dataset Config (optional)", value="en")
101
+
102
+ start_btn = gr.Button("Start Training", variant="primary")
103
+ status_output = gr.Textbox(label="Status", interactive=False)
104
 
105
  start_btn.click(
106
  fn=train,
107
+ inputs=[dataset_name, dataset_config],
108
  outputs=status_output
109
  )
110