Kevin Fink commited on
Commit
108b583
·
1 Parent(s): d8c9b4e
Files changed (1) hide show
  1. app.py +5 -0
app.py CHANGED
@@ -140,6 +140,7 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
140
 
141
  elif os.access(f'/data/{hub_id.strip()}_train_dataset3', os.R_OK):
142
  dataset = load_dataset(dataset_name.strip())
 
143
  del dataset['train']
144
  del dataset['validation']
145
  test_set = dataset.map(tokenize_function, batched=True, batch_size=50, remove_columns=column_names,)
@@ -148,6 +149,7 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
148
 
149
  elif os.access(f'/data/{hub_id.strip()}_validation_dataset', os.R_OK):
150
  dataset = load_dataset(dataset_name.strip())
 
151
  train_size = len(dataset['train'])
152
  third_size = train_size // 3
153
  del dataset['test']
@@ -165,6 +167,8 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
165
 
166
  if os.access(f'/data/{hub_id.strip()}_train_dataset', os.R_OK) and not os.access(f'/data/{hub_id.strip()}_train_dataset3', os.R_OK):
167
  dataset = load_dataset(dataset_name.strip())
 
 
168
  train_size = len(dataset['train'])
169
  third_size = train_size // 3
170
  second_third = dataset['train'].select(range(third_size, third_size*2))
@@ -179,6 +183,7 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
179
  except Exception as e:
180
  print(f"An error occurred: {str(e)}, TB: {traceback.format_exc()}")
181
  dataset = load_dataset(dataset_name.strip())
 
182
  train_size = len(dataset['train'])
183
  third_size = train_size // 3
184
  # Tokenize the dataset
 
140
 
141
  elif os.access(f'/data/{hub_id.strip()}_train_dataset3', os.R_OK):
142
  dataset = load_dataset(dataset_name.strip())
143
+ dataset['test'] = dataset['test'].select(700)
144
  del dataset['train']
145
  del dataset['validation']
146
  test_set = dataset.map(tokenize_function, batched=True, batch_size=50, remove_columns=column_names,)
 
149
 
150
  elif os.access(f'/data/{hub_id.strip()}_validation_dataset', os.R_OK):
151
  dataset = load_dataset(dataset_name.strip())
152
+ dataset['train'] = dataset['train'].select(8000)
153
  train_size = len(dataset['train'])
154
  third_size = train_size // 3
155
  del dataset['test']
 
167
 
168
  if os.access(f'/data/{hub_id.strip()}_train_dataset', os.R_OK) and not os.access(f'/data/{hub_id.strip()}_train_dataset3', os.R_OK):
169
  dataset = load_dataset(dataset_name.strip())
170
+ dataset['train'] = dataset['train'].select(8000)
171
+ dataset['validation'] = dataset['validation'].select(300)
172
  train_size = len(dataset['train'])
173
  third_size = train_size // 3
174
  second_third = dataset['train'].select(range(third_size, third_size*2))
 
183
  except Exception as e:
184
  print(f"An error occurred: {str(e)}, TB: {traceback.format_exc()}")
185
  dataset = load_dataset(dataset_name.strip())
186
+ dataset['train'] = dataset['train'].select(8000)
187
  train_size = len(dataset['train'])
188
  third_size = train_size // 3
189
  # Tokenize the dataset