test_valid_scaled_grouped_mm_2d_3d was previously passing when this code was written, however, when we hit some cryptic cublas errors that were not easy to resolve, we marked the test as skipped. At the time, we were sprinting on MXFP8, and this was not a priority to deep dive. Since then, that cublas error has self-resolved; however, we now have numerical mismatches in the output. Specifically, the forward pass outputs match exactly with torch.equal, but the gradients have most columns matching exactly, with some columns requiring atol/rtol=1 equals one to pass.
test_valid_scaled_grouped_mm_2d_3d was previously passing when this code was written, however, when we hit some cryptic cublas errors that were not easy to resolve, we marked the test as skipped. At the time, we were sprinting on MXFP8, and this was not a priority to deep dive. Since then, that cublas error has self-resolved; however, we now have numerical mismatches in the output. Specifically, the forward pass outputs match exactly with torch.equal, but the gradients have most columns matching exactly, with some columns requiring atol/rtol=1 equals one to pass.