Source code for gunpowder.nodes.normalize

import logging
import numpy as np

from gunpowder.batch_request import BatchRequest

from .batch_filter import BatchFilter

logger = logging.getLogger(__name__)


[docs] class Normalize(BatchFilter): """Normalize the values of an array to be floats between 0 and 1, based on the type of the array. Args: array (:class:`ArrayKey`): The key of the array to modify. factor (scalar, optional): The factor to use. If not given, a factor is chosen based on the ``dtype`` of the array (e.g., ``np.uint8`` would result in a factor of ``1.0/255``). dtype (data-type, optional): The datatype of the normalized array. Defaults to ``np.float32``. """ def __init__(self, array, factor=None, dtype=np.float32): self.array = array self.factor = factor self.dtype = dtype def setup(self): self.enable_autoskip() array_spec = self.spec[self.array].copy() array_spec.dtype = self.dtype self.updates(self.array, array_spec) def prepare(self, request): deps = BatchRequest() deps[self.array] = request[self.array] deps[self.array].dtype = None return deps def process(self, batch, request): if self.array not in batch.arrays: return factor = self.factor array = batch.arrays[self.array] array.spec.dtype = self.dtype if factor is None: logger.debug( "automatically normalizing %s with dtype=%s", self.array, array.data.dtype, ) if array.data.dtype == np.uint8: factor = 1.0 / 255 elif array.data.dtype == np.uint16: factor = 1.0 / (256 * 256 - 1) elif array.data.dtype == np.float32: assert array.data.min() >= 0 and array.data.max() <= 1, ( "Values are float but not in [0,1], I don't know how " "to normalize. Please provide a factor." ) factor = 1.0 else: raise RuntimeError( "Automatic normalization for " + str(array.data.dtype) + " not implemented, please " "provide a factor." ) logger.debug("scaling %s with %f", self.array, factor) array.data = array.data.astype(self.dtype) * factor