Skip to content
Prev Previous commit
Next Next commit
Add velocity and speed feature
  • Loading branch information
Mihir Thalanki authored and Mihir Thalanki committed Aug 31, 2024
commit 9fb02799a8415e237c700a920a5f880664994fc5
34 changes: 26 additions & 8 deletions unravel/utils/features/node_feature_set.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .utils import (
normalize_x,
coord_x,
normalize_y,
coord_y,
normalize_coords,
coord,
unit_vector,
normalize_speed,
normalize_angles
)


Expand All @@ -20,9 +21,9 @@ def add_x(self, normed: bool = False):
If 'normed=True', the function will normalize the x coordinate
"""
if normed:
self.node_feature_functions.append(normalize_x)
self.node_feature_functions.append(('normalize_x', normalize_coords, ['x', 'max_x']))
else:
self.node_feature_functions.append(coord_x)
self.node_feature_functions.append(('coord_x', coord, ['x']))

return self

Expand All @@ -32,11 +33,28 @@ def add_y(self, normed: bool = False):
If 'normed=True', the function will normalize the x coordinate
"""
if normed:
self.node_feature_functions.append(normalize_y)
self.node_feature_functions.append(('normalize_y', normalize_coords, ['y', 'max_y']))
else:
self.node_feature_functions.append(coord_y)
self.node_feature_functions.append(('coord_y', coord, ['y']))

return self

def add_velocity(self, x: bool = True, y: bool = True):
if not (x or y):
print("Warning: No velocity component added. Please add either x or y components")
return
if x:
self.node_feature_functions.append(('unit_velocity_x', unit_vector, ['velocity_x']))
if y:
self.node_feature_functions.append(('unit_velocity_y', unit_vector, ['velocity_y']))
return self

def add_speed(self, normed: bool = True):
if normed:
self.node_feature_functions.append(('normalized_speed', normalize_speed, ['speed', 'max_speed']))
else:
self.node_feature_functions.append(('speed', coord, ['speed']))
return self

def get_features(self):
return self.node_feature_functions
Expand Down
44 changes: 26 additions & 18 deletions unravel/utils/features/node_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,25 +140,33 @@ def ball_features(ball):
]
)

all_params = {'x': 10.0, 'max_x': 105.0, 'y': 5.0, 'max_y': 90}
results = {}
for func in function_list:
sig = inspect.signature(func)
func_args = {}
for param_name, _ in sig.parameters.items():
if param_name in all_params:
func_args[param_name] = all_params[param_name]
else:
func_args[param_name] = None

if func_args:
try:
result = func(**func_args)
results[func.__name__] = result
except Exception as e:
results[func.__name__] = f"Error: {str(e)}"
all_params = {'x': 10.0,
'max_x': 105.0,
'y': 5.0,
'max_y': 90.0,
'velocity_x': 20.0,
'velocity_y': 10.0,
'speed': 20.0,
'max_speed': 40.0,
}

computed_values = {}
for func_name, func, reqd_params in function_list:
try:
if all(param in all_params for param in reqd_params): #if all the required parameters exist in all_params, then compute
params = [all_params[param] for param in reqd_params]
value = func(*params)
computed_values[func_name] = value
else: #else, print out the missing parameters. Maybe you should check if there is a default value. Then it is okay if the parameter is not present
missing_params = [param for param in reqd_params if param not in all_params]
print(f"Warning: Missing parameters {missing_params} for function '{func_name}'")
computed_values[func_name] = None
except Exception as e:
print(f"Error while executing function '{func_name}': {e}")
computed_values[func_name] = None

print(computed_values)

print(results)
# compute ball features
b_features = ball_features(ball)
X = np.append(ap_features, dp_features, axis=0)
Expand Down
19 changes: 11 additions & 8 deletions unravel/utils/features/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,17 @@ def make_sparse(a):


#Functions for node features
def normalize_x(x, max_x):
return x / max_x

def coord_x(x):
return x
def coord(value):
return value
# def normalize_x(x, max_x):
# return x / max_x

def normalize_y(y, max_y):
return y / max_y
# def coord_x(x):
# return x

def coord_y(y):
return y
# def normalize_y(y, max_y):
# return y / max_y

# def coord_y(y):
# return y