Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader | |
| from torchvision import transforms | |
| from datasets import load_dataset | |
| from huggingface_hub import Repository | |
| from huggingface_hub import HfApi, HfFolder, Repository, create_repo | |
| import os | |
| import pandas as pd | |
| import gradio as gr | |
| from PIL import Image | |
| import numpy as np | |
| from small_256_model import UNet as small_UNet | |
| from big_1024_model import UNet as big_UNet | |
| from CLIP import load as load_clip | |
| from rich import print as rp | |
| # Device configuration | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| big = False if device == torch.device('cpu') else True | |
| # Parameters | |
| IMG_SIZE = 1024 if big else 256 | |
| BATCH_SIZE = 1 if big else 1 | |
| EPOCHS = 12 | |
| LR = 0.0002 | |
| dataset_id = "K00B404/pix2pix_flux_set" | |
| model_repo_id = "K00B404/pix2pix_flux" | |
| # Global model variable | |
| global_model = None | |
| # CLIP | |
| clip_model, clip_tokenizer = load_clip() | |
| def load_model(): | |
| """Load the models at startup""" | |
| global global_model | |
| weights_name = 'big_model_weights.pth' if big else 'small_model_weights.pth' | |
| try: | |
| checkpoint = torch.load(weights_name, map_location=device) | |
| model = big_UNet() if checkpoint['model_config']['big'] else small_UNet() | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.to(device) | |
| model.eval() | |
| global_model = model | |
| rp("Model loaded successfully!") | |
| return model | |
| except Exception as e: | |
| rp(f"Error loading model: {e}") | |
| model = big_UNet().to(device) if big else small_UNet().to(device) | |
| global_model = model | |
| return model | |
| class Pix2PixDataset(torch.utils.data.Dataset): | |
| def __init__(self, combined_data, transform, clip_tokenizer): | |
| self.data = combined_data | |
| self.transform = transform | |
| self.clip_tokenizer = clip_tokenizer | |
| self.original_folder = 'images_dataset/original/' | |
| self.target_folder = 'images_dataset/target/' | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| original_img_filename = os.path.basename(self.data.iloc[idx]['image_path']) | |
| original_img_path = os.path.join(self.original_folder, original_img_filename) | |
| target_img_path = os.path.join(self.target_folder, original_img_filename) | |
| original_img = Image.open(original_img_path).convert('RGB') | |
| target_img = Image.open(target_img_path).convert('RGB') | |
| # Transform images | |
| original = self.transform(original_img) | |
| target = self.transform(target_img) | |
| # Get prompts from the DataFrame | |
| original_prompt = self.data.iloc[idx]['original_prompt'] | |
| enhanced_prompt = self.data.iloc[idx]['enhanced_prompt'] | |
| # Tokenize the prompts using CLIP tokenizer | |
| original_tokens = self.clip_tokenizer(original_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77) | |
| enhanced_tokens = self.clip_tokenizer(enhanced_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77) | |
| return original, target, original_tokens, enhanced_tokens | |
| class UNetWrapper: | |
| def __init__(self, unet_model, repo_id, epoch, loss, optimizer, scheduler=None): | |
| self.loss = loss | |
| self.epoch = epoch | |
| self.model = unet_model | |
| self.optimizer = optimizer | |
| self.scheduler = scheduler | |
| self.repo_id = repo_id | |
| self.token = os.getenv('NEW_TOKEN') # Ensure the token is set in the environment | |
| self.api = HfApi(token=self.token) | |
| def save_checkpoint(self, save_path): | |
| """Save checkpoint with model, optimizer, and scheduler states.""" | |
| self.save_dict = { | |
| 'model_state_dict': self.model.state_dict(), | |
| 'optimizer_state_dict': self.optimizer.state_dict(), | |
| 'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None, | |
| 'model_config': { | |
| 'big': isinstance(self.model, big_UNet), | |
| 'img_size': 1024 if isinstance(self.model, big_UNet) else 256 | |
| }, | |
| 'epoch': self.epoch, | |
| 'loss': self.loss | |
| } | |
| torch.save(self.save_dict, save_path) | |
| print(f"Checkpoint saved at epoch {self.epoch}, loss: {self.loss}") | |
| def load_checkpoint(self, checkpoint_path): | |
| """Load model, optimizer, and scheduler states from the checkpoint.""" | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| if self.scheduler and checkpoint['scheduler_state_dict']: | |
| self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) | |
| self.epoch = checkpoint['epoch'] | |
| self.loss = checkpoint['loss'] | |
| print(f"Checkpoint loaded: epoch {self.epoch}, loss: {self.loss}") | |
| def push_to_hub(self, pth_name): | |
| """Push model checkpoint and metadata to the Hugging Face Hub.""" | |
| try: | |
| self.api.upload_file( | |
| path_or_fileobj=pth_name, | |
| path_in_repo=pth_name, | |
| repo_id=self.repo_id, | |
| token=self.token, | |
| repo_type="model" | |
| ) | |
| print(f"Model checkpoint successfully uploaded to {self.repo_id}") | |
| except Exception as e: | |
| print(f"Error uploading model: {e}") | |
| # Create and upload model card | |
| model_card = f"""--- | |
| tags: | |
| - unet | |
| - pix2pix | |
| - pytorch | |
| library_name: pytorch | |
| license: wtfpl | |
| datasets: | |
| - K00B404/pix2pix_flux_set | |
| language: | |
| - en | |
| pipeline_tag: image-to-image | |
| --- | |
| # Pix2Pix UNet Model | |
| ## Model Description | |
| Custom UNet model for Pix2Pix image translation. | |
| - **Image Size:** {self.save_dict['model_config']['img_size']} | |
| - **Model Type:** {"big" if big else "small"}_UNet ({self.save_dict['model_config']['img_size']}) | |
| ## Usage | |
| ```python | |
| import torch | |
| from small_256_model import UNet as small_UNet | |
| from big_1024_model import UNet as big_UNet | |
| big = True | |
| # Load the model | |
| name='big_model_weights.pth' if big else 'small_model_weights.pth' | |
| checkpoint = torch.load(name) | |
| model = big_UNet() if checkpoint['model_config']['big'] else small_UNet() | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| ``` | |
| ## Model Architecture | |
| {str(self.model)} """ | |
| rp(model_card) | |
| try: | |
| # Save and upload README | |
| with open("README.md", "w") as f: | |
| f.write(f"# Pix2Pix UNet Model\n\n" | |
| f"- **Image Size:** {self.save_dict['model_config']['img_size']}\n" | |
| f"- **Model Type:** {'big' if big else 'small'}_UNet ({self.save_dict['model_config']['img_size']})\n" | |
| f"## Model Architecture\n{str(self.model)}") | |
| self.api.upload_file( | |
| path_or_fileobj="README.md", | |
| path_in_repo="README.md", | |
| repo_id=self.repo_id, | |
| token=self.token, | |
| repo_type="model" | |
| ) | |
| # Clean up local files | |
| os.remove(pth_name) | |
| os.remove("README.md") | |
| print(f"Model successfully uploaded to {self.repo_id}") | |
| except Exception as e: | |
| print(f"Error uploading model: {e}") | |
| def prepare_input(image, device='cpu'): | |
| """Prepare image for inference""" | |
| transform = transforms.Compose([ | |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
| transforms.ToTensor(), | |
| ]) | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| input_tensor = transform(image).unsqueeze(0).to(device) | |
| return input_tensor | |
| def run_inference(image): | |
| """Run inference on a single image""" | |
| global global_model | |
| if global_model is None: | |
| return "Error: Model not loaded" | |
| global_model.eval() | |
| input_tensor = prepare_input(image, device) | |
| with torch.no_grad(): | |
| output = global_model(input_tensor) | |
| # Convert output to image | |
| output = output.cpu().squeeze(0).permute(1, 2, 0).numpy() | |
| output = ((output - output.min()) / (output.max() - output.min()) * 255).astype(np.uint8) | |
| rp(output[0]) | |
| return output | |
| def to_hub(model, epoch, loss): | |
| wrapper = UNetWrapper(model, model_repo_id, epoch, loss) | |
| wrapper.push_to_hub() | |
| def train_model(epochs, save_interval=1): | |
| """Training function with checkpoint saving and model uploading.""" | |
| global global_model | |
| # Load combined data CSV | |
| data_path = 'combined_data.csv' | |
| combined_data = pd.read_csv(data_path) | |
| # Define the transformation | |
| transform = transforms.Compose([ | |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
| transforms.ToTensor(), | |
| ]) | |
| # Initialize dataset and dataloader | |
| dataset = Pix2PixDataset(combined_data, transform, clip_tokenizer) | |
| dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) | |
| model = global_model | |
| criterion = nn.L1Loss() | |
| optimizer = optim.Adam(model.parameters(), lr=LR) | |
| scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) # Example scheduler | |
| wrapper = UNetWrapper(model, model_repo_id, epoch=0, loss=0.0, optimizer=optimizer, scheduler=scheduler) | |
| output_text = [] | |
| for epoch in range(epochs): | |
| model.train() | |
| running_loss = 0.0 | |
| for i, (original, target, original_prompt_tokens, enhanced_prompt_tokens) in enumerate(dataloader): | |
| # Move data to device | |
| original, target = original.to(device), target.to(device) | |
| original_prompt_tokens = original_prompt_tokens.input_ids.to(device).float() | |
| enhanced_prompt_tokens = enhanced_prompt_tokens.input_ids.to(device).float() | |
| optimizer.zero_grad() | |
| # Forward pass | |
| output = model(target) | |
| img_loss = criterion(output, original) | |
| total_loss = img_loss | |
| total_loss.backward() | |
| optimizer.step() | |
| running_loss += total_loss.item() | |
| if i % 10 == 0: | |
| status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}" | |
| print(status) | |
| output_text.append(status) | |
| # Update the epoch and loss for checkpoint | |
| wrapper.epoch = epoch + 1 | |
| wrapper.loss = running_loss / len(dataloader) | |
| # Save checkpoint at specified intervals | |
| if (epoch + 1) % save_interval == 0: | |
| checkpoint_path = f'big_checkpoint_epoch_{epoch+1}.pth' if big else f'small_checkpoint_epoch_{epoch+1}.pth' | |
| wrapper.save_checkpoint(checkpoint_path) | |
| wrapper.push_to_hub(checkpoint_path) | |
| scheduler.step() # Update learning rate scheduler | |
| global_model = model # Update global model after training | |
| return model, "\n".join(output_text) | |
| def train_model_old(epochs): | |
| """Training function""" | |
| global global_model | |
| # Load combined data CSV | |
| data_path = 'combined_data.csv' # Adjust this path | |
| combined_data = pd.read_csv(data_path) | |
| # Define the transformation | |
| transform = transforms.Compose([ | |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
| transforms.ToTensor(), | |
| ]) | |
| # Initialize the dataset and dataloader | |
| dataset = Pix2PixDataset(combined_data, transform, clip_tokenizer) | |
| dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) | |
| model = global_model | |
| criterion = nn.L1Loss() # L1 loss for image reconstruction | |
| optimizer = optim.Adam(model.parameters(), lr=LR) | |
| output_text = [] | |
| for epoch in range(epochs): | |
| model.train() | |
| for i, (original, target, original_prompt_tokens, enhanced_prompt_tokens) in enumerate(dataloader): | |
| # Move images and prompt embeddings to the appropriate device (CPU or GPU) | |
| original, target = original.to(device), target.to(device) | |
| original_prompt_tokens = original_prompt_tokens.input_ids.to(device).float() # Convert to float | |
| enhanced_prompt_tokens = enhanced_prompt_tokens.input_ids.to(device).float() # Convert to float | |
| optimizer.zero_grad() | |
| # Forward pass through the model | |
| output = model(target) | |
| # Compute image reconstruction loss | |
| img_loss = criterion(output, original) | |
| rp(f"Image {i} Loss:{img_loss}") | |
| # Combine losses | |
| total_loss = img_loss # Add any other losses if necessary | |
| total_loss.backward() | |
| # Optimizer step | |
| optimizer.step() | |
| if i % 10 == 0: | |
| status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}" | |
| rp(status) | |
| output_text.append(status) | |
| # Push model to Hugging Face Hub at the end of each epoch | |
| to_hub(model, epoch, total_loss) | |
| global_model = model # Update the global model after training | |
| return model, "\n".join(output_text) | |
| def gradio_train(epochs): | |
| # Gradio training interface function | |
| model, training_log = train_model(int(epochs)) | |
| #to_hub(model) | |
| return f"{training_log}\n\nModel trained for {epochs} epochs and pushed to {model_repo_id}" | |
| def gradio_inference(input_image): | |
| # Gradio inference interface function | |
| output_image = run_inference(input_image) # Assuming `run_inference` returns a tuple (output_image, other_data) | |
| rp(output_image) | |
| # If `run_inference` returns a tuple, you should only return the image part | |
| return output_image # Ensure you're only returning the processed output image | |
| # Create Gradio interface with tabs | |
| with gr.Blocks() as app: | |
| gr.Markdown("# Pix2Pix Model Training and Inference") | |
| with gr.Tab("Train"): | |
| epochs_input = gr.Number(value=EPOCHS, label="Number of epochs") | |
| train_button = gr.Button("Train") | |
| training_output = gr.Textbox(label="Training Log", interactive=False) | |
| train_button.click(gradio_train, inputs=[epochs_input], outputs=[training_output]) | |
| with gr.Tab("Inference"): | |
| image_input = gr.Image(type='numpy') | |
| prompt_input = gr.Textbox(label="Prompt") | |
| inference_button = gr.Button("Generate") | |
| inference_output = gr.Image(type='numpy', label="Generated Image") | |
| inference_button.click(gradio_inference, inputs=[image_input], outputs=[inference_output]) | |
| load_model() | |
| app.launch() | |