Source code for gunpowder.nodes.batch_filter

import logging

from .batch_provider import BatchProvider
from gunpowder.batch_request import BatchRequest
from gunpowder.profiling import Timing

logger = logging.getLogger(__name__)


class BatchFilterError(Exception):
    def __init__(self, batch_filter, msg):
        self.batch_filter = batch_filter
        self.msg = msg

    def __str__(self):
        return f"Error in {self.batch_filter.name()}: {self.msg}"


[docs] class BatchFilter(BatchProvider): """Convenience wrapper for :class:`BatchProviders<BatchProvider>` with exactly one input provider. By default, a node of this class will expose the same :class:`ProviderSpec` as the upstream provider. You can modify the provider spec by calling :func:`provides` and :func:`updates` in :func:`setup`. Subclasses need to implement at least :func:`process` to modify a passed batch (downstream). Optionally, the following methods can be implemented: :func:`setup` Initialize this filter. Called after setup of the DAG. All upstream providers will be set up already. :func:`teardown` Destruct this filter, free resources, stop worker processes. :func:`prepare` Prepare for a batch request. Always called before each :func:`process`. Used to communicate dependencies. """ @property def remove_placeholders(self): if not hasattr(self, "_remove_placeholders"): return False return self._remove_placeholders def get_upstream_provider(self): if len(self.get_upstream_providers()) != 1: raise BatchFilterError( self, "BatchFilters need to have exactly one upstream provider, " f"this one has {len(self.get_upstream_providers())}: " f"({[b.name() for b in self.get_upstream_providers()]}", ) return self.get_upstream_providers()[0]
[docs] def updates(self, key, spec): """Update an output provided by this :class:`BatchFilter`. Implementations should call this in their :func:`setup` method, which will be called when the pipeline is build. Args: key (:class:`ArrayKey` or :class:`GraphKey`): The array or point set key this filter updates. spec (:class:`ArraySpec` or :class:`GraphSpec`): The updated spec of the array or point set. """ if key not in self.spec: raise BatchFilterError( self, f"BatchFilter {self} is trying to change the spec for {key}, " f"but {key} is not provided upstream. Upstream offers: " f"{self.get_upstream_provider().spec}", ) self.spec[key] = spec.copy() self.updated_items.append(key) logger.debug("%s updates %s with %s" % (self.name(), key, spec))
[docs] def enable_autoskip(self, skip=True): """Enable automatic skipping of this :class:`BatchFilter`, based on given :func:`updates` and :func:`provides` calls. Has to be called in :func:`setup`. By default, :class:`BatchFilters<BatchFilter>` are not skipped automatically, regardless of what they update or provide. If autskip is enabled, :class:`BatchFilters<BatchFilter>` will only be run if the request contains at least one key reported earlier with :func:`updates` or :func:`provides`. """ self._autoskip_enabled = skip
def _init_spec(self): # default for BatchFilters is to provide the same as upstream if not hasattr(self, "_spec") or self._spec is None: if len(self.get_upstream_providers()) != 0: self._spec = self.get_upstream_provider().spec.copy() else: self._spec = None def internal_teardown(self): logger.debug("Resetting spec of %s", self.name()) self._spec = None self._updated_items = [] self.teardown() @property def updated_items(self): """Get a list of the keys that are updated by this `BatchFilter`. This list is only available after the pipeline has been build. Before that, it is empty. """ if not hasattr(self, "_updated_items"): self._updated_items = [] return self._updated_items @property def autoskip_enabled(self): if not hasattr(self, "_autoskip_enabled"): self._autoskip_enabled = False return self._autoskip_enabled def provide(self, request): skip = self.__can_skip(request) or self.skip_node(request) timing_prepare = Timing(self, "prepare") timing_prepare.start() downstream_request = request.copy() if not skip: dependencies = self.prepare(request) if isinstance(dependencies, BatchRequest): upstream_request = request.update_with(dependencies) elif dependencies is None: upstream_request = request.copy() else: raise BatchFilterError( self, f"This BatchFilter returned a {type(dependencies)}! " "Supported return types are: `BatchRequest` containing your exact " "dependencies or `None`, indicating a dependency on the full request.", ) self.remove_provided(upstream_request) else: upstream_request = request.copy() self.remove_provided(upstream_request) timing_prepare.stop() batch = self.get_upstream_provider().request_batch(upstream_request) timing_process = Timing(self, "process") timing_process.start() if not skip: if dependencies is not None: dependencies.remove_placeholders() node_batch = batch.crop(dependencies) else: node_batch = batch downstream_request.remove_placeholders() processed_batch = self.process(node_batch, downstream_request) if processed_batch is None: processed_batch = node_batch batch = batch.merge(processed_batch, merge_profiling_stats=False).crop( downstream_request ) timing_process.stop() batch.profiling_stats.add(timing_prepare) batch.profiling_stats.add(timing_process) return batch def __can_skip(self, request): """Check if this filter needs to be run for the given request.""" if not self.autoskip_enabled: return False for key, spec in request.items(): if spec.placeholder: continue if key in self.provided_items: return False if key in self.updated_items: return False return True def skip_node(self, request): """To be implemented in subclasses. Skip a node if a condition is met. Can be useful if using a probability to determine whether to use an augmentation, for example. """ pass
[docs] def setup(self): """To be implemented in subclasses. Called during initialization of the DAG. Callees can assume that all upstream providers are set up already. In setup, call :func:`provides` or :func:`updates` to announce the arrays and points provided or changed by this node. """ pass
[docs] def prepare(self, request): """To be implemented in subclasses. Prepare for a batch request. Should return a :class:`BatchRequest` of needed dependencies. If None is returned, it will be assumed that all of request is needed. """ return None
[docs] def process(self, batch, request): """To be implemented in subclasses. Filter a batch, will be called after :func:`prepare`. Should return a :class:`Batch` containing modified Arrays and Graphs. Keys in the returned batch will replace the associated data in the original batch. If None is returned it is assumed that the batch has been modified in place. ``request`` is the same as passed to :func:`prepare`, provided for convenience. Args: batch (:class:`Batch`): The batch received from upstream to be modified by this node. request (:class:`BatchRequest`): The request this node received. The updated batch should meet this request. """ raise BatchFilterError(self, "does not implement 'process'")