import math import numpy as np import torch import logging import os import sys from colorama import Fore, Style, init from dotenv import load_dotenv load_dotenv() init(autoreset=True) def nearest_power_of_two(x: int, round_up: bool = False) -> int: return ( 1 << math.floor(math.log2(x)) if not round_up else 1 << math.ceil(math.log2(x)) ) def get_hankel(seq_len: int, use_hankel_L: bool = False, device: torch.device = None, dtype: torch.dtype = torch.float32) -> torch.Tensor: entries = torch.arange(1, seq_len + 1, dtype=dtype, device=device) i_plus_j = entries[:, None] + entries[None, :] if use_hankel_L: sgn = (-1.0) ** (i_plus_j - 2.0) + 1.0 denom = (i_plus_j + 3.0) * (i_plus_j - 1.0) * (i_plus_j + 1.0) Z = sgn * (8.0 / denom) elif not use_hankel_L: Z = 2.0 / (i_plus_j**3 - i_plus_j) else: raise ValueError("use_hankel_L must be a boolean") return Z class ColorFormatter(logging.Formatter): """ A custom log formatter that applies color based on the log level using the Colorama library. Attributes: LOG_COLORS (dict): A dictionary mapping log levels to their corresponding color codes. """ # Colors for each log level LOG_COLORS = { logging.DEBUG: Fore.LIGHTMAGENTA_EX + Style.BRIGHT, logging.INFO: Fore.CYAN, logging.WARNING: Fore.YELLOW + Style.BRIGHT, logging.ERROR: Fore.RED + Style.BRIGHT, logging.CRITICAL: Fore.RED + Style.BRIGHT + Style.NORMAL, } # Colors for other parts of the log message TIME_COLOR = Fore.GREEN FILE_COLOR = Fore.BLUE LEVEL_COLOR = Style.BRIGHT def __init__(self, fmt=None): super().__init__(fmt or "%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s", "%Y-%m-%d %H:%M:%S") def format(self, record): """ Formats a log record with the appropriate color based on the log level. Args: record (logging.LogRecord): The log record to format. Returns: str: The formatted log message with colors applied. """ # Apply color based on the log level level_color = self.LOG_COLORS.get(record.levelno, Fore.WHITE) time_str = f"{self.TIME_COLOR}{self.formatTime(record)}{Style.RESET_ALL}" levelname_str = f"{level_color}{record.levelname}{Style.RESET_ALL}" file_info_str = f"{self.FILE_COLOR}{record.filename}:{record.lineno}{Style.RESET_ALL}" # Format the log message with color log_msg = f"{time_str} - {levelname_str} - {file_info_str} - {record.msg}" return log_msg def setup_logger(): """ Sets up a logger with a custom color formatter that logs to standard output (stdout). The logger is configured with the ColorFormatter to format log messages with color based on the log level. The log level is set to INFO by default, but this can be changed to show more or less detailed messages. Returns: logging.Logger: A logger instance that logs formatted messages to stdout. """ handler = logging.StreamHandler(sys.stdout) # Set custom formatter formatter = ColorFormatter() handler.setFormatter(formatter) logger = logging.getLogger(__name__) # Set to DEBUG to capture all logging levels DEBUG = os.environ.get("DEBUG", "False").lower() in ("true", "1", "t") logger.setLevel(logging.DEBUG) if DEBUG else logger.setLevel(logging.INFO) logger.addHandler(handler) logger.propagate = False # Prevents multiple logging if re-initialized return logger logger = setup_logger() # Initialize once to prevent multiple loggers