jbilcke-hf HF Staff commited on
Commit
e2cd770
Β·
1 Parent(s): 183a6c1
vms/ui/project/services/training.py CHANGED
@@ -1821,17 +1821,193 @@ class TrainingService:
1821
  print(f"Failed to create checkpoint zip: {str(e)}")
1822
  raise gr.Error(f"Failed to create checkpoint zip: {str(e)}")
1823
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1824
  def get_checkpoint_button_text(self) -> str:
1825
- """Get the dynamic text for the download checkpoint button based on available checkpoints"""
1826
  try:
1827
  checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
1828
  if not checkpoints:
1829
- return "πŸ“₯ Download checkpoints (not available)"
1830
 
1831
  # Get the latest checkpoint by step number
1832
  latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
1833
  step_num = int(latest_checkpoint.name.split("_")[-1])
1834
- return f"πŸ“₯ Download checkpoints (step {step_num})"
1835
  except Exception as e:
1836
  logger.warning(f"Error getting checkpoint info for button text: {e}")
1837
- return "πŸ“₯ Download checkpoints (not available)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1821
  print(f"Failed to create checkpoint zip: {str(e)}")
1822
  raise gr.Error(f"Failed to create checkpoint zip: {str(e)}")
1823
 
1824
+ def create_everything_zip(self) -> Optional[str]:
1825
+ """Create a ZIP file containing both the latest checkpoint folder and LoRA weights
1826
+
1827
+ Returns:
1828
+ Path to created ZIP file or None if no checkpoint found
1829
+ """
1830
+ # Find all checkpoint directories
1831
+ checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
1832
+ if not checkpoints:
1833
+ logger.info("No checkpoint directories found")
1834
+ raise gr.Error("No checkpoint directories found")
1835
+
1836
+ # Get the latest checkpoint by step number
1837
+ latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
1838
+ step_num = int(latest_checkpoint.name.split("_")[-1])
1839
+
1840
+ # Check if LoRA weights exist
1841
+ lora_weights_path = self.app.output_path / "pytorch_lora_weights.safetensors"
1842
+ if not lora_weights_path.exists():
1843
+ logger.warning("LoRA weights file not found, will only include checkpoint")
1844
+
1845
+ # Create temporary directory for combining files
1846
+ with tempfile.TemporaryDirectory() as temp_dir:
1847
+ temp_path = Path(temp_dir)
1848
+
1849
+ # Copy checkpoint directory to temp location
1850
+ checkpoint_dest = temp_path / latest_checkpoint.name
1851
+ shutil.copytree(latest_checkpoint, checkpoint_dest)
1852
+ print(f"Copied checkpoint {latest_checkpoint.name} to temporary location")
1853
+
1854
+ # Copy LoRA weights if they exist
1855
+ if lora_weights_path.exists():
1856
+ lora_dest = temp_path / "pytorch_lora_weights.safetensors"
1857
+ shutil.copy2(lora_weights_path, lora_dest)
1858
+ print(f"Copied LoRA weights to temporary location")
1859
+
1860
+ # Create temporary zip file
1861
+ with tempfile.NamedTemporaryFile(suffix='.zip', delete=False) as temp_zip:
1862
+ temp_zip_path = str(temp_zip.name)
1863
+ print(f"Creating combined zip file with checkpoint and LoRA weights...")
1864
+ try:
1865
+ # Create the archive with all contents from temp directory
1866
+ make_archive(temp_path, temp_zip_path)
1867
+ print(f"Combined zip file created for step {step_num}!")
1868
+ return temp_zip_path
1869
+ except Exception as e:
1870
+ print(f"Failed to create combined zip: {str(e)}")
1871
+ raise gr.Error(f"Failed to create combined zip: {str(e)}")
1872
+
1873
  def get_checkpoint_button_text(self) -> str:
1874
+ """Get the dynamic text for the download everything button based on available checkpoints"""
1875
  try:
1876
  checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
1877
  if not checkpoints:
