Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
33fdd4f
Add feature specs definition to graph converter
MihirT906 Feb 19, 2025
4598ef5
Add error handling for feature_specs
MihirT906 Feb 19, 2025
60e5255
Bug fix
MihirT906 Feb 19, 2025
82c87eb
Add default structure
MihirT906 Feb 19, 2025
71b53d0
add normalized node features
MihirT906 Feb 20, 2025
838687b
Complete node features implementation
MihirT906 Feb 20, 2025
79d96e7
Remove redundant code
MihirT906 Feb 20, 2025
9a27cb1
Add flexible edge features
MihirT906 Feb 21, 2025
e2f306e
Add comments
MihirT906 Feb 21, 2025
04eea9b
Reformatted with black
MihirT906 Feb 25, 2025
264c04e
Fix comments
MihirT906 Feb 25, 2025
c22d151
black reformatting
MihirT906 Feb 25, 2025
c387d54
Modify initialisation of feature spec to take None
MihirT906 Feb 25, 2025
e0bf15c
Add error handling at initialisation for node features
MihirT906 Mar 11, 2025
08e9c2a
Modified node features to use default feature function map
MihirT906 Mar 16, 2025
9ac758c
Add edge feature error handling
MihirT906 Mar 16, 2025
2266156
Handle edge cases
MihirT906 Mar 17, 2025
8115e38
Add tests for flexible implementation
MihirT906 Mar 17, 2025
40f330d
Reformatting
MihirT906 Mar 17, 2025
203a4b9
Add function to save configuration
MihirT906 Mar 17, 2025
471c58d
Reformatted
MihirT906 Mar 17, 2025
6658862
Add graph settings and dataset features to save functionality
MihirT906 Mar 29, 2025
4b0b535
Complete save functionality
MihirT906 Mar 31, 2025
a1cdc2a
Fix dataset feature save
MihirT906 Apr 1, 2025
07a7978
Add version check code for load
MihirT906 Apr 1, 2025
241a7b7
Ball ID Type Bug Fix
MihirT906 Apr 1, 2025
50fdadc
Reformat
MihirT906 Apr 1, 2025
842b106
Complete JSON load function
MihirT906 Apr 1, 2025
48a5c33
Add test for save and load functionality
MihirT906 Apr 2, 2025
ca6f587
Complete save and load tests
MihirT906 Apr 3, 2025
f4f508a
Add function descriptions
MihirT906 Apr 3, 2025
52924e1
clean post_init
MihirT906 Apr 10, 2025
dcaeaa5
Fix default feature_specs
MihirT906 Apr 10, 2025
e0779c6
Modify to take None argument
MihirT906 Apr 10, 2025
b460478
Remove value type check
MihirT906 Apr 10, 2025
6ef5351
Change feature spec definition to take None
MihirT906 Apr 10, 2025
59a400e
Bug Fixes
MihirT906 Apr 10, 2025
d289228
Changed to take null
MihirT906 Apr 10, 2025
a0ff803
Validate dataset features and graph columns
MihirT906 Apr 10, 2025
41286c6
Remove redundancy
MihirT906 Apr 12, 2025
5c4d879
Add from_json
MihirT906 Apr 12, 2025
4b30912
Use FileLike
MihirT906 Apr 12, 2025
4366c75
Allow boolean parameters in feature specs
MihirT906 Apr 12, 2025
a08c1af
Revert to previous test
MihirT906 Apr 12, 2025
e8b4a35
Remove print statement
MihirT906 Apr 12, 2025
521137e
Add orientation check
MihirT906 Apr 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add from_json
  • Loading branch information
