Skip to content

Commit 993859c

Browse files
[XPU] fix all_reduce all-zero accuracy issue under torch.compile (vllm-project#39844)
Signed-off-by: Chaojun Zhang <chaojun.zhang@intel.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent 48a65cc commit 993859c

1 file changed

Lines changed: 4 additions & 3 deletions

File tree

vllm/distributed/device_communicators/xpu_communicator.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,10 @@ def __init__(
4747
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
4848
logger.info("Using AgRs manager on XPU device.")
4949

50-
def all_reduce(self, input_) -> torch.Tensor:
51-
dist.all_reduce(input_, group=self.device_group)
52-
return input_
50+
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
51+
output = input_.clone() if torch.compiler.is_compiling() else input_
52+
dist.all_reduce(output, group=self.device_group)
53+
return output
5354

5455
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
5556
world_size = self.world_size

0 commit comments

Comments
 (0)