1878
+ return "πŸ“₯ Download everything (not available)"
1879
 
1880
  # Get the latest checkpoint by step number
1881
  latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
1882
  step_num = int(latest_checkpoint.name.split("_")[-1])
1883
+ return f"πŸ“₯ Download everything (step {step_num})"
1884
  except Exception as e:
1885
  logger.warning(f"Error getting checkpoint info for button text: {e}")
1886
+ return "πŸ“₯ Download everything (not available)"
1887
+
1888
+ def restore_checkpoint_from_zip(self, zip_file_path: str) -> Tuple[bool, str]:
1889
+ """Restore checkpoint from a ZIP file
1890
+
1891
+ Args:
1892
+ zip_file_path: Path to the ZIP file containing checkpoint data
1893
+
1894
+ Returns:
1895
+ Tuple of (success, message)
1896
+ """
1897
+ try:
1898
+ # Check if training is running
1899
+ if self.is_training_running():
1900
+ return False, "Cannot restore checkpoint while training is running. Please stop training first."
1901
+
1902
+ import zipfile
1903
+
1904
+ # Validate ZIP file
1905
+ if not zipfile.is_zipfile(zip_file_path):
1906
+ return False, "The provided file is not a valid ZIP archive."
1907
+
1908
+ # Extract the ZIP file to a temporary directory first
1909
+ import tempfile
1910
+ with tempfile.TemporaryDirectory() as temp_dir:
1911
+ temp_path = Path(temp_dir)
1912
+
1913
+ # Extract ZIP
1914
+ with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
1915
+ zip_ref.extractall(temp_path)
1916
+ logger.info(f"Extracted backup to temporary directory: {temp_path}")
1917
+
1918
+ # Check what was extracted
1919
+ extracted_items = list(temp_path.glob("*"))
1920
+ if not extracted_items:
1921
+ return False, "The ZIP file appears to be empty."
1922
+
1923
+ logger.info(f"Extracted items: {[item.name for item in extracted_items]}")
1924
+
1925
+ # Look for checkpoint directory pattern (finetrainers_step_*)
1926
+ checkpoint_dirs = [d for d in extracted_items if d.is_dir() and d.name.startswith("finetrainers_step_")]
1927
+
1928
+ if checkpoint_dirs:
1929
+ # We have a checkpoint directory
1930
+ checkpoint_dir = checkpoint_dirs[0]
1931
+ checkpoint_name = checkpoint_dir.name
1932
+ step_num = int(checkpoint_name.split("_")[-1])
1933
+
1934
+ # Check if this checkpoint already exists
1935
+ target_path = self.app.output_path / checkpoint_name
1936
+ if target_path.exists():
1937
+ # Backup existing checkpoint
1938
+ backup_name = f"{checkpoint_name}_backup_{int(time.time())}"
1939
+ backup_path = self.app.output_path / backup_name
1940
+ shutil.move(str(target_path), str(backup_path))
1941
+ logger.info(f"Backed up existing checkpoint to {backup_name}")
1942
+
1943
+ # Move the checkpoint to output directory
1944
+ shutil.move(str(checkpoint_dir), str(target_path))
1945
+ logger.info(f"Restored checkpoint {checkpoint_name} to {target_path}")
1946
+
1947
+ # Also check for LoRA weights in the archive
1948
+ lora_file = None
1949
+ for item in extracted_items:
1950
+ if item.name == "pytorch_lora_weights.safetensors":
1951
+ lora_file = item
1952
+ break
1953
+
1954
+ if lora_file:
1955
+ # Copy LoRA weights to output directory root
1956
+ target_lora_path = self.app.output_path / "pytorch_lora_weights.safetensors"
1957
+ if target_lora_path.exists():
1958
+ backup_lora_name = f"pytorch_lora_weights_backup_{int(time.time())}.safetensors"
1959
+ backup_lora_path = self.app.output_path / backup_lora_name
1960
+ shutil.move(str(target_lora_path), str(backup_lora_path))
1961
+ logger.info(f"Backed up existing LoRA weights to {backup_lora_name}")
1962
+
1963
+ shutil.copy2(str(lora_file), str(target_lora_path))
1964
+ logger.info(f"Restored LoRA weights to {target_lora_path}")
1965
+
1966
+ return True, f"Successfully restored checkpoint at step {step_num} with LoRA weights"
1967
+ else:
1968
+ return True, f"Successfully restored checkpoint at step {step_num}"
1969
+
1970
+ else:
1971
+ # No checkpoint directory found, look for direct files
1972
+ # This could be an output directory backup
1973
+ logger.info("No checkpoint directory found, treating as output directory backup")
1974
+
1975
+ # Clear existing output directory (backup first if needed)
1976
+ if any(self.app.output_path.iterdir()):
1977
+ backup_dir = self.app.output_path.parent / f"output_backup_{int(time.time())}"
1978
+ shutil.copytree(str(self.app.output_path), str(backup_dir))
1979
+ logger.info(f"Backed up existing output to {backup_dir}")
1980
+
1981
+ # Clear output directory
1982
+ for item in self.app.output_path.iterdir():
1983
+ if item.is_dir():
1984
+ shutil.rmtree(item)
1985
+ else:
1986
+ item.unlink()
1987
+
1988
+ # Move all extracted items to output directory
1989
+ for item in extracted_items:
1990
+ target = self.app.output_path / item.name
1991
+ shutil.move(str(item), str(target))
1992
+ logger.info(f"Restored {item.name} to output directory")
1993
+
1994
+ # Check what we restored
1995
+ restored_checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
1996
+ restored_lora = (self.app.output_path / "pytorch_lora_weights.safetensors").exists()
1997
+
1998
+ if restored_checkpoints:
1999
+ latest = max(restored_checkpoints, key=lambda x: int(x.name.split("_")[-1]))
2000
+ step_num = int(latest.name.split("_")[-1])
2001
+ if restored_lora:
2002
+ return True, f"Successfully restored output directory with checkpoint at step {step_num} and LoRA weights"
2003
+ else:
2004
+ return True, f"Successfully restored output directory with checkpoint at step {step_num}"
2005
+ elif restored_lora:
2006
+ return True, "Successfully restored output directory with LoRA weights"
2007
+ else:
2008
+ return True, "Successfully restored output directory backup"
2009
+
2010
+ except Exception as e:
2011
+ error_msg = f"Failed to restore checkpoint: {str(e)}"
2012
+ logger.error(error_msg, exc_info=True)
2013
+ return False, error_msg
vms/ui/project/tabs/manage_tab.py CHANGED
@@ -40,12 +40,12 @@ class ManageTab(BaseTab):
40
  return "🧠 Download weights (.safetensors)"
