File size: 3,696 Bytes
5a1cdf2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import math
import numpy as np
import torch
import logging
import os
import sys
from colorama import Fore, Style, init
from dotenv import load_dotenv
load_dotenv()
init(autoreset=True)
def nearest_power_of_two(x: int, round_up: bool = False) -> int:
return (
1 << math.floor(math.log2(x)) if not round_up else 1 << math.ceil(math.log2(x))
)
def get_hankel(seq_len: int, use_hankel_L: bool = False, device: torch.device = None, dtype: torch.dtype = torch.float32) -> torch.Tensor:
entries = torch.arange(1, seq_len + 1, dtype=dtype, device=device)
i_plus_j = entries[:, None] + entries[None, :]
if use_hankel_L:
sgn = (-1.0) ** (i_plus_j - 2.0) + 1.0
denom = (i_plus_j + 3.0) * (i_plus_j - 1.0) * (i_plus_j + 1.0)
Z = sgn * (8.0 / denom)
elif not use_hankel_L:
Z = 2.0 / (i_plus_j**3 - i_plus_j)
else:
raise ValueError("use_hankel_L must be a boolean")
return Z
class ColorFormatter(logging.Formatter):
"""
A custom log formatter that applies color based on the log level using the Colorama library.
Attributes:
LOG_COLORS (dict): A dictionary mapping log levels to their corresponding color codes.
"""
# Colors for each log level
LOG_COLORS = {
logging.DEBUG: Fore.LIGHTMAGENTA_EX + Style.BRIGHT,
logging.INFO: Fore.CYAN,
logging.WARNING: Fore.YELLOW + Style.BRIGHT,
logging.ERROR: Fore.RED + Style.BRIGHT,
logging.CRITICAL: Fore.RED + Style.BRIGHT + Style.NORMAL,
}
# Colors for other parts of the log message
TIME_COLOR = Fore.GREEN
FILE_COLOR = Fore.BLUE
LEVEL_COLOR = Style.BRIGHT
def __init__(self, fmt=None):
super().__init__(fmt or "%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s", "%Y-%m-%d %H:%M:%S")
def format(self, record):
"""
Formats a log record with the appropriate color based on the log level.
Args:
record (logging.LogRecord): The log record to format.
Returns:
str: The formatted log message with colors applied.
"""
# Apply color based on the log level
level_color = self.LOG_COLORS.get(record.levelno, Fore.WHITE)
time_str = f"{self.TIME_COLOR}{self.formatTime(record)}{Style.RESET_ALL}"
levelname_str = f"{level_color}{record.levelname}{Style.RESET_ALL}"
file_info_str = f"{self.FILE_COLOR}{record.filename}:{record.lineno}{Style.RESET_ALL}"
# Format the log message with color
log_msg = f"{time_str} - {levelname_str} - {file_info_str} - {record.msg}"
return log_msg
def setup_logger():
"""
Sets up a logger with a custom color formatter that logs to standard output (stdout).
The logger is configured with the ColorFormatter to format log messages with color based on the log level.
The log level is set to INFO by default, but this can be changed to show more or less detailed messages.
Returns:
logging.Logger: A logger instance that logs formatted messages to stdout.
"""
handler = logging.StreamHandler(sys.stdout)
# Set custom formatter
formatter = ColorFormatter()
handler.setFormatter(formatter)
logger = logging.getLogger(__name__)
# Set to DEBUG to capture all logging levels
DEBUG = os.environ.get("DEBUG", "False").lower() in ("true", "1", "t")
logger.setLevel(logging.DEBUG) if DEBUG else logger.setLevel(logging.INFO)
logger.addHandler(handler)
logger.propagate = False # Prevents multiple logging if re-initialized
return logger
logger = setup_logger() # Initialize once to prevent multiple loggers
|