MihirT906 committed Apr 12, 2025
commit 5c4d87923b492a5b3ea98a56967de10e6d55e6e0
15 changes: 9 additions & 6 deletions tests/test_polar_flex.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,9 @@ def default_loaded_converter(
SoccerGraphConverter with feature specs loaded from a json file. The default_converter is saved to a json file and then loaded to create a new converter.
"""
default_converter.save(feature_specs_file)
converter = SoccerGraphConverterPolars(dataset=kloppy_polars_dataset)
converter.load_from_json(feature_specs_file)
converter = SoccerGraphConverterPolars(
dataset=kloppy_polars_dataset, from_json=feature_specs_file
)
return converter

@pytest.fixture()
Expand Down Expand Up @@ -425,8 +426,9 @@ def test_overriden_load_feature_specs(
Tests if the default overriden converter is saved and loaded correctly.
"""
default_overriden_converter.save(feature_specs_file)
converter = SoccerGraphConverterPolars(dataset=kloppy_polars_dataset)
converter.load_from_json(feature_specs_file)
converter = SoccerGraphConverterPolars(
dataset=kloppy_polars_dataset, from_json=feature_specs_file
)
converter.save(new_feature_specs_file)

with open(feature_specs_file, "r") as f1, open(
Expand All @@ -447,8 +449,9 @@ def test_valid_load_feature_specs(
Tests if the valid feature converter is saved and loaded correctly.
"""
valid_feature_converter.save(feature_specs_file)
converter = SoccerGraphConverterPolars(dataset=kloppy_polars_dataset)
converter.load_from_json(feature_specs_file)
converter = SoccerGraphConverterPolars(
dataset=kloppy_polars_dataset, from_json=feature_specs_file
)
converter.save(new_feature_specs_file)

with open(feature_specs_file, "r") as f1, open(
Expand Down
53 changes: 28 additions & 25 deletions unravel/soccer/graphs/graph_converter_pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@
"is_ball": None,
"goal_sin_normed": None,
"goal_cos_normed": None,
"ball_sin_normed":None,
"ball_cos_normed":None,
"ball_sin_normed": None,
"ball_cos_normed": None,
"ball_carrier": None,
},
"edge_features": {
Expand All @@ -62,6 +62,7 @@
},
}


@dataclass(repr=True)
class SoccerGraphConverterPolars(DefaultGraphConverter):
"""
Expand All @@ -83,6 +84,7 @@ class SoccerGraphConverterPolars(DefaultGraphConverter):
chunk_size: int = 2_0000
non_potential_receiver_node_value: float = 0.1
graph_feature_cols: Optional[List[str]] = None
from_json: Optional[str] = None

def __post_init__(self):
if not isinstance(self.dataset, KloppyPolarsDataset):
Expand All @@ -91,6 +93,10 @@ def __post_init__(self):
self.pitch_dimensions: MetricPitchDimensions = (
self.dataset.settings.pitch_dimensions
)

if self.from_json is not None:
self._load_from_json(self.from_json)

self.label_column: str = (
self.label_col if self.label_col is not None else self.dataset._label_column
)
Expand All @@ -114,7 +120,7 @@ def __post_init__(self):
self._validate_feature_specs_general()
self._shuffle()

def _validate_feature_specs_general(self):
def _validate_feature_specs_general(self):
# Override the feature specs to the default version if they are not provided
if self.feature_specs == None or self.feature_specs == {}:
self.feature_specs = DEFAULT_SOCCER_FEATURE_SPECS
Expand Down Expand Up @@ -165,7 +171,7 @@ def _populate_feature_specs(self, feature_func, feature_tag):

self.feature_specs[feature_tag][feature] = params

def load_from_json(self, file_path: str) -> None:
def _load_from_json(self, file_path: str) -> None:
"""
Load the configuration from a JSON file.
Args:
Expand All @@ -190,22 +196,26 @@ def load_from_json(self, file_path: str) -> None:
configuration["graph_converter_attributes"].pop("label_column", None)
configuration["graph_converter_attributes"].pop("graph_id_column", None)

#validate data cols
# validate data cols
if "dataset_cols" in configuration:
#check if all columns in the dataset specified in the JSON file are in the dataset
for col in self.dataset.columns:
# check if all columns in the dataset specified in the JSON file are in the dataset
for col in self.dataset.data.columns:
if col not in configuration["dataset_cols"]:
raise ValueError(
f"Column '{col}' is missing in dataset_cols."
)

#validate graph converter attributes
raise ValueError(f"Column '{col}' is missing in dataset_cols.")

# validate graph converter attributes
for key, value in configuration["graph_converter_attributes"].items():
if key == "dataset":
print("Dataset is not settable from JSON file.")
if key == "graph_feature_cols" and configuration["graph_converter_attributes"]["graph_feature_cols"] is not None:
#check if graph feature columns exist in the dataset
for col in configuration["graph_converter_attributes"]["graph_feature_cols"]:
if (
key == "graph_feature_cols"
and configuration["graph_converter_attributes"]["graph_feature_cols"]
is not None
):
# check if graph feature columns exist in the dataset
for col in configuration["graph_converter_attributes"][
"graph_feature_cols"
]:
if col not in self.dataset.columns:
raise ValueError(
f"Graph feature column '{col}' not found in dataset columns."
Expand All @@ -221,24 +231,18 @@ def load_from_json(self, file_path: str) -> None:
filtered_settings = {
k: v for k, v in graph_settings_dict.items() if k in valid_keys
}
print(valid_keys)
print(filtered_settings)
self.settings = DefaultGraphSettings(**filtered_settings)

self.dataset = self.dataset_checkpoint
self.__post_init__()
if "dataset_features" in configuration:
for key, value in configuration["dataset_features"].items():
dataset_features = self.dataset_checkpoint.get_features()
dataset_features = self.dataset.get_features()
if key in dataset_features:
if value != dataset_features[key]:
raise ValueError(
f"Feature '{key}' in dataset does not match the value in the configuration file."
)
else:
raise ValueError(
f"Feature '{key}' not found in dataset features."
)
raise ValueError(f"Feature '{key}' not found in dataset features.")

def _validate_feature_specs(
self, feature_specs: dict, feature_func, feature_defaults, feature_tag
Expand All @@ -254,7 +258,7 @@ def _validate_feature_specs(
raise ValueError(
f"feature {feature} is not a valid {feature_tag[:4]} feature. Valid features are {list(feature_map.keys())}"
)
#check if feature_specs[feature_tag][feature] is a dictionary
# check if feature_specs[feature_tag][feature] is a dictionary
if isinstance(feature_specs[feature_tag][feature], dict):
for key, value in feature_specs[feature_tag][feature].items():
if key not in feature_map[feature]["defaults"]:
Expand Down Expand Up @@ -705,7 +709,6 @@ def save(self, file_path: str) -> None:
"dataset_features": self.dataset_checkpoint.get_features(),
"dataset_cols": self.dataset_checkpoint.data.columns,
# "graph_feature_cols": self.graph_feature_cols or [],

}
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "w") as f:
Expand Down