41
 
42
  def get_checkpoint_button_text(self) -> str:
43
- """Get the dynamic text for the download checkpoint button"""
44
  try:
45
  return self.app.training.get_checkpoint_button_text()
46
  except Exception as e:
47
  logger.warning(f"Error getting checkpoint button text: {e}")
48
- return "πŸ“₯ Download checkpoints (not available)"
49
 
50
  def update_download_button_text(self) -> gr.update:
51
  """Update the download button text"""
@@ -107,6 +107,32 @@ class ManageTab(BaseTab):
107
  size="lg",
108
  visible=False
109
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  with gr.Row():
111
  with gr.Column():
112
  gr.Markdown("## πŸ“‘ Publish your model")
@@ -258,7 +284,7 @@ class ManageTab(BaseTab):
258
  )
259
 
260
  self.components["download_checkpoint_btn"].click(
261
- fn=self.app.training.create_checkpoint_zip,
262
  outputs=[self.components["download_checkpoint_btn"]]
263
  )
264
 
@@ -273,6 +299,13 @@ class ManageTab(BaseTab):
273
  outputs=[]
274
  )
275
 
 
 
 
 
 
 
 
276
  # Dataset deletion with modal
277
  self.components["delete_dataset_btn"].click(
278
  fn=lambda: Modal(visible=True),
@@ -404,6 +437,35 @@ class ManageTab(BaseTab):
404
  gr.Error(error_msg)
405
  logger.error(f"LoRA cleanup failed: {e}")
406
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
  def delete_dataset(self):
408
  """Delete dataset files (images, videos, captions)"""
409
  status_messages = {}
 
40
  return "🧠 Download weights (.safetensors)"
41
 
42
  def get_checkpoint_button_text(self) -> str:
43
+ """Get the dynamic text for the download everything button"""
44
  try:
45
  return self.app.training.get_checkpoint_button_text()
46
  except Exception as e:
47
  logger.warning(f"Error getting checkpoint button text: {e}")
48
+ return "πŸ“₯ Download everything (not available)"
49
 
50
  def update_download_button_text(self) -> gr.update:
51
  """Update the download button text"""
 
107
  size="lg",
108
  visible=False
109
  )
