Spaces:
Running
Running
Commit
Β·
e2cd770
1
Parent(s):
183a6c1
update
Browse files
vms/ui/project/services/training.py
CHANGED
@@ -1821,17 +1821,193 @@ class TrainingService:
|
|
1821 |
print(f"Failed to create checkpoint zip: {str(e)}")
|
1822 |
raise gr.Error(f"Failed to create checkpoint zip: {str(e)}")
|
1823 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1824 |
def get_checkpoint_button_text(self) -> str:
|
1825 |
-
"""Get the dynamic text for the download
|
1826 |
try:
|
1827 |
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
1828 |
if not checkpoints:
|
1829 |
-
return "π₯ Download
|
1830 |
|
1831 |
# Get the latest checkpoint by step number
|
1832 |
latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
|
1833 |
step_num = int(latest_checkpoint.name.split("_")[-1])
|
1834 |
-
return f"π₯ Download
|
1835 |
except Exception as e:
|
1836 |
logger.warning(f"Error getting checkpoint info for button text: {e}")
|
1837 |
-
return "π₯ Download
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1821 |
print(f"Failed to create checkpoint zip: {str(e)}")
|
1822 |
raise gr.Error(f"Failed to create checkpoint zip: {str(e)}")
|
1823 |
|
1824 |
+
def create_everything_zip(self) -> Optional[str]:
|
1825 |
+
"""Create a ZIP file containing both the latest checkpoint folder and LoRA weights
|
1826 |
+
|
1827 |
+
Returns:
|
1828 |
+
Path to created ZIP file or None if no checkpoint found
|
1829 |
+
"""
|
1830 |
+
# Find all checkpoint directories
|
1831 |
+
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
1832 |
+
if not checkpoints:
|
1833 |
+
logger.info("No checkpoint directories found")
|
1834 |
+
raise gr.Error("No checkpoint directories found")
|
1835 |
+
|
1836 |
+
# Get the latest checkpoint by step number
|
1837 |
+
latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
|
1838 |
+
step_num = int(latest_checkpoint.name.split("_")[-1])
|
1839 |
+
|
1840 |
+
# Check if LoRA weights exist
|
1841 |
+
lora_weights_path = self.app.output_path / "pytorch_lora_weights.safetensors"
|
1842 |
+
if not lora_weights_path.exists():
|
1843 |
+
logger.warning("LoRA weights file not found, will only include checkpoint")
|
1844 |
+
|
1845 |
+
# Create temporary directory for combining files
|
1846 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
1847 |
+
temp_path = Path(temp_dir)
|
1848 |
+
|
1849 |
+
# Copy checkpoint directory to temp location
|
1850 |
+
checkpoint_dest = temp_path / latest_checkpoint.name
|
1851 |
+
shutil.copytree(latest_checkpoint, checkpoint_dest)
|
1852 |
+
print(f"Copied checkpoint {latest_checkpoint.name} to temporary location")
|
1853 |
+
|
1854 |
+
# Copy LoRA weights if they exist
|
1855 |
+
if lora_weights_path.exists():
|
1856 |
+
lora_dest = temp_path / "pytorch_lora_weights.safetensors"
|
1857 |
+
shutil.copy2(lora_weights_path, lora_dest)
|
1858 |
+
print(f"Copied LoRA weights to temporary location")
|
1859 |
+
|
1860 |
+
# Create temporary zip file
|
1861 |
+
with tempfile.NamedTemporaryFile(suffix='.zip', delete=False) as temp_zip:
|
1862 |
+
temp_zip_path = str(temp_zip.name)
|
1863 |
+
print(f"Creating combined zip file with checkpoint and LoRA weights...")
|
1864 |
+
try:
|
1865 |
+
# Create the archive with all contents from temp directory
|
1866 |
+
make_archive(temp_path, temp_zip_path)
|
1867 |
+
print(f"Combined zip file created for step {step_num}!")
|
1868 |
+
return temp_zip_path
|
1869 |
+
except Exception as e:
|
1870 |
+
print(f"Failed to create combined zip: {str(e)}")
|
1871 |
+
raise gr.Error(f"Failed to create combined zip: {str(e)}")
|
1872 |
+
|
1873 |
def get_checkpoint_button_text(self) -> str:
|
1874 |
+
"""Get the dynamic text for the download everything button based on available checkpoints"""
|
1875 |
try:
|
1876 |
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
1877 |
if not checkpoints:
|
1878 |
+
return "π₯ Download everything (not available)"
|
1879 |
|
1880 |
# Get the latest checkpoint by step number
|
1881 |
latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
|
1882 |
step_num = int(latest_checkpoint.name.split("_")[-1])
|
1883 |
+
return f"π₯ Download everything (step {step_num})"
|
1884 |
except Exception as e:
|
1885 |
logger.warning(f"Error getting checkpoint info for button text: {e}")
|
1886 |
+
return "π₯ Download everything (not available)"
|
1887 |
+
|
1888 |
+
def restore_checkpoint_from_zip(self, zip_file_path: str) -> Tuple[bool, str]:
|
1889 |
+
"""Restore checkpoint from a ZIP file
|
1890 |
+
|
1891 |
+
Args:
|
1892 |
+
zip_file_path: Path to the ZIP file containing checkpoint data
|
1893 |
+
|
1894 |
+
Returns:
|
1895 |
+
Tuple of (success, message)
|
1896 |
+
"""
|
1897 |
+
try:
|
1898 |
+
# Check if training is running
|
1899 |
+
if self.is_training_running():
|
1900 |
+
return False, "Cannot restore checkpoint while training is running. Please stop training first."
|
1901 |
+
|
1902 |
+
import zipfile
|
1903 |
+
|
1904 |
+
# Validate ZIP file
|
1905 |
+
if not zipfile.is_zipfile(zip_file_path):
|
1906 |
+
return False, "The provided file is not a valid ZIP archive."
|
1907 |
+
|
1908 |
+
# Extract the ZIP file to a temporary directory first
|
1909 |
+
import tempfile
|
1910 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
1911 |
+
temp_path = Path(temp_dir)
|
1912 |
+
|
1913 |
+
# Extract ZIP
|
1914 |
+
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
|
1915 |
+
zip_ref.extractall(temp_path)
|
1916 |
+
logger.info(f"Extracted backup to temporary directory: {temp_path}")
|
1917 |
+
|
1918 |
+
# Check what was extracted
|
1919 |
+
extracted_items = list(temp_path.glob("*"))
|
1920 |
+
if not extracted_items:
|
1921 |
+
return False, "The ZIP file appears to be empty."
|
1922 |
+
|
1923 |
+
logger.info(f"Extracted items: {[item.name for item in extracted_items]}")
|
1924 |
+
|
1925 |
+
# Look for checkpoint directory pattern (finetrainers_step_*)
|
1926 |
+
checkpoint_dirs = [d for d in extracted_items if d.is_dir() and d.name.startswith("finetrainers_step_")]
|
1927 |
+
|
1928 |
+
if checkpoint_dirs:
|
1929 |
+
# We have a checkpoint directory
|
1930 |
+
checkpoint_dir = checkpoint_dirs[0]
|
1931 |
+
checkpoint_name = checkpoint_dir.name
|
1932 |
+
step_num = int(checkpoint_name.split("_")[-1])
|
1933 |
+
|
1934 |
+
# Check if this checkpoint already exists
|
1935 |
+
target_path = self.app.output_path / checkpoint_name
|
1936 |
+
if target_path.exists():
|
1937 |
+
# Backup existing checkpoint
|
1938 |
+
backup_name = f"{checkpoint_name}_backup_{int(time.time())}"
|
1939 |
+
backup_path = self.app.output_path / backup_name
|
1940 |
+
shutil.move(str(target_path), str(backup_path))
|
1941 |
+
logger.info(f"Backed up existing checkpoint to {backup_name}")
|
1942 |
+
|
1943 |
+
# Move the checkpoint to output directory
|
1944 |
+
shutil.move(str(checkpoint_dir), str(target_path))
|
1945 |
+
logger.info(f"Restored checkpoint {checkpoint_name} to {target_path}")
|
1946 |
+
|
1947 |
+
# Also check for LoRA weights in the archive
|
1948 |
+
lora_file = None
|
1949 |
+
for item in extracted_items:
|
1950 |
+
if item.name == "pytorch_lora_weights.safetensors":
|
1951 |
+
lora_file = item
|
1952 |
+
break
|
1953 |
+
|
1954 |
+
if lora_file:
|
1955 |
+
# Copy LoRA weights to output directory root
|
1956 |
+
target_lora_path = self.app.output_path / "pytorch_lora_weights.safetensors"
|
1957 |
+
if target_lora_path.exists():
|
1958 |
+
backup_lora_name = f"pytorch_lora_weights_backup_{int(time.time())}.safetensors"
|
1959 |
+
backup_lora_path = self.app.output_path / backup_lora_name
|
1960 |
+
shutil.move(str(target_lora_path), str(backup_lora_path))
|
1961 |
+
logger.info(f"Backed up existing LoRA weights to {backup_lora_name}")
|
1962 |
+
|
1963 |
+
shutil.copy2(str(lora_file), str(target_lora_path))
|
1964 |
+
logger.info(f"Restored LoRA weights to {target_lora_path}")
|
1965 |
+
|
1966 |
+
return True, f"Successfully restored checkpoint at step {step_num} with LoRA weights"
|
1967 |
+
else:
|
1968 |
+
return True, f"Successfully restored checkpoint at step {step_num}"
|
1969 |
+
|
1970 |
+
else:
|
1971 |
+
# No checkpoint directory found, look for direct files
|
1972 |
+
# This could be an output directory backup
|
1973 |
+
logger.info("No checkpoint directory found, treating as output directory backup")
|
1974 |
+
|
1975 |
+
# Clear existing output directory (backup first if needed)
|
1976 |
+
if any(self.app.output_path.iterdir()):
|
1977 |
+
backup_dir = self.app.output_path.parent / f"output_backup_{int(time.time())}"
|
1978 |
+
shutil.copytree(str(self.app.output_path), str(backup_dir))
|
1979 |
+
logger.info(f"Backed up existing output to {backup_dir}")
|
1980 |
+
|
1981 |
+
# Clear output directory
|
1982 |
+
for item in self.app.output_path.iterdir():
|
1983 |
+
if item.is_dir():
|
1984 |
+
shutil.rmtree(item)
|
1985 |
+
else:
|
1986 |
+
item.unlink()
|
1987 |
+
|
1988 |
+
# Move all extracted items to output directory
|
1989 |
+
for item in extracted_items:
|
1990 |
+
target = self.app.output_path / item.name
|
1991 |
+
shutil.move(str(item), str(target))
|
1992 |
+
logger.info(f"Restored {item.name} to output directory")
|
1993 |
+
|
1994 |
+
# Check what we restored
|
1995 |
+
restored_checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
1996 |
+
restored_lora = (self.app.output_path / "pytorch_lora_weights.safetensors").exists()
|
1997 |
+
|
1998 |
+
if restored_checkpoints:
|
1999 |
+
latest = max(restored_checkpoints, key=lambda x: int(x.name.split("_")[-1]))
|
2000 |
+
step_num = int(latest.name.split("_")[-1])
|
2001 |
+
if restored_lora:
|
2002 |
+
return True, f"Successfully restored output directory with checkpoint at step {step_num} and LoRA weights"
|
2003 |
+
else:
|
2004 |
+
return True, f"Successfully restored output directory with checkpoint at step {step_num}"
|
2005 |
+
elif restored_lora:
|
2006 |
+
return True, "Successfully restored output directory with LoRA weights"
|
2007 |
+
else:
|
2008 |
+
return True, "Successfully restored output directory backup"
|
2009 |
+
|
2010 |
+
except Exception as e:
|
2011 |
+
error_msg = f"Failed to restore checkpoint: {str(e)}"
|
2012 |
+
logger.error(error_msg, exc_info=True)
|
2013 |
+
return False, error_msg
|
vms/ui/project/tabs/manage_tab.py
CHANGED
@@ -40,12 +40,12 @@ class ManageTab(BaseTab):
|
|
40 |
return "π§ Download weights (.safetensors)"
|
41 |
|
42 |
def get_checkpoint_button_text(self) -> str:
|
43 |
-
"""Get the dynamic text for the download
|
44 |
try:
|
45 |
return self.app.training.get_checkpoint_button_text()
|
46 |
except Exception as e:
|
47 |
logger.warning(f"Error getting checkpoint button text: {e}")
|
48 |
-
return "π₯ Download
|
49 |
|
50 |
def update_download_button_text(self) -> gr.update:
|
51 |
"""Update the download button text"""
|
@@ -107,6 +107,32 @@ class ManageTab(BaseTab):
|
|
107 |
size="lg",
|
108 |
visible=False
|
109 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
with gr.Row():
|
111 |
with gr.Column():
|
112 |
gr.Markdown("## π‘ Publish your model")
|
@@ -258,7 +284,7 @@ class ManageTab(BaseTab):
|
|
258 |
)
|
259 |
|
260 |
self.components["download_checkpoint_btn"].click(
|
261 |
-
fn=self.app.training.
|
262 |
outputs=[self.components["download_checkpoint_btn"]]
|
263 |
)
|
264 |
|
@@ -273,6 +299,13 @@ class ManageTab(BaseTab):
|
|
273 |
outputs=[]
|
274 |
)
|
275 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
# Dataset deletion with modal
|
277 |
self.components["delete_dataset_btn"].click(
|
278 |
fn=lambda: Modal(visible=True),
|
@@ -404,6 +437,35 @@ class ManageTab(BaseTab):
|
|
404 |
gr.Error(error_msg)
|
405 |
logger.error(f"LoRA cleanup failed: {e}")
|
406 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
407 |
def delete_dataset(self):
|
408 |
"""Delete dataset files (images, videos, captions)"""
|
409 |
status_messages = {}
|
|
|
40 |
return "π§ Download weights (.safetensors)"
|
41 |
|
42 |
def get_checkpoint_button_text(self) -> str:
|
43 |
+
"""Get the dynamic text for the download everything button"""
|
44 |
try:
|
45 |
return self.app.training.get_checkpoint_button_text()
|
46 |
except Exception as e:
|
47 |
logger.warning(f"Error getting checkpoint button text: {e}")
|
48 |
+
return "π₯ Download everything (not available)"
|
49 |
|
50 |
def update_download_button_text(self) -> gr.update:
|
51 |
"""Update the download button text"""
|
|
|
107 |
size="lg",
|
108 |
visible=False
|
109 |
)
|
110 |
+
|
111 |
+
with gr.Row():
|
112 |
+
with gr.Column():
|
113 |
+
gr.Markdown("## π Restore from backup")
|
114 |
+
gr.Markdown("Upload a checkpoint ZIP file to restore your training progress. Training must be stopped before restoring.")
|
115 |
+
|
116 |
+
with gr.Row():
|
117 |
+
self.components["restore_file"] = gr.File(
|
118 |
+
label="Select backup ZIP file",
|
119 |
+
file_types=[".zip"],
|
120 |
+
type="filepath"
|
121 |
+
)
|
122 |
+
|
123 |
+
with gr.Row():
|
124 |
+
self.components["restore_btn"] = gr.Button(
|
125 |
+
"π Restore from backup",
|
126 |
+
variant="primary",
|
127 |
+
size="lg"
|
128 |
+
)
|
129 |
+
|
130 |
+
with gr.Row():
|
131 |
+
self.components["restore_status"] = gr.Textbox(
|
132 |
+
label="Restore Status",
|
133 |
+
interactive=False,
|
134 |
+
visible=False
|
135 |
+
)
|
136 |
with gr.Row():
|
137 |
with gr.Column():
|
138 |
gr.Markdown("## π‘ Publish your model")
|
|
|
284 |
)
|
285 |
|
286 |
self.components["download_checkpoint_btn"].click(
|
287 |
+
fn=self.app.training.create_everything_zip,
|
288 |
outputs=[self.components["download_checkpoint_btn"]]
|
289 |
)
|
290 |
|
|
|
299 |
outputs=[]
|
300 |
)
|
301 |
|
302 |
+
# Restore from backup button
|
303 |
+
self.components["restore_btn"].click(
|
304 |
+
fn=self.handle_restore_backup,
|
305 |
+
inputs=[self.components["restore_file"]],
|
306 |
+
outputs=[self.components["restore_status"]]
|
307 |
+
)
|
308 |
+
|
309 |
# Dataset deletion with modal
|
310 |
self.components["delete_dataset_btn"].click(
|
311 |
fn=lambda: Modal(visible=True),
|
|
|
437 |
gr.Error(error_msg)
|
438 |
logger.error(f"LoRA cleanup failed: {e}")
|
439 |
|
440 |
+
def handle_restore_backup(self, file_path):
|
441 |
+
"""Handle restoring from a backup ZIP file
|
442 |
+
|
443 |
+
Args:
|
444 |
+
file_path: Path to the uploaded ZIP file
|
445 |
+
|
446 |
+
Returns:
|
447 |
+
gr.update for the status textbox
|
448 |
+
"""
|
449 |
+
if not file_path:
|
450 |
+
return gr.update(value="No file selected. Please select a ZIP file to restore.", visible=True)
|
451 |
+
|
452 |
+
try:
|
453 |
+
# Call the training service to restore the backup
|
454 |
+
success, message = self.app.training.restore_checkpoint_from_zip(file_path)
|
455 |
+
|
456 |
+
if success:
|
457 |
+
gr.Info(f"β
{message}")
|
458 |
+
return gr.update(value=message, visible=True)
|
459 |
+
else:
|
460 |
+
gr.Error(f"β {message}")
|
461 |
+
return gr.update(value=f"Error: {message}", visible=True)
|
462 |
+
|
463 |
+
except Exception as e:
|
464 |
+
error_msg = f"Failed to restore backup: {str(e)}"
|
465 |
+
gr.Error(f"β {error_msg}")
|
466 |
+
logger.error(f"Restore backup failed: {e}", exc_info=True)
|
467 |
+
return gr.update(value=f"Error: {error_msg}", visible=True)
|
468 |
+
|
469 |
def delete_dataset(self):
|
470 |
"""Delete dataset files (images, videos, captions)"""
|
471 |
status_messages = {}
|