Skip to content

Commit 8f4148f

Browse files
committed
OP#9015: Migrate to async across the codebase by converting observers, methods, buffers, writers, and test cases to support AsyncIterator and AsyncSubject.
1 parent bfdc0a4 commit 8f4148f

19 files changed

Lines changed: 246 additions & 106 deletions

OTVision/abstraction/observer.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
from typing import Callable, TypeVar
1+
import asyncio
2+
from typing import Awaitable, Callable, TypeVar
23

34
VALUE = TypeVar("VALUE")
45
type Observer[T] = Callable[[T], None]
6+
type AsyncObserver[T] = Callable[[T], Awaitable[None]]
57

68

79
class Subject[T]:
@@ -41,3 +43,46 @@ def __init__(self, subject: Subject[T]) -> None:
4143

4244
def register(self, observer: Observer[T]) -> None:
4345
self._subject.register(observer)
46+
47+
48+
class AsyncSubject[T]:
49+
"""Generic async subject class to handle and notify async observers.
50+
51+
This class ensures that no duplicate observers can be registered.
52+
The order that registered observers are notified is dictated by the order they have
53+
been registered. Meaning, first to be registered is first to be notified.
54+
All observers are awaited concurrently using asyncio.gather().
55+
"""
56+
57+
def __init__(self) -> None:
58+
self._observers: list[AsyncObserver[T]] = []
59+
60+
def register(self, observer: AsyncObserver[T]) -> None:
61+
"""Listen to changes of subject.
62+
63+
Args:
64+
observer (AsyncObserver[T]): the observer to be registered. This must be an
65+
async `Callable` that returns an `Awaitable`.
66+
"""
67+
new_observers = self._observers.copy()
68+
new_observers.append(observer)
69+
self._observers = list(dict.fromkeys(new_observers))
70+
71+
async def notify(self, value: T) -> None:
72+
"""Notifies observers about the value asynchronously.
73+
74+
All observers are notified concurrently using asyncio.gather().
75+
76+
Args:
77+
value (T): value to notify the observer with.
78+
"""
79+
if self._observers:
80+
await asyncio.gather(*[observer(value) for observer in self._observers])
81+
82+
83+
class AsyncObservable[T]:
84+
def __init__(self, subject: AsyncSubject[T]) -> None:
85+
self._subject = subject
86+
87+
def register(self, observer: AsyncObserver[T]) -> None:
88+
self._subject.register(observer)

OTVision/application/buffer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ def __init__(self) -> None:
1010

1111
async def filter(self, pipe: AsyncIterator[T]) -> AsyncIterator[T]:
1212
async for element in pipe:
13-
self.buffer(element)
13+
await self.buffer(element)
1414
yield element
1515

16-
def buffer(self, to_buffer: T) -> None:
16+
async def buffer(self, to_buffer: T) -> None:
1717
self._buffer.append(to_buffer)
1818

1919
def _get_buffered_elements(self) -> list[T]:
@@ -24,5 +24,5 @@ def _reset_buffer(self) -> None:
2424
self._buffer = list()
2525

2626
@abstractmethod
27-
def on_flush(self, event: OBSERVING_TYPE) -> None:
27+
async def on_flush(self, event: OBSERVING_TYPE) -> None:
2828
raise NotImplementedError

OTVision/detect/builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from argparse import ArgumentParser
33
from functools import cached_property
44

