|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
from ..octree import DfsOctree as Octree
|
|
|
|
|
|
class Strivec(Octree):
|
|
def __init__(
|
|
self,
|
|
resolution: int,
|
|
aabb: list,
|
|
sh_degree: int = 0,
|
|
rank: int = 8,
|
|
dim: int = 8,
|
|
device: str = "cuda",
|
|
):
|
|
assert np.log2(resolution) % 1 == 0, "Resolution must be a power of 2"
|
|
self.resolution = resolution
|
|
depth = int(np.round(np.log2(resolution)))
|
|
super().__init__(
|
|
depth=depth,
|
|
aabb=aabb,
|
|
sh_degree=sh_degree,
|
|
primitive="trivec",
|
|
primitive_config={"rank": rank, "dim": dim},
|
|
device=device,
|
|
)
|
|
|