Skip to content

Commit 3b9cf4d

Browse files
Merge pull request #43 from UnravelSports/feat/unravel_graph
add epsilon for nans
2 parents 19a6ec0 + 2634a2e commit 3b9cf4d

1 file changed

Lines changed: 12 additions & 2 deletions

File tree

unravel/soccer/graphs/graph_converter.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -481,8 +481,18 @@ def __add_additional_kwargs(self, d):
481481
d["is_gk"] = np.where(
482482
d[Column.POSITION_NAME] == self.settings.goalkeeper_id, True, False
483483
)
484-
d["position"] = np.stack((d[Column.X], d[Column.Y], d[Column.Z]), axis=-1)
485-
d["velocity"] = np.stack((d[Column.VX], d[Column.VY], d[Column.VZ]), axis=-1)
484+
d["position"] = np.nan_to_num(
485+
np.stack((d[Column.X], d[Column.Y], d[Column.Z]), axis=-1),
486+
nan=1e-10,
487+
posinf=1e3,
488+
neginf=-1e3,
489+
)
490+
d["velocity"] = np.nan_to_num(
491+
np.stack((d[Column.VX], d[Column.VY], d[Column.VZ]), axis=-1),
492+
nan=1e-10,
493+
posinf=1e3,
494+
neginf=-1e3,
495+
)
486496

487497
if len(np.where(d["team_id"] == d["ball_id"])[0]) >= 1:
488498
ball_index = np.where(d["team_id"] == d["ball_id"])[0]

0 commit comments

Comments
 (0)