diff --git a/av/filter/graph.pxd b/av/filter/graph.pxd index 2e52bd6ec..e85f67ce4 100644 --- a/av/filter/graph.pxd +++ b/av/filter/graph.pxd @@ -13,10 +13,10 @@ cdef class Graph: cdef dict _name_counts cdef str _get_unique_name(self, str name) + cdef list[FilterContext] _get_context_by_type(self, str type) cdef _register_context(self, FilterContext) cdef _auto_register(self) cdef int _nb_filters_seen - cdef dict _context_by_ptr - cdef dict _context_by_name - cdef dict _context_by_type + cdef dict[long, FilterContext] _context_by_ptr + cdef dict[str, list[FilterContext]] _context_by_type diff --git a/av/filter/graph.py b/av/filter/graph.py index 471b6ebc4..5cd3bbf06 100644 --- a/av/filter/graph.py +++ b/av/filter/graph.py @@ -20,7 +20,6 @@ def __cinit__(self): self._name_counts = {} self._nb_filters_seen = 0 self._context_by_ptr = {} - self._context_by_name = {} self._context_by_type = {} def __dealloc__(self): @@ -37,6 +36,11 @@ def _get_unique_name(self, name: str) -> str: else: return name + @cython.cfunc + @cython.inline + def _get_context_by_type(self, type_: str) -> list[FilterContext]: + return self._context_by_type.get(type_, []) + @cython.ccall def configure(self, auto_buffer: cython.bint = True, force: cython.bint = False): if self.configured and not force: @@ -87,7 +91,6 @@ def add(self, filter, args=None, **kwargs): @cython.cfunc def _register_context(self, ctx: FilterContext): self._context_by_ptr[cython.cast(cython.long, ctx.ptr)] = ctx - self._context_by_name[ctx.ptr.name] = ctx self._context_by_type.setdefault(ctx.filter.ptr.name, []).append(ctx) @cython.cfunc @@ -203,7 +206,7 @@ def set_audio_frame_size(self, frame_size): """ if not self.configured: raise ValueError("graph not configured") - sinks = self._context_by_type.get("abuffersink", []) + sinks = self._get_context_by_type("abuffersink") if not sinks: raise ValueError("missing abuffersink filter") for sink in sinks: @@ -213,13 +216,13 @@ def set_audio_frame_size(self, frame_size): def push(self, frame): if frame is None: - contexts = self._context_by_type.get( - "buffer", [] - ) + self._context_by_type.get("abuffer", []) + contexts = self._get_context_by_type("buffer") + self._get_context_by_type( + "abuffer" + ) elif isinstance(frame, VideoFrame): - contexts = self._context_by_type.get("buffer", []) + contexts = self._get_context_by_type("buffer") elif isinstance(frame, AudioFrame): - contexts = self._context_by_type.get("abuffer", []) + contexts = self._get_context_by_type("abuffer") else: raise ValueError( f"can only AudioFrame, VideoFrame or None; got {type(frame)}" @@ -230,13 +233,13 @@ def push(self, frame): def vpush(self, frame: VideoFrame | None): """Like `push`, but only for VideoFrames.""" - for ctx in self._context_by_type.get("buffer", []): + for ctx in self._get_context_by_type("buffer"): ctx.push(frame) # TODO: Test complex filter graphs, add `at: int = 0` arg to pull() and vpull(). def pull(self): - vsinks = self._context_by_type.get("buffersink", []) - asinks = self._context_by_type.get("abuffersink", []) + vsinks = self._get_context_by_type("buffersink") + asinks = self._get_context_by_type("abuffersink") nsinks = len(vsinks) + len(asinks) if nsinks != 1: @@ -246,7 +249,7 @@ def pull(self): def vpull(self): """Like `pull`, but only for VideoFrames.""" - vsinks = self._context_by_type.get("buffersink", []) + vsinks = self._get_context_by_type("buffersink") nsinks = len(vsinks) if nsinks != 1: raise ValueError(f"can only auto-pull with single sink; found {nsinks}")