Skip to content
Open
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
bb7d854
added BasketballDataset class
not-heavychevy Apr 10, 2025
2abeeff
added BasketballPitchDimensions class
not-heavychevy Apr 10, 2025
bd59522
added graph settings
not-heavychevy Apr 10, 2025
8a83938
added optimized graph converter
not-heavychevy Apr 10, 2025
f5071c6
added ball handling
not-heavychevy Apr 10, 2025
26d6d85
added init files
not-heavychevy Apr 10, 2025
f2d164b
bugfix dataset load() bug
not-heavychevy Apr 10, 2025
d86c0af
added tests
not-heavychevy Apr 10, 2025
d1c0c73
added additional fields computation
not-heavychevy Apr 10, 2025
64f5ee3
BasketballDataset inherits from DefaultDataset
not-heavychevy Apr 12, 2025
835cd59
bugfix
not-heavychevy Apr 12, 2025
98f09ae
files read with kloppy.io
not-heavychevy Apr 19, 2025
0502aa7
added norm parameters
not-heavychevy Apr 19, 2025
d2f6b52
refactor: move get_dataframe to DefaultDataset
not-heavychevy Apr 20, 2025
53ea444
created post_init
not-heavychevy Apr 20, 2025
3482bf9
added self.settings to BasketballDataset
not-heavychevy Apr 20, 2025
51a6657
added add_dummy_labels и add_graph_ids
not-heavychevy Apr 20, 2025
1352f80
rewritten tests for dataset.py
not-heavychevy Apr 21, 2025
b0fc5c1
Refactor BasketballPitchDimensions
not-heavychevy Apr 25, 2025
1e04bfd
added tests for BasketballPitchDimensions
not-heavychevy Apr 25, 2025
627fae8
Refactor BasketballGraphSettings
not-heavychevy Apr 25, 2025
1bdd740
added tests for BasketballGraphSettings
not-heavychevy Apr 25, 2025
7c64156
Merge PitchDimensions and GraphSettings
not-heavychevy Apr 25, 2025
a70739c
graph_settings test update
not-heavychevy Apr 25, 2025
ebe0914
import bugs fix
not-heavychevy Apr 25, 2025
2dcd3fb
graph_converter refactoring
not-heavychevy Apr 26, 2025
4b96024
dataset separator bugfix
not-heavychevy Apr 26, 2025
af3a02a
added tests for graph_converter
not-heavychevy Apr 26, 2025
8a47337
moved the functionality to “features”
not-heavychevy Apr 26, 2025
633afca
tests update
not-heavychevy Apr 26, 2025
7463b1e
tests fix
not-heavychevy Apr 26, 2025
dcfa8e4
Deprecate speed/acceleration thresholds
not-heavychevy Apr 26, 2025
1b5bc3b
unify data/settings access on DefaultDataset
not-heavychevy Apr 26, 2025
7eb2081
Refactor _convert to use polars methods
not-heavychevy Apr 26, 2025
b0b9d72
Add unified graph-export API to GraphConverter
not-heavychevy Apr 26, 2025
e55d30e
added new tests for public export API
not-heavychevy Apr 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
files read with kloppy.io
  • Loading branch information
not-heavychevy committed Apr 19, 2025
commit 98f09ae813c47f542f57eee0bd9eb9acf75b161a
72 changes: 48 additions & 24 deletions unravel/basketball/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import requests
import numpy as np

from kloppy.io import open_as_file

try:
import py7zr
except ImportError:
Expand All @@ -22,8 +24,23 @@ class BasketballDataset(DefaultDataset):
Modes:
- URL: Loads from a 7zip archive (expects a JSON file inside).
- Local: Loads from a file path or game identifier.

Additional parameters:
- max_player_speed, max_ball_speed, max_player_acceleration, max_ball_acceleration:
Threshold values for normalizing player and ball speeds/accelerations.
- orient_ball_owning:
Flag indicating whether to compute oriented direction for ball ownership.
- sample_rate:
Fraction of data to sample (e.g., 0.5 to keep half of the rows).