5-
from OTVision.abstraction.observer import Subject
5+
from OTVision.abstraction.observer import AsyncSubject
66
from OTVision.application.config import Config, DetectConfig
77
from OTVision.application.config_parser import ConfigParser
88
from OTVision.application.configure_logger import ConfigureLogger
@@ -127,7 +127,7 @@ def frame_count_provider(self) -> FrameCountProvider:
127127
@cached_property
128128
def otdet_file_writer(self) -> OtdetFileWriter:
129129
return OtdetFileWriter(
130-
subject=Subject[OtdetFileWrittenEvent](),
130+
subject=AsyncSubject[OtdetFileWrittenEvent](),
131131
builder=self.otdet_builder,
132132
get_current_config=self.get_current_config,
133133
current_object_detector_metadata=self.current_object_detector_metadata,
@@ -147,7 +147,7 @@ def current_object_detector(self) -> CurrentObjectDetector:
147147

148148
@cached_property
149149
def detected_frame_buffer(self) -> DetectedFrameBuffer:
150-
return DetectedFrameBuffer(subject=Subject[DetectedFrameBufferEvent]())
150+
return DetectedFrameBuffer(subject=AsyncSubject[DetectedFrameBufferEvent]())
151151

152152
@cached_property
153153
def detected_frame_producer(self) -> DetectedFrameProducer:

OTVision/detect/detected_frame_buffer.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from dataclasses import dataclass
22
from datetime import datetime, timedelta
33

4-
from OTVision.abstraction.observer import Observable, Subject
4+
from OTVision.abstraction.observer import AsyncObservable, AsyncSubject
55
from OTVision.application.buffer import Buffer
66
from OTVision.domain.frame import DetectedFrame
77

@@ -51,25 +51,25 @@ class DetectedFrameBufferEvent:
5151

5252

5353
class DetectedFrameBuffer(
54-
Buffer[DetectedFrame, FlushEvent], Observable[DetectedFrameBufferEvent]
54+
Buffer[DetectedFrame, FlushEvent], AsyncObservable[DetectedFrameBufferEvent]
5555
):
56-
def __init__(self, subject: Subject[DetectedFrameBufferEvent]) -> None:
56+
def __init__(self, subject: AsyncSubject[DetectedFrameBufferEvent]) -> None:
5757
Buffer.__init__(self)
58-
Observable.__init__(self, subject)
58+
AsyncObservable.__init__(self, subject)
5959

60-
def on_flush(self, event: FlushEvent) -> None:
60+
async def on_flush(self, event: FlushEvent) -> None:
6161
buffered_elements = self._get_buffered_elements()
62-
self._notify_observers(buffered_elements, event)
62+
await self._notify_observers(buffered_elements, event)
6363
self._reset_buffer()
6464

65-
def _notify_observers(
65+
async def _notify_observers(
6666
self, elements: list[DetectedFrame], event: FlushEvent
6767
) -> None:
68-
self._subject.notify(
68+
await self._subject.notify(
6969
DetectedFrameBufferEvent(
7070
source_metadata=event.source_metadata, frames=elements
7171
)
7272
)
7373

74-
def buffer(self, to_buffer: DetectedFrame) -> None:
74+
async def buffer(self, to_buffer: DetectedFrame) -> None:
7575
self._buffer.append(to_buffer.without_image())

OTVision/detect/file_based_detect_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from functools import cached_property
22

3-
from OTVision.abstraction.observer import Subject
3+
from OTVision.abstraction.observer import AsyncSubject, Subject
44
from OTVision.application.event.new_video_start import NewVideoStartEvent
55
from OTVision.detect.builder import DetectBuilder
66
from OTVision.detect.detected_frame_buffer import FlushEvent
@@ -19,7 +19,7 @@ class FileBasedDetectBuilder(DetectBuilder):
1919
@cached_property
2020
def input_source(self) -> VideoSource:
2121
return VideoSource(
22-
subject_flush=Subject[FlushEvent](),
22+
subject_flush=AsyncSubject[FlushEvent](),
2323
subject_new_video_start=Subject[NewVideoStartEvent](),
2424
get_current_config=self.get_current_config,
2525
frame_rotator=self.frame_rotator,

OTVision/detect/otdet_file_writer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from dataclasses import dataclass
33
from pathlib import Path
44

5-
from OTVision.abstraction.observer import Observer, Subject
5+
from OTVision.abstraction.observer import AsyncObserver, AsyncSubject
66
from OTVision.application.detect.current_object_detector_metadata import (
77
CurrentObjectDetectorMetadata,
88
)
@@ -45,7 +45,7 @@ class OtdetFileWriter:
4545

4646
def __init__(
4747
self,
48-
subject: Subject[OtdetFileWrittenEvent],
48+
subject: AsyncSubject[OtdetFileWrittenEvent],
4949
builder: OtdetBuilder,
5050
get_current_config: GetCurrentConfig,
5151
current_object_detector_metadata: CurrentObjectDetectorMetadata,
@@ -57,7 +57,7 @@ def __init__(
5757
self._current_object_detector_metadata = current_object_detector_metadata
5858
self._save_path_provider = save_path_provider
5959

60-
def write(self, event: DetectedFrameBufferEvent) -> None:
60+
async def write(self, event: DetectedFrameBufferEvent) -> None:
6161
"""Writes detection results to a file in OTDET format.
6262
6363
Processes the detected frames and associated metadata, builds the OTDET
@@ -118,23 +118,23 @@ def write(self, event: DetectedFrameBufferEvent) -> None:
118118

119119
finished_msg = "Finished detection"
120120
log.info(finished_msg)
121-
self.__notify(
121+
await self.__notify(
122122
num_frames=actual_frames,
123123
builder_config=builder_config,
124124
save_location=detections_file,
125125
)
126126

127-
def __notify(
127+
async def __notify(
128128
self, num_frames: int, builder_config: OtdetBuilderConfig, save_location: Path
129129
) -> None:
130-
self._subject.notify(
130+
await self._subject.notify(
131131
OtdetFileWrittenEvent(
132132
number_of_frames=num_frames,
133133
otdet_builder_config=builder_config,
134134
save_location=save_location,
135135
)
136136
)
137137

138-
def register_observer(self, observer: Observer[OtdetFileWrittenEvent]) -> None:
138+
def register_observer(self, observer: AsyncObserver[OtdetFileWrittenEvent]) -> None:
139139
"""Register an observer to receive notifications about otdet file writes.."""
140140
self._subject.register(observer)

OTVision/detect/rtsp_based_detect_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from functools import cached_property
22

3-
from OTVision.abstraction.observer import Subject
3+
from OTVision.abstraction.observer import AsyncSubject, Subject
44
from OTVision.application.config import StreamConfig
55
from OTVision.application.event.new_video_start import NewVideoStartEvent
66
from OTVision.detect.builder import DetectBuilder
@@ -33,7 +33,7 @@ def stream_config(self) -> StreamConfig:
3333
@cached_property
3434
def input_source(self) -> RtspInputSource:
3535
return RtspInputSource(
36-
subject_flush=Subject[FlushEvent](),
36+
subject_flush=AsyncSubject[FlushEvent](),
3737
subject_new_video_start=Subject[NewVideoStartEvent](),
3838
datetime_provider=self.datetime_provider,
3939
frame_counter=Counter(),

OTVision/detect/rtsp_input_source.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import socket
23
from datetime import datetime, timedelta
34
from time import sleep
@@ -13,7 +14,7 @@
1314
)
1415
from numpy import ndarray
1516

16-
from OTVision.abstraction.observer import Subject
17+
from OTVision.abstraction.observer import AsyncSubject, Subject
1718
from OTVision.application.config import (
1819
DATETIME_FORMAT,
1920
Config,
@@ -87,7 +88,7 @@ def fps(self) -> float:
8788

8889
def __init__(
8990
self,
90-
subject_flush: Subject[FlushEvent],
91+
subject_flush: AsyncSubject[FlushEvent],
9192
subject_new_video_start: Subject[NewVideoStartEvent],
9293
datetime_provider: DatetimeProvider,
9394
frame_counter: Counter,
@@ -153,10 +154,10 @@ async def produce(self) -> AsyncIterator[Frame]:
153154
occurrence=occurrence,
154155
)
155156
if self.flush_condition_met():
156-
self._notify_flush_observers()
157+
await self._notify_flush_observers()
157158
self._outdated = True
158159
self._frame_counter.reset()
159-
self._notify_flush_observers()
160+
await self._notify_flush_observers()
160161
except InvalidRtspUrlError as cause:
161162
logger().error(cause)
162163

@@ -209,7 +210,7 @@ def start(self) -> None:
209210
def flush_condition_met(self) -> bool:
210211
return self.current_frame_number % self.flush_buffer_size == 0
211212

212-
def _notify_flush_observers(self) -> None:
213+
async def _notify_flush_observers(self) -> None:
213214
frame_width = self._get_width()
214215
frame_height = self._get_height()
215216
frames = (
@@ -219,7 +220,7 @@ def _notify_flush_observers(self) -> None:
219220
)
220221
duration = timedelta(seconds=round(frames / self.fps))
221222
output = self.create_output()
222-
self.subject_flush.notify(
223+
await self.subject_flush.notify(
223224
FlushEvent.create(
224225
source=self.rtsp_url,
225226
output=output,
@@ -256,7 +257,9 @@ def create_output(self) -> str:
256257
def notify_new_config(self, config: NewOtvisionConfigEvent) -> None:
257258
try:
258259
logger().debug("New OTVision config detected. Flushing buffers...")
259-
self._notify_flush_observers()
260+
261+
# Create task to handle async flush notification
262+
asyncio.create_task(self._notify_flush_observers())
260263
except NoConfigurationFoundError:
261264
logger().info("No configuration found for RTSP stream. Skipping flushing.")
262265

OTVision/detect/video_input_source.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from av.container.input import InputContainer
88
from tqdm.asyncio import tqdm
99

10-
from OTVision.abstraction.observer import Subject
10+
from OTVision.abstraction.observer import AsyncSubject, Subject
1111
from OTVision.application.config import DATETIME_FORMAT, Config
1212
from OTVision.application.detect.timestamper import Timestamper
1313
from OTVision.application.event.new_video_start import NewVideoStartEvent
@@ -59,7 +59,7 @@ def _start_time(self) -> datetime | None:
5959

6060
def __init__(
6161
self,
62-
subject_flush: Subject[FlushEvent],
62+
subject_flush: AsyncSubject[FlushEvent],
6363
subject_new_video_start: Subject[NewVideoStartEvent],
6464
get_current_config: GetCurrentConfig,
6565
frame_rotator: AvVideoFrameRotator,
@@ -137,7 +137,7 @@ async def produce(self) -> AsyncIterator[Frame]:
137137
}
138138
)
139139
counter += 1
140-
self.notify_flush_event_observers(video_file, video_fps)
140+
await self.notify_flush_event_observers(video_file, video_fps)
141141
self._on_video_finished(video_file)
142142
except Exception as e:
143143
log.error(f"Error processing {video_file}", exc_info=e)
@@ -192,7 +192,7 @@ def __overwrite_existing_detection_file(self, detections_file: Path) -> bool:
192192
return False
193193
return True
194194

195-
def notify_flush_event_observers(
195+
async def notify_flush_event_observers(
196196
self, current_video_file: Path, video_fps: float
197197
) -> None:
198198
if expected_duration := self._current_config.detect.expected_duration:
@@ -205,7 +205,7 @@ def notify_flush_event_observers(
205205
current_video_file, start_time=self._start_time
206206
)
207207

208-
self.subject_flush.notify(
208+
await self.subject_flush.notify(
209209
FlushEvent.create(
210210
source=str(current_video_file),
211211
output=str(current_video_file),

OTVision/domain/video_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def close(self) -> None:
2222
raise NotImplementedError
2323

2424
@abstractmethod
25-
def notify_on_flush_event(self, event: FlushEvent) -> None:
25+
async def notify_on_flush_event(self, event: FlushEvent) -> None:
2626
raise NotImplementedError
2727

2828
@abstractmethod

0 commit comments

Comments
 (0)