|
import math |
|
import numpy as np |
|
import torch |
|
|
|
import logging |
|
import os |
|
import sys |
|
from colorama import Fore, Style, init |
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
init(autoreset=True) |
|
|
|
def nearest_power_of_two(x: int, round_up: bool = False) -> int: |
|
return ( |
|
1 << math.floor(math.log2(x)) if not round_up else 1 << math.ceil(math.log2(x)) |
|
) |
|
|
|
def get_hankel(seq_len: int, use_hankel_L: bool = False, device: torch.device = None, dtype: torch.dtype = torch.float32) -> torch.Tensor: |
|
entries = torch.arange(1, seq_len + 1, dtype=dtype, device=device) |
|
i_plus_j = entries[:, None] + entries[None, :] |
|
|
|
if use_hankel_L: |
|
sgn = (-1.0) ** (i_plus_j - 2.0) + 1.0 |
|
denom = (i_plus_j + 3.0) * (i_plus_j - 1.0) * (i_plus_j + 1.0) |
|
Z = sgn * (8.0 / denom) |
|
elif not use_hankel_L: |
|
Z = 2.0 / (i_plus_j**3 - i_plus_j) |
|
else: |
|
raise ValueError("use_hankel_L must be a boolean") |
|
|
|
return Z |
|
|
|
|
|
class ColorFormatter(logging.Formatter): |
|
""" |
|
A custom log formatter that applies color based on the log level using the Colorama library. |
|
|
|
Attributes: |
|
LOG_COLORS (dict): A dictionary mapping log levels to their corresponding color codes. |
|
""" |
|
|
|
|
|
LOG_COLORS = { |
|
logging.DEBUG: Fore.LIGHTMAGENTA_EX + Style.BRIGHT, |
|
logging.INFO: Fore.CYAN, |
|
logging.WARNING: Fore.YELLOW + Style.BRIGHT, |
|
logging.ERROR: Fore.RED + Style.BRIGHT, |
|
logging.CRITICAL: Fore.RED + Style.BRIGHT + Style.NORMAL, |
|
} |
|
|
|
|
|
TIME_COLOR = Fore.GREEN |
|
FILE_COLOR = Fore.BLUE |
|
LEVEL_COLOR = Style.BRIGHT |
|
|
|
def __init__(self, fmt=None): |
|
super().__init__(fmt or "%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s", "%Y-%m-%d %H:%M:%S") |
|
|
|
def format(self, record): |
|
""" |
|
Formats a log record with the appropriate color based on the log level. |
|
|
|
Args: |
|
record (logging.LogRecord): The log record to format. |
|
|
|
Returns: |
|
str: The formatted log message with colors applied. |
|
""" |
|
|
|
level_color = self.LOG_COLORS.get(record.levelno, Fore.WHITE) |
|
time_str = f"{self.TIME_COLOR}{self.formatTime(record)}{Style.RESET_ALL}" |
|
levelname_str = f"{level_color}{record.levelname}{Style.RESET_ALL}" |
|
file_info_str = f"{self.FILE_COLOR}{record.filename}:{record.lineno}{Style.RESET_ALL}" |
|
|
|
|
|
log_msg = f"{time_str} - {levelname_str} - {file_info_str} - {record.msg}" |
|
return log_msg |
|
|
|
def setup_logger(): |
|
""" |
|
Sets up a logger with a custom color formatter that logs to standard output (stdout). |
|
|
|
The logger is configured with the ColorFormatter to format log messages with color based on the log level. |
|
The log level is set to INFO by default, but this can be changed to show more or less detailed messages. |
|
|
|
Returns: |
|
logging.Logger: A logger instance that logs formatted messages to stdout. |
|
""" |
|
handler = logging.StreamHandler(sys.stdout) |
|
|
|
|
|
formatter = ColorFormatter() |
|
handler.setFormatter(formatter) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
DEBUG = os.environ.get("DEBUG", "False").lower() in ("true", "1", "t") |
|
logger.setLevel(logging.DEBUG) if DEBUG else logger.setLevel(logging.INFO) |
|
logger.addHandler(handler) |
|
logger.propagate = False |
|
|
|
return logger |
|
|
|
logger = setup_logger() |
|
|