"""
tracking_data: str
max_player_speed: float = 20.0
max_ball_speed: float = 30.0
max_player_acceleration: float = 10.0
max_ball_acceleration: float = 10.0
orient_ball_owning: bool = False
sample_rate: float = 1.0
data: Optional[pl.DataFrame] = field(default=None, init=False)

def load(self) -> pl.DataFrame:
Expand All @@ -33,38 +50,37 @@ def load(self) -> pl.DataFrame:
"""
# Load via URL if tracking_data starts with "http"
if self.tracking_data.startswith("http"):
with open_as_file(self.tracking_data) as tmp_file:
tmp_filename = tmp_file.name
json_file = None
if py7zr is None:
raise ImportError("py7zr is required to extract 7zip archives.")
response = requests.get(self.tracking_data)
if response.status_code != 200:
raise Exception("Failed to download data from URL.")
with tempfile.NamedTemporaryFile(delete=False, suffix=".7z") as tmp_file:
tmp_file.write(response.content)
tmp_filename = tmp_file.name
with py7zr.SevenZipFile(tmp_filename, mode='r') as archive:
extract_path = tempfile.mkdtemp()
archive.extractall(path=extract_path)
for fname in archive.getnames():
if fname.endswith('.json'):
json_file = os.path.join(tempfile.mkdtemp(), fname)
with open(json_file, 'wb') as f:
file_dict = archive.read(fname)
file_bytes = file_dict[fname].read()
f.write(file_bytes)
break
os.unlink(tmp_filename)
json_file = next(
(os.path.join(extract_path, fname) for fname in os.listdir(extract_path) if fname.endswith('.json')),
None
)
if json_file is None:
raise FileNotFoundError("JSON file not found in extracted archive.")
with open(json_file, 'r', encoding='utf-8') as jf:
json_data = json.load(jf)
else:
# Load from file if a valid file path is provided
if os.path.isfile(self.tracking_data):
with open(self.tracking_data, 'r', encoding='utf-8') as jf:
json_data = json.load(jf)
file_path = self.tracking_data
else:
# Search for a file in the default directory using the game identifier
file_path = os.path.join("data", "nba", f"{self.tracking_data}.json")
if not os.path.isfile(file_path):
raise FileNotFoundError(f"Game file '{self.tracking_data}.json' not found at: {file_path}")
with open(file_path, 'r', encoding='utf-8') as jf:
json_data = json.load(jf)

with open_as_file(file_path) as f:
json_data = json.load(f)

rows = []
# Process JSON as a dictionary
Expand Down Expand Up @@ -107,6 +123,9 @@ def load(self) -> pl.DataFrame:
raise ValueError("Unexpected JSON structure")

self.data = pl.DataFrame(rows, strict=False)
if self.sample_rate < 1.0:
self.data = self.data.sample(fraction=self.sample_rate, with_replacement=False)

return self.data

def compute_additional_fields(self) -> pl.DataFrame:
Expand All @@ -115,9 +134,11 @@ def compute_additional_fields(self) -> pl.DataFrame:
- vx, vy: velocity components,
- speed: magnitude of the velocity,
- direction: movement direction in radians,
- acceleration: change in speed over time.
- acceleration: change in speed over time,
- normalized_speed: speed normalized by max_player_speed or max_ball_speed,
- normalized_acceleration: acceleration normalized by max_player_acceleration or max_ball_acceleration,
- oriented_direction (if orient_ball_owning is True): a placeholder for ball-owning orientation.

Calculations are performed for each group defined by game_id and player.
"""
if self.data is None:
raise ValueError("Data not loaded. Call load() first.")
Expand All @@ -137,19 +158,22 @@ def compute_additional_fields(self) -> pl.DataFrame:
(pl.col("x") - pl.col("x").shift(1)).alias("dx"),
(pl.col("y") - pl.col("y").shift(1)).alias("dy")
])

df = df.with_columns([
(pl.col("dx") / pl.col("dt")).alias("vx"),
(pl.col("dy") / pl.col("dt")).alias("vy")
])

df = df.with_columns([
((pl.col("vx") ** 2 + pl.col("vy") ** 2) ** 0.5).alias("speed")
])
df = df.with_columns([
pl.struct(["vx", "vy"]).apply(
lambda row: float(np.arctan2(row["vy"], row["vx"]))
if (row["vx"] is not None and row["vy"] is not None) else None
).alias("direction")
])

df = df.with_columns(
pl.concat_list([pl.col("vx"), pl.col("vy")])
.map_elements(lambda row: float(np.arctan2(row[1], row[0])) if row[0] is not None and row[1] is not None else None)
.alias("direction")
)

if "game_clock" in df.columns:
df = df.with_columns([
(pl.col("game_clock") - pl.col("game_clock").shift(1)).abs().alias("dt_acc")
Expand Down