11from __future__ import annotations
22
33import asyncio
4- import functools
54import json
65import logging
76from abc import ABC , abstractmethod
87from collections import defaultdict
9- from typing import Any , Coroutine , Dict , List , Optional , Tuple , Type
8+ from typing import Any , Awaitable , Dict , List , Optional , Tuple , Type , cast
109
1110import websockets
1211import websockets .client
1615
1716from .g3typing import (
1817 Hostname ,
19- JsonDict ,
18+ JSONDict ,
19+ JSONObject ,
2020 MessageId ,
2121 SignalBody ,
2222 SignalId ,
@@ -74,11 +74,11 @@ def _init_signal_subscription_handling(self) -> None:
7474
7575 async def subscribe_to_signal (
7676 self , signal_uri_path : UriPath
77- ) -> Tuple [asyncio .Queue [SignalBody ], functools . partial [ Coroutine [ Any , Any , None ] ]]:
77+ ) -> Tuple [asyncio .Queue [SignalBody ], Awaitable [ None ]]:
7878 """Sets up a subscription to the signal with the specified `signal_uri_path`.
7979
80- Returns a tuple with a queue and a callable . Upon receiving signals messages, the message
81- body is added to the queue. The callable can be called to unsubscribe to the signal.
80+ Returns a tuple with a queue and an awaitable . Upon receiving signals messages, the message
81+ body is added to the queue. The awaitable can be awaited to unsubscribe to the signal.
8282 """
8383 self ._subscription_count += 1
8484 signal_id = self ._signal_id_by_path .get (signal_uri_path )
@@ -93,8 +93,7 @@ async def subscribe_to_signal(
9393 ] = signal_queue
9494 return (
9595 signal_queue ,
96- functools .partial (
97- self ._unsubscribe_to_signal ,
96+ self ._unsubscribe_to_signal (
9897 signal_uri_path ,
9998 signal_id ,
10099 SubscriptionId (self ._subscription_count ),
@@ -148,7 +147,7 @@ def __init__(
148147 """Initializes super class properties and additional properties needed for the communication."""
149148 self .g3_logger = logging .getLogger (__name__ )
150149 self ._message_count = 0
151- self ._future_messages : Dict [MessageId , asyncio .Future [JsonDict ]] = {}
150+ self ._future_messages : Dict [MessageId , asyncio .Future [JSONObject ]] = {}
152151 self ._event_loop = asyncio .get_running_loop ()
153152 if subprotocols is None :
154153 subprotocols = self .DEFAULT_SUBPROTOCOLS
@@ -184,17 +183,22 @@ def start_receiver_task(self) -> None:
184183 async def _receiver_task (self ) -> None :
185184 """Listens for and handles/delegates incoming messages."""
186185 async for message in self :
187- json_message : JsonDict = json .loads (message )
186+ json_message : JSONObject = json .loads (message )
188187 self .g3_logger .info (f"Received { json_message } " )
189188 match json_message :
190- case {"id" : message_id }:
191- self ._future_messages [message_id ].set_result (json_message )
189+ case {"id" : message_id , "body" : message_body }:
190+ del json_message ["id" ]
191+ self ._future_messages [cast (MessageId , message_id )].set_result (
192+ message_body
193+ )
192194 case {"signal" : signal_id , "body" : signal_body }:
193- self .receive_signal (signal_id , signal_body )
195+ self .receive_signal (
196+ cast (SignalId , signal_id ), cast (SignalBody , signal_body )
197+ )
194198 case _:
195199 raise InvalidResponseError
196200
197- async def require (self , request : JsonDict ) -> JsonDict :
201+ async def require (self , request : JSONDict ) -> JSONObject :
198202 """Sends a request with a unique id and returns the response."""
199203 self ._message_count += 1
200204 request ["id" ] = self ._message_count
@@ -206,44 +210,48 @@ async def require(self, request: JsonDict) -> JsonDict:
206210 return await future
207211
208212 async def require_get (
209- self , path : UriPath , params : Optional [JsonDict ] = None
210- ) -> JsonDict :
213+ self , path : UriPath , params : Optional [JSONObject ] = None
214+ ) -> JSONObject :
211215 """Sends a GET request and returns the response."""
212216 return await self .require (self .generate_get_request (path , params ))
213217
214- async def require_post (self , path : UriPath , body : Optional [str ] = None ) -> JsonDict :
218+ async def require_post (
219+ self , path : UriPath , body : Optional [JSONObject ] = None
220+ ) -> JSONObject :
215221 """Sends a POST request and returns the response."""
216222 return await self .require (self .generate_post_request (path , body ))
217223
218224 async def require_post_subscribe (self , signal_uri_path : UriPath ) -> SignalId :
219225 """Sends a subscription POST request and returns the body of the response."""
220- response = await self .require_post (signal_uri_path )
226+ response = cast ( JSONDict , await self .require_post (signal_uri_path ) )
221227 try :
222- return response ["body" ]
228+ return cast ( SignalId , response ["body" ])
223229 except (KeyError , json .JSONDecodeError ):
224230 raise InvalidResponseError
225231
226232 async def require_post_unsubscribe (
227233 self , signal_uri_path : UriPath , signal_id : SignalId
228234 ) -> bool :
229235 """Sends an unsubscription POST request and returns a boolean indicating its success."""
230- response = await self .require_post (signal_uri_path , signal_id )
236+ response = cast ( JSONDict , await self .require_post (signal_uri_path , signal_id ) )
231237 try :
232- return response ["body" ]
238+ return cast ( bool , response ["body" ])
233239 except (KeyError , json .JSONDecodeError ):
234240 raise InvalidResponseError
235241
236242 @staticmethod
237243 def generate_get_request (
238- path : UriPath , params : Optional [JsonDict ] = None
239- ) -> JsonDict :
244+ path : UriPath , params : Optional [JSONObject ] = None
245+ ) -> JSONDict :
240246 """Generates a GET request."""
241- request : JsonDict = {"path" : path , "method" : "GET" }
247+ request : JSONDict = {"path" : path , "method" : "GET" }
242248 if params is not None :
243249 request ["params" ] = params
244250 return request
245251
246252 @staticmethod
247- def generate_post_request (path : UriPath , body : Optional [str ] = None ) -> JsonDict :
253+ def generate_post_request (
254+ path : UriPath , body : Optional [JSONObject ] = None
255+ ) -> JSONDict :
248256 """Generates a POST request."""
249- return {"path" : path , "method" : "POST" , "body" : body }
257+ return {"path" : cast ( str , path ) , "method" : "POST" , "body" : body }
0 commit comments