Skip to content
Prev Previous commit
Next Next commit
Add comments
  • Loading branch information
Mihir Thalanki authored and Mihir Thalanki committed Sep 3, 2024
commit a4e5a2f8f30d145d588dfcad2563de2e398dd421
25 changes: 25 additions & 0 deletions unravel/utils/features/node_feature_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def add_y(self, normed: bool = False):
def add_velocity(
self, x: bool = True, y: bool = True, angle: bool = True, normed: bool = False
):
"""
Adds a function to return the x and y unit vectors of the velocity as well as the angle.
The angle can be normalized and is only calculated if both x and y components of velocity is present.
"""
if not (x or y):
print(
"Warning: No velocity component added. Please add either x or y components"
Expand Down Expand Up @@ -97,6 +101,9 @@ def add_velocity(
return self

def add_speed(self, normed: bool = True):
"""
Adds a function that calculates the speed. Can be normalized
"""
if normed:
self.node_feature_functions.append(
("normalized_speed", normalize_speed, ["speed", "max_speed"])
Expand All @@ -106,6 +113,9 @@ def add_speed(self, normed: bool = True):
return self

def add_goal_distance(self, normed: bool = True):
"""
Adds a function that calculates the distance of the ball/player to the goal. Can be normalized
"""
if normed:
self.node_feature_functions.append(
(
Expand All @@ -129,6 +139,9 @@ def add_goal_distance(self, normed: bool = True):
return self

def add_goal_angle(self, normed: bool = True):
"""
Adds a function that calculates the angle of the player to the goal. Can be normalized
"""
if normed:
self.node_feature_functions.append(
("normed_goal_angle", normalize_angles, ["goal_angle"])
Expand All @@ -141,6 +154,9 @@ def add_goal_angle(self, normed: bool = True):
return self

def add_ball_distance(self, normed: bool = True):
"""
Adds a function to calculate the distance of the player from the ball. Can be normalized
"""
if normed:
self.node_feature_functions.append(
(
Expand All @@ -164,6 +180,9 @@ def add_ball_distance(self, normed: bool = True):
return self

def add_ball_angle(self, normed: bool = True):
"""
Adds a function to calculate the angle of player to the ball. Can be normalized
"""
if normed:
self.node_feature_functions.append(
("normed_ball_angle", normalize_angles, ["ball_angle"])
Expand All @@ -176,10 +195,16 @@ def add_ball_angle(self, normed: bool = True):
return self

def add_team(self):
"""
Adds a function that returns 1 if player is on same team but not in possession, 0.1 for all other players, 0.1 if the player is 'missing'
"""
self.node_feature_functions.append(("team", lambda t: t, ["team"]))
return self

def add_potential_reciever(self):
"""
Adds a function that returns 1 if player is a potential reciever
"""
self.node_feature_functions.append(
(
"potential_reciever",
Expand Down