Spaces:
Sleeping
Sleeping
import numpy as np | |
from pathlib import Path | |
from scipy.special import gamma | |
from typing import Optional, Tuple, Dict, List, Union | |
import torch | |
import os | |
class GeneralizedGaussianMixture: | |
r"""广义高斯混合分布数据集生成器 | |
P_{\theta_k}(x_i) = \eta_k \exp(-s_k d_k(x_i)) = \frac{p}{2\alpha_k \Gamma(1/p)}\exp(-|\frac{x_i-c_k}{\alpha_k}|^p) | |
""" | |
def __init__(self, | |
D: int = 2, # 维度 | |
K: int = 3, # 聚类数量 | |
p: float = 2.0, # 幂次,p=2为标准高斯分布 | |
centers: Optional[np.ndarray] = None, # 聚类中心 | |
scales: Optional[np.ndarray] = None, # 尺度参数 | |
weights: Optional[np.ndarray] = None, # 混合权重 | |
seed: int = 42): # 随机种子 | |
"""初始化GMM数据集生成器 | |
Args: | |
D: 数据维度 | |
K: 聚类数量 | |
p: 幂次参数,控制分布的形状 | |
centers: 聚类中心,形状为(K, D) | |
scales: 尺度参数,形状为(K, D) | |
weights: 混合权重,形状为(K,) | |
seed: 随机种子 | |
""" | |
self.D = D | |
self.K = K | |
self.p = p | |
self.seed = seed | |
np.random.seed(seed) | |
# 初始化分布参数 | |
if centers is None: | |
self.centers = np.random.randn(K, D) * 2 | |
else: | |
self.centers = centers | |
if scales is None: | |
self.scales = np.random.uniform(0.1, 0.5, size=(K, D)) | |
else: | |
self.scales = scales | |
if weights is None: | |
self.weights = np.random.dirichlet(np.ones(K)) | |
else: | |
self.weights = weights / weights.sum() # 确保权重和为1 | |
def component_pdf(self, x: np.ndarray, k: int) -> np.ndarray: | |
"""计算第k个分量的概率密度 | |
Args: | |
x: 输入数据点,形状为(N, D) | |
k: 分量索引 | |
Returns: | |
概率密度值,形状为(N,) | |
""" | |
# 计算归一化常数 | |
norm_const = self.p / (2 * self.scales[k] * gamma(1/self.p)) | |
# 计算|x_i - c_k|^p / α_k^p | |
z = np.abs(x - self.centers[k]) / self.scales[k] | |
exp_term = np.exp(-np.sum(z**self.p, axis=1)) | |
return np.prod(norm_const) * exp_term | |
def pdf(self, x: np.ndarray) -> np.ndarray: | |
"""计算混合分布的概率密度 | |
Args: | |
x: 输入数据点,形状为(N, D) | |
Returns: | |
概率密度值,形状为(N,) | |
""" | |
density = np.zeros(len(x)) | |
for k in range(self.K): | |
density += self.weights[k] * self.component_pdf(x, k) | |
return density | |
def generate_component_samples(self, n: int, k: int) -> np.ndarray: | |
"""从第k个分量生成样本 | |
Args: | |
n: 样本数量 | |
k: 分量索引 | |
Returns: | |
样本点,形状为(n, D) | |
""" | |
# 使用幂指数分布的反变换采样 | |
u = np.random.uniform(-1, 1, size=(n, self.D)) | |
r = np.abs(u) ** (1/self.p) | |
samples = self.centers[k] + self.scales[k] * np.sign(u) * r | |
return samples | |
def generate_samples(self, N: int) -> Tuple[np.ndarray, np.ndarray]: | |
"""生成混合分布的样本 | |
Args: | |
N: 总样本数量 | |
Returns: | |
X: 生成的数据点,形状为(N, D) | |
y: 对应的概率密度值,形状为(N,) | |
""" | |
# 根据混合权重确定每个分量的样本数量 | |
n_samples = np.random.multinomial(N, self.weights) | |
# 从每个分量生成样本 | |
samples = [] | |
for k in range(self.K): | |
x = self.generate_component_samples(n_samples[k], k) | |
samples.append(x) | |
# 合并并打乱样本 | |
X = np.vstack(samples) | |
idx = np.random.permutation(N) | |
X = X[idx] | |
# 计算概率密度 | |
y = self.pdf(X) | |
return X, y | |
def save_dataset(self, save_dir: Union[str, Path], name: str = 'gmm_dataset') -> None: | |
"""保存数据集到文件 | |
Args: | |
save_dir: 保存目录 | |
name: 数据集名称 | |
""" | |
save_path = Path(save_dir) | |
save_path.mkdir(parents=True, exist_ok=True) | |
# 生成并保存数据 | |
X, y = self.generate_samples(N=1000) | |
np.savez(str(save_path / f'{name}.npz'), | |
X=X, y=y, | |
centers=self.centers, | |
scales=self.scales, | |
weights=self.weights, | |
D=self.D, | |
K=self.K, | |
p=self.p) | |
def load_dataset(cls, file_path: Union[str, Path]) -> "GeneralizedGaussianMixture": | |
"""从文件加载数据集 | |
Args: | |
file_path: 数据文件路径 | |
Returns: | |
加载的GMM对象 | |
""" | |
data = np.load(str(file_path)) | |
return cls( | |
D=int(data['D']), | |
K=int(data['K']), | |
p=float(data['p']), | |
centers=data['centers'], | |
scales=data['scales'], | |
weights=data['weights'] | |
) | |
def test_gmm_dataset(): | |
"""测试GMM数据集生成器""" | |
# 创建2D的GMM数据集 | |
gmm = GeneralizedGaussianMixture( | |
D=2, | |
K=3, | |
p=2.0, | |
centers=np.array([[-2, -2], [0, 0], [2, 2]]), | |
scales=np.array([[0.3, 0.3], [0.2, 0.2], [0.4, 0.4]]), | |
weights=np.array([0.3, 0.4, 0.3]) | |
) | |
# 生成样本 | |
X, y = gmm.generate_samples(1000) | |
# 保存数据集 | |
gmm.save_dataset('test_data') | |
# 加载数据集 | |
loaded_gmm = GeneralizedGaussianMixture.load_dataset('test_data/gmm_dataset.npz') | |
# 验证保存和加载的参数是否一致 | |
assert np.allclose(gmm.centers, loaded_gmm.centers) | |
assert np.allclose(gmm.scales, loaded_gmm.scales) | |
assert np.allclose(gmm.weights, loaded_gmm.weights) | |
print("GMM数据集测试通过!") | |
if __name__ == '__main__': | |
test_gmm_dataset() |