110
+
111
+ with gr.Row():
112
+ with gr.Column():
113
+ gr.Markdown("## πŸ”„ Restore from backup")
114
+ gr.Markdown("Upload a checkpoint ZIP file to restore your training progress. Training must be stopped before restoring.")
115
+
116
+ with gr.Row():
117
+ self.components["restore_file"] = gr.File(
118
+ label="Select backup ZIP file",
119
+ file_types=[".zip"],
120
+ type="filepath"
121
+ )
122
+
123
+ with gr.Row():
124
+ self.components["restore_btn"] = gr.Button(
125
+ "πŸ”„ Restore from backup",
126
+ variant="primary",
127
+ size="lg"
128
+ )
129
+
130
+ with gr.Row():
131
+ self.components["restore_status"] = gr.Textbox(
132
+ label="Restore Status",
133
+ interactive=False,
134
+ visible=False
135
+ )
136
  with gr.Row():
137
  with gr.Column():
138
  gr.Markdown("## πŸ“‘ Publish your model")
 
284
  )
285
 
286
  self.components["download_checkpoint_btn"].click(
287
+ fn=self.app.training.create_everything_zip,
288
  outputs=[self.components["download_checkpoint_btn"]]
289
  )
290
 
 
299
  outputs=[]
300
  )
301
 
302
+ # Restore from backup button
303
+ self.components["restore_btn"].click(
304
+ fn=self.handle_restore_backup,
305
+ inputs=[self.components["restore_file"]],
306
+ outputs=[self.components["restore_status"]]
307
+ )
308
+
309
  # Dataset deletion with modal
310
  self.components["delete_dataset_btn"].click(
311
  fn=lambda: Modal(visible=True),
 
437
  gr.Error(error_msg)
438
  logger.error(f"LoRA cleanup failed: {e}")
439
 
440
+ def handle_restore_backup(self, file_path):
441
+ """Handle restoring from a backup ZIP file
442
+
443
+ Args:
444
+ file_path: Path to the uploaded ZIP file
445
+
446
+ Returns:
447
+ gr.update for the status textbox
448
+ """
449
+ if not file_path:
450
+ return gr.update(value="No file selected. Please select a ZIP file to restore.", visible=True)
451
+
452
+ try:
453
+ # Call the training service to restore the backup
454
+ success, message = self.app.training.restore_checkpoint_from_zip(file_path)
455
+
456
+ if success:
457
+ gr.Info(f"βœ… {message}")
458
+ return gr.update(value=message, visible=True)
459
+ else:
460
+ gr.Error(f"❌ {message}")
461
+ return gr.update(value=f"Error: {message}", visible=True)
462
+
463
+ except Exception as e:
464
+ error_msg = f"Failed to restore backup: {str(e)}"
465
+ gr.Error(f"❌ {error_msg}")
466
+ logger.error(f"Restore backup failed: {e}", exc_info=True)
467
+ return gr.update(value=f"Error: {error_msg}", visible=True)
468
+
469
  def delete_dataset(self):
470
  """Delete dataset files (images, videos, captions)"""
471
  status_messages = {}