Skip to content

Commit cc671cb

Browse files
gmagogsfmclaude
andauthored
[Kernel] [Helion] [17/N] Add Helion kernel torch.compile support (vllm-project#38592)
Signed-off-by: Yanan Cao <gmagogsfm@gmail.com> Co-authored-by: Claude Sonnet 4 <noreply@anthropic.com>
1 parent 856589e commit cc671cb

2 files changed

Lines changed: 98 additions & 78 deletions

File tree

tests/kernels/helion/test_register.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@
3535
validate_helion_settings,
3636
)
3737

38+
if _HOP_AVAILABLE:
39+
from helion._compiler._dynamo.higher_order_ops import (
40+
helion_kernel_wrapper_mutation,
41+
)
42+
3843

3944
def _add_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
4045
out = torch.empty_like(x)
@@ -941,3 +946,60 @@ def fake_impl(*args, **kwargs):
941946
registered = get_registered_kernels()
942947
assert "disabled_kernel" in registered
943948
assert registered["disabled_kernel"] is wrapper
949+
950+
951+
@pytest.mark.skipif(not _HOP_AVAILABLE, reason="Requires PyTorch >= 2.11 for HOP")
952+
class TestTorchCompileHOP:
953+
"""Test that HelionKernelWrapper emits the correct HOP under torch.compile."""
954+
955+
def test_compiled_graph_contains_helion_hop(self):
956+
"""Verify torch.compile on a HelionKernelWrapper emits a
957+
helion_kernel_wrapper_mutation HOP node in the FX graph."""
958+
configs = {"default": helion.Config(block_sizes=[4, 4])}
959+
960+
with dummy_kernel_registry(configs=configs) as register:
961+
add_helion_kernel = register(
962+
op_name="test_torch_compile_add_kernel",
963+
config_picker=lambda args, keys: "default",
964+
)(_add_kernel)
965+
966+
captured_graph: torch.fx.GraphModule | None = None
967+
968+
def capturing_backend(gm, example_inputs):
969+
nonlocal captured_graph
970+
assert captured_graph is None, "Backend called multiple times"
971+
captured_graph = gm
972+
return gm.forward
973+
974+
def f(x, y):
975+
return add_helion_kernel(x, y)
976+
977+
torch._dynamo.reset()
978+
compiled_f = torch.compile(f, backend=capturing_backend, fullgraph=True)
979+
980+
x = torch.randn(4, 4, device="cuda")
981+
y = torch.randn(4, 4, device="cuda")
982+
983+
# Run compiled version and capture graph
984+
compiled_result = compiled_f(x, y)
985+
986+
assert captured_graph is not None
987+
hop_nodes = [
988+
node
989+
for node in captured_graph.graph.nodes
990+
if node.op == "call_function"
991+
and node.target is helion_kernel_wrapper_mutation
992+
]
993+
assert len(hop_nodes) > 0, (
994+
"Expected helion_kernel_wrapper_mutation HOP node in compiled graph, "
995+
f"but found none. Graph nodes: "
996+
f"{[(n.op, n.target) for n in captured_graph.graph.nodes]}"
997+
)
998+
999+
# Verify compiled result matches eager execution
1000+
eager_result = f(x, y) # Run in eager mode
1001+
1002+
assert torch.allclose(compiled_result, eager_result, atol=1e-5, rtol=1e-5), (
1003+
"Compiled execution result doesn't match eager execution. "
1004+
f"Max difference: {torch.max(torch.abs(compiled_result - eager_result))}"
1005+
)

vllm/kernels/helion/register.py

Lines changed: 36 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,11 @@
6363
_HOP_AVAILABLE = requires_torch_version("2.11")
6464

6565
if _HOP_AVAILABLE:
66-
import torch.utils._pytree as pytree
67-
from helion._compiler._dynamo.higher_order_ops import (
68-
helion_kernel_side_table,
69-
helion_kernel_wrapper_mutation,
70-
)
71-
from helion._compiler._dynamo.variables import infer_output_spec
72-
from torch.fx.experimental.proxy_tensor import (
73-
disable_proxy_modes_tracing,
74-
get_proxy_mode,
75-
)
66+
from helion._compiler._dynamo.higher_order_ops import helion_kernel_side_table
67+
from helion._compiler._dynamo.variables import HelionKernelVariable
68+
from torch._dynamo.guards import GuardBuilder
69+
from torch._dynamo.variables.builder import VariableBuilder
70+
7671

7772
logger = init_logger(__name__)
7873

