Update build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags.py (#2)
Browse files- Update build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags.py (40b7a46bfb70bdc50bffbfcb008171dcfaaa58fc)
build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags.py
CHANGED
@@ -211,7 +211,8 @@ def make_default_opt_flags_nvidia(
|
|
211 |
ns = opt_flags_nvidia.compute_num_stages(*compute_num_stages_args, ep, epilogue_effective_itemsize)
|
212 |
if ns > num_stages:
|
213 |
epilogue_subtile, num_stages = ep, ns
|
214 |
-
|
|
|
215 |
if constraints.get("num_stages", None):
|
216 |
num_stages = constraints["num_stages"]
|
217 |
|
|
|
211 |
ns = opt_flags_nvidia.compute_num_stages(*compute_num_stages_args, ep, epilogue_effective_itemsize)
|
212 |
if ns > num_stages:
|
213 |
epilogue_subtile, num_stages = ep, ns
|
214 |
+
# removed due to https://huggingface.co/kernels-community/triton_kernels/discussions/1
|
215 |
+
# assert num_stages >= 1
|
216 |
if constraints.get("num_stages", None):
|
217 |
num_stages = constraints["num_stages"]
|
218 |
|