Skip to content
Prev Previous commit
Next Next commit
add export settings functionality
  • Loading branch information
Mihir Thalanki authored and Mihir Thalanki committed Oct 4, 2024
commit 4d46686c5472b086aca19bf9401c6f4a80017915
51 changes: 51 additions & 0 deletions settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
{
"__version__": "0.1.2",
"node_features": [
"normalize_x",
"normalize_y",
"unit_velocity_x",
"unit_velocity_y",
"normalized_velocity_angle",
"normalized_speed",
"normalized_goal_distance",
"normed_goal_angle",
"normalized_ball_distance",
"normed_ball_angle",
"team",
"potential_reciever"
],
"edge_features": [
"normalize_dist",
"normalize_speed_diff",
"normalise_cos_pos",
"normalise_sin_pos",
"normalise_cos_vel",
"normalise_sin_vel"
],
"graph_settings": {
"infer_ball_ownership": true,
"infer_goalkeepers": true,
"ball_carrier_treshold": 25.0,
"max_player_speed": 12.0,
"max_ball_speed": 28.0,
"boundary_correction": null,
"self_loop_ball": false,
"adjacency_matrix_connect_type": "ball",
"adjacency_matrix_type": "split_by_team",
"label_type": "binary",
"defending_team_node_value": 0.1,
"non_potential_receiver_node_value": 0.1,
"random_seed": false,
"pad": false,
"verbose": false,
"pitch_dimensions": {
"pitch_length": 105,
"pitch_width": 68,
"max_x": 52.5,
"min_x": -52.5,
"max_y": 34.0,
"min_y": -34.0
},
"pad_settings": null
}
}
15 changes: 15 additions & 0 deletions unravel/soccer/graphs/graph_converter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import sys
from copy import deepcopy
import json
#from .. import __version__

import warnings

Expand Down Expand Up @@ -323,3 +325,16 @@ def to_pickle(self, file_path: str) -> None:
with gzip.open(file_path, "wb") as file:
data = [x.graph_data for x in self.graph_frames]
pickle.dump(data, file)

def export_settings(self) -> None:
file_path = 'settings.json'
data = {
"__version__": "0.1.2",
"node_features": [func_name for func_name,_,_ in self.node_features.get_features()],
"edge_features": [func_name for func_name,_,_ in self.edge_features.get_features()],
"graph_settings": self.settings.to_dict()
}
print(data)

with open(file_path, 'w') as json_file:
json.dump(data, json_file, indent=4)
31 changes: 31 additions & 0 deletions unravel/utils/objects/graph_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,34 @@ def __pad_settings(self):
max_nodes=(n_players * 2) + n_ball,
n_players=n_players,
)

def to_dict(self):
return {
"infer_ball_ownership": self.infer_ball_ownership,
"infer_goalkeepers": self.infer_goalkeepers,
"ball_carrier_treshold": self.ball_carrier_treshold,
"max_player_speed": self.max_player_speed,
"max_ball_speed": self.max_ball_speed,
"boundary_correction": self.boundary_correction,
"self_loop_ball": self.self_loop_ball,
"adjacency_matrix_connect_type": self.adjacency_matrix_connect_type,
"adjacency_matrix_type": self.adjacency_matrix_type,
"label_type": self.label_type,
"defending_team_node_value": self.defending_team_node_value,
"non_potential_receiver_node_value": self.non_potential_receiver_node_value,
"random_seed": self.random_seed,
"pad": self.pad,
"verbose": self.verbose,
"pitch_dimensions": self._serialize_pitch_dimensions(),
"pad_settings": self.pad_settings
}

def _serialize_pitch_dimensions(self):
return {
"pitch_length": self.pitch_dimensions.pitch_length,
"pitch_width": self.pitch_dimensions.pitch_width,
"max_x": self.pitch_dimensions.x_dim.max,
"min_x": self.pitch_dimensions.x_dim.min,
"max_y": self.pitch_dimensions.y_dim.max,
"min_y": self.pitch_dimensions.y_dim.min,
}