@@ -298,75 +293,11 @@ def __call__(self, *args, **kwargs):
298293
f"Kernel '{self.op_name}' was not initialized. "
299294
"Please open an issue on GitHub."
300295
)
301-
if get_proxy_mode() is not None:
302-
return self._call_via_hop(args, kwargs)
303-
return self._configured_kernel(*args, **kwargs)
304-
305-
def _call_via_hop(
306-
self,
307-
args: tuple[Any, ...],
308-
kwargs: dict[str, Any],
309-
) -> Any:
310-
kernel = self.get_configured_op()._decorated_kernel
311-
kernel_idx = helion_kernel_side_table.add_kernel(kernel)
312-
313-
constant_args, tensor_args = self._partition_args(kernel, args, kwargs)
314-
315-
all_named = {**constant_args, **tensor_args}
316-
full_args = tuple(
317-
all_named.get(n, p.default)
318-
for n, p in kernel.signature.parameters.items() # type: ignore[attr-defined]
319-
if n in all_named or p.default is not p.empty
320-
)
321-
322-
with disable_proxy_modes_tracing():
323-
output_spec = infer_output_spec(kernel, full_args)
324-
325-
hop_result = helion_kernel_wrapper_mutation(
326-
kernel_idx=kernel_idx,
327-
constant_args=constant_args,
328-
tensor_args=tensor_args,
329-
output_spec=output_spec,
330-
)
331-
332-
tree_spec_str = output_spec.get("tree_spec_str")
333-
if tree_spec_str is None:
334-
return None
335-
tree_spec = pytree.treespec_loads(tree_spec_str)
336-
337-
hop_iter = iter(hop_result)
338-
reconstructed = []
339-
for spec in output_spec["leaf_specs"]:
340-
is_constant_scalar = spec["type"] == "scalar" and not isinstance(
341-
spec.get("scalar_value"), torch.SymInt
342-
)
343-
if is_constant_scalar:
344-
reconstructed.append(spec["scalar_value"])
345-
else:
346-
reconstructed.append(next(hop_iter))
347-
return pytree.tree_unflatten(reconstructed, tree_spec)
348296

349-
@staticmethod
350-
def _partition_args(
351-
kernel: Any,
352-
args: tuple[Any, ...],
353-
kwargs: dict[str, Any],
354-
) -> tuple[dict[str, Any], dict[str, Any]]:
355-
constant_args: dict[str, Any] = {}
356-
tensor_args: dict[str, Any] = {}
357-
params = list(kernel.signature.parameters.keys())
358-
for i, val in enumerate(args):
359-
name = params[i]
360-
if isinstance(val, torch.Tensor):
361-
tensor_args[name] = val
362-
else:
363-
constant_args[name] = val
364-
for name, val in kwargs.items():
365-
if isinstance(val, torch.Tensor):
366-
tensor_args[name] = val
367-
else:
368-
constant_args[name] = val
369-
return constant_args, tensor_args
297+
# During Dynamo tracing, this call will be intercepted by our custom
298+
# HelionKernelWrapperVariable and handled via proper HOP emission.
299+
# During eager execution, call the kernel directly.
300+
return self._configured_kernel(*args, **kwargs)
370301

371302
def get_inputs(self) -> dict[str, tuple[Any, ...]]:
372303
if self._input_generator is None:
@@ -535,3 +466,30 @@ def decorator(kernel_func: Callable) -> HelionKernelWrapper:
535466
return kernel_wrapper
536467

537468
return decorator
469+
470+
471+
# Register HelionKernelWrapper with Dynamo's variable tracker system
472+
if _HOP_AVAILABLE:
473+
474+
def _register_vllm_helion_dynamo_variable():
475+
"""Register HelionKernelWrapper with Dynamo's VariableBuilder.
476+
477+
When Dynamo encounters a HelionKernelWrapper during tracing, this
478+
extracts the underlying Helion Kernel, registers it in the side table,
479+
and returns Helion's own HelionKernelVariable to handle HOP emission.
480+
"""
481+
482+
def wrap_helion_kernel_wrapper(
483+
builder: VariableBuilder, value: HelionKernelWrapper
484+
):
485+
kernel = value.get_configured_op()._decorated_kernel
486+
kernel_idx = helion_kernel_side_table.add_kernel(kernel)
487+
builder.install_guards(GuardBuilder.ID_MATCH)
488+
return HelionKernelVariable(kernel, kernel_idx, source=builder.source)
489+
490+
# Register with Dynamo's type dispatch system
491+
dispatch = VariableBuilder._type_dispatch()
492+
dispatch[HelionKernelWrapper] = wrap_helion_kernel_wrapper
493+
494+
# Register immediately when the module is imported
495+
_register_vllm_helion_dynamo_variable()

0 commit comments

Comments
 (0)