|
63 | 63 | _HOP_AVAILABLE = requires_torch_version("2.11") |
64 | 64 |
|
65 | 65 | if _HOP_AVAILABLE: |
66 | | - import torch.utils._pytree as pytree |
67 | | - from helion._compiler._dynamo.higher_order_ops import ( |
68 | | - helion_kernel_side_table, |
69 | | - helion_kernel_wrapper_mutation, |
70 | | - ) |
71 | | - from helion._compiler._dynamo.variables import infer_output_spec |
72 | | - from torch.fx.experimental.proxy_tensor import ( |
73 | | - disable_proxy_modes_tracing, |
74 | | - get_proxy_mode, |
75 | | - ) |
| 66 | + from helion._compiler._dynamo.higher_order_ops import helion_kernel_side_table |
| 67 | + from helion._compiler._dynamo.variables import HelionKernelVariable |
| 68 | + from torch._dynamo.guards import GuardBuilder |
| 69 | + from torch._dynamo.variables.builder import VariableBuilder |
| 70 | + |
76 | 71 |
|
77 | 72 | logger = init_logger(__name__) |
78 | 73 |
|
@@ -298,75 +293,11 @@ def __call__(self, *args, **kwargs): |
298 | 293 | f"Kernel '{self.op_name}' was not initialized. " |
299 | 294 | "Please open an issue on GitHub." |
300 | 295 | ) |
301 | | - if get_proxy_mode() is not None: |
302 | | - return self._call_via_hop(args, kwargs) |
303 | | - return self._configured_kernel(*args, **kwargs) |
304 | | - |
305 | | - def _call_via_hop( |
306 | | - self, |
307 | | - args: tuple[Any, ...], |
308 | | - kwargs: dict[str, Any], |
309 | | - ) -> Any: |
310 | | - kernel = self.get_configured_op()._decorated_kernel |
311 | | - kernel_idx = helion_kernel_side_table.add_kernel(kernel) |
312 | | - |
313 | | - constant_args, tensor_args = self._partition_args(kernel, args, kwargs) |
314 | | - |
315 | | - all_named = {**constant_args, **tensor_args} |
316 | | - full_args = tuple( |
317 | | - all_named.get(n, p.default) |
318 | | - for n, p in kernel.signature.parameters.items() # type: ignore[attr-defined] |
319 | | - if n in all_named or p.default is not p.empty |
320 | | - ) |
321 | | - |
322 | | - with disable_proxy_modes_tracing(): |
323 | | - output_spec = infer_output_spec(kernel, full_args) |
324 | | - |
325 | | - hop_result = helion_kernel_wrapper_mutation( |
326 | | - kernel_idx=kernel_idx, |
327 | | - constant_args=constant_args, |
328 | | - tensor_args=tensor_args, |
329 | | - output_spec=output_spec, |
330 | | - ) |
331 | | - |
332 | | - tree_spec_str = output_spec.get("tree_spec_str") |
333 | | - if tree_spec_str is None: |
334 | | - return None |
335 | | - tree_spec = pytree.treespec_loads(tree_spec_str) |
336 | | - |
337 | | - hop_iter = iter(hop_result) |
338 | | - reconstructed = [] |
339 | | - for spec in output_spec["leaf_specs"]: |
340 | | - is_constant_scalar = spec["type"] == "scalar" and not isinstance( |
341 | | - spec.get("scalar_value"), torch.SymInt |
342 | | - ) |
343 | | - if is_constant_scalar: |
344 | | - reconstructed.append(spec["scalar_value"]) |
345 | | - else: |
346 | | - reconstructed.append(next(hop_iter)) |
347 | | - return pytree.tree_unflatten(reconstructed, tree_spec) |
348 | 296 |
|
349 | | - @staticmethod |
350 | | - def _partition_args( |
351 | | - kernel: Any, |
352 | | - args: tuple[Any, ...], |
353 | | - kwargs: dict[str, Any], |
354 | | - ) -> tuple[dict[str, Any], dict[str, Any]]: |
355 | | - constant_args: dict[str, Any] = {} |
356 | | - tensor_args: dict[str, Any] = {} |
357 | | - params = list(kernel.signature.parameters.keys()) |
358 | | - for i, val in enumerate(args): |
359 | | - name = params[i] |
360 | | - if isinstance(val, torch.Tensor): |
361 | | - tensor_args[name] = val |
362 | | - else: |
363 | | - constant_args[name] = val |
364 | | - for name, val in kwargs.items(): |
365 | | - if isinstance(val, torch.Tensor): |
366 | | - tensor_args[name] = val |
367 | | - else: |
368 | | - constant_args[name] = val |
369 | | - return constant_args, tensor_args |
| 297 | + # During Dynamo tracing, this call will be intercepted by our custom |
| 298 | + # HelionKernelWrapperVariable and handled via proper HOP emission. |
| 299 | + # During eager execution, call the kernel directly. |
| 300 | + return self._configured_kernel(*args, **kwargs) |
370 | 301 |
|
371 | 302 | def get_inputs(self) -> dict[str, tuple[Any, ...]]: |
372 | 303 | if self._input_generator is None: |
@@ -535,3 +466,30 @@ def decorator(kernel_func: Callable) -> HelionKernelWrapper: |
535 | 466 | return kernel_wrapper |
536 | 467 |
|
537 | 468 | return decorator |
| 469 | + |
| 470 | + |
| 471 | +# Register HelionKernelWrapper with Dynamo's variable tracker system |
| 472 | +if _HOP_AVAILABLE: |
| 473 | + |
| 474 | + def _register_vllm_helion_dynamo_variable(): |
| 475 | + """Register HelionKernelWrapper with Dynamo's VariableBuilder. |
| 476 | +
|
| 477 | + When Dynamo encounters a HelionKernelWrapper during tracing, this |
| 478 | + extracts the underlying Helion Kernel, registers it in the side table, |
| 479 | + and returns Helion's own HelionKernelVariable to handle HOP emission. |
| 480 | + """ |
| 481 | + |
| 482 | + def wrap_helion_kernel_wrapper( |
| 483 | + builder: VariableBuilder, value: HelionKernelWrapper |
| 484 | + ): |
| 485 | + kernel = value.get_configured_op()._decorated_kernel |
| 486 | + kernel_idx = helion_kernel_side_table.add_kernel(kernel) |
| 487 | + builder.install_guards(GuardBuilder.ID_MATCH) |
| 488 | + return HelionKernelVariable(kernel, kernel_idx, source=builder.source) |
| 489 | + |
| 490 | + # Register with Dynamo's type dispatch system |
| 491 | + dispatch = VariableBuilder._type_dispatch() |
| 492 | + dispatch[HelionKernelWrapper] = wrap_helion_kernel_wrapper |
| 493 | + |
| 494 | + # Register immediately when the module is imported |
| 495 | + _register_vllm_helion_dynamo_variable() |
0 commit comments