import math
import logging
from random import random, randint, choices
import itertools
import numpy as np
from scipy.spatial import cKDTree
from skimage.transform import integral_image, integrate
from gunpowder.batch_request import BatchRequest
from gunpowder.coordinate import Coordinate
from gunpowder.roi import Roi
from gunpowder.array import Array
from gunpowder.array_spec import ArraySpec
from .batch_filter import BatchFilter
from gunpowder.profiling import Timing
logger = logging.getLogger(__name__)
[docs]
class RandomLocation(BatchFilter):
"""Choses a batch at a random location in the bounding box of the upstream
provider.
The random location is chosen such that the batch request ROI lies entirely
inside the provider's ROI.
If ``min_masked`` and ``mask`` are set, only batches are returned that have
at least the given ratio of masked-in voxels. This is in general faster
than using the :class:`Reject` node, at the expense of storing an integral
array of the complete mask.
If ``ensure_nonempty`` is set to a :class:`GraphKey`, only batches are
returned that have at least one point of this point collection within the
requested ROI.
Additional tests for randomly picked locations can be implemented by
subclassing and overwriting of :func:`accepts`. This method takes the
randomly shifted request that meets all previous criteria (like
``min_masked`` and ``ensure_nonempty``) and should return ``True`` if the
request is acceptable.
Args:
min_masked (``float``, optional):
If non-zero, require that the random sample contains at least that
ratio of masked-in voxels.
mask (:class:`ArrayKey`, optional):
The array to use for mask checks.
ensure_nonempty (:class:`GraphKey`, optional):
Ensures that when finding a random location, a request for
``ensure_nonempty`` will contain at least one point.
p_nonempty (``float``, optional):
If ``ensure_nonempty`` is set, it defines the probability that a
request for ``ensure_nonempty`` will contain at least one point.
Default value is 1.0.
ensure_centered (``bool``, optional):
if ``ensure_nonempty`` is set, ``ensure_centered`` guarantees
that the center voxel of the roi contains a point.
point_balance_radius (``int``):
if ``ensure_nonempty`` is set, ``point_balance_radius`` defines
a radius s.t. for every point `p` in ``ensure_nonempty``, the
probability of picking p is inversely related to the number of
other points within a distance of ``point_balance_radius`` to p.
This helps avoid oversampling of dense regions of the graph, and
undersampling of sparse regions.
random_shift_key (``ArrayKey`` optional):
if ``random_shift_key`` is not None, this node will populate
that key with a nonspatial array containing the random shift
used for each request. This can be useful for snapshot iterations
if you want to figure out where that snapshot came from.
"""
def __init__(
self,
min_masked=0,
mask=None,
ensure_nonempty=None,
p_nonempty=1.0,
ensure_centered=None,
point_balance_radius=1,
random_shift_key=None,
):
self.min_masked = min_masked
self.mask = mask
self.mask_spec = None
self.mask_integral = None
self.ensure_nonempty = ensure_nonempty
self.points = None
self.p_nonempty = p_nonempty
self.upstream_spec = None
self.random_shift = None
self.ensure_centered = ensure_centered
self.point_balance_radius = point_balance_radius
self.random_shift_key = random_shift_key
def setup(self):
upstream = self.get_upstream_provider()
self.upstream_spec = upstream.spec
if self.mask and self.min_masked > 0:
assert self.mask in self.upstream_spec, (
"Upstream provider does not have %s" % self.mask
)
self.mask_spec = self.upstream_spec.array_specs[self.mask]
logger.info("requesting complete mask...")
mask_request = BatchRequest({self.mask: self.mask_spec})
mask_batch = upstream.request_batch(mask_request)
logger.info("allocating mask integral array...")
mask_data = mask_batch.arrays[self.mask].data
mask_integral_dtype = np.uint64
logger.debug("mask size is %s", mask_data.size)
if mask_data.size < 2**32:
mask_integral_dtype = np.uint32
if mask_data.size < 2**16:
mask_integral_dtype = np.uint16
logger.debug("chose %s as integral array dtype", mask_integral_dtype)
self.mask_integral = np.array(mask_data > 0, dtype=mask_integral_dtype)
self.mask_integral = integral_image(self.mask_integral).astype(
mask_integral_dtype
)
if self.ensure_nonempty:
assert self.ensure_nonempty in self.upstream_spec, (
"Upstream provider does not have %s" % self.ensure_nonempty
)
graph_spec = self.upstream_spec.graph_specs[self.ensure_nonempty]
logger.info("requesting all %s points...", self.ensure_nonempty)
nonempty_request = BatchRequest({self.ensure_nonempty: graph_spec})
nonempty_batch = upstream.request_batch(nonempty_request)
self.points = cKDTree(
[p.location for p in nonempty_batch[self.ensure_nonempty].nodes]
)
point_counts = self.points.query_ball_point(
[p.location for p in nonempty_batch[self.ensure_nonempty].nodes],
r=self.point_balance_radius,
)
weights = [1 / len(point_count) for point_count in point_counts]
self.cumulative_weights = list(itertools.accumulate(weights))
logger.debug("retrieved %d points", len(self.points.data))
# clear bounding boxes of all provided arrays and points --
# RandomLocation does not have limits (offsets are ignored)
for key, spec in self.spec.items():
if spec.roi is not None:
spec.roi.shape = Coordinate((None,) * spec.roi.dims)
self.updates(key, spec)
# provide randomness if asked for
if self.random_shift_key is not None:
self.provides(self.random_shift_key, ArraySpec(nonspatial=True))
def prepare(self, request):
logger.debug("request: %s", request.array_specs)
logger.debug("my spec: %s", self.spec)
if request.array_specs.keys():
lcm_voxel_size = self.spec.get_lcm_voxel_size(request.array_specs.keys())
else:
lcm_voxel_size = Coordinate((1,) * request.get_total_roi().dims)
shift_roi = self.__get_possible_shifts(request, lcm_voxel_size)
if request.array_specs.keys():
shift_roi = shift_roi.snap_to_grid(lcm_voxel_size, mode="shrink")
lcm_shift_roi = shift_roi / lcm_voxel_size
logger.debug(
"restricting random locations to multiples of voxel size %s",
lcm_voxel_size,
)
else:
lcm_shift_roi = shift_roi
assert not lcm_shift_roi.unbounded, (
"Can not pick a random location, intersection of upstream ROIs is "
"unbounded."
)
assert not lcm_shift_roi.empty, (
"Can not satisfy batch request, no location covers all requested " "ROIs."
)
random_shift = self.__select_random_shift(
request, lcm_shift_roi, lcm_voxel_size
)
self.random_shift = random_shift
self.__shift_request(request, random_shift)
return request
def provide(self, request):
timing_prepare = Timing(self, "prepare")
timing_prepare.start()
downstream_request = request.copy()
self.prepare(request)
self.remove_provided(request)
timing_prepare.stop()
batch = self.get_upstream_provider().request_batch(request)
timing_process = Timing(self, "process")
timing_process.start()
downstream_request.remove_placeholders()
self.process(batch, downstream_request)
timing_process.stop()
batch.profiling_stats.add(timing_prepare)
batch.profiling_stats.add(timing_process)
return batch
def process(self, batch, request):
if self.random_shift_key is not None:
batch[self.random_shift_key] = Array(
np.array(self.random_shift),
ArraySpec(nonspatial=True),
)
# reset ROIs to request
for array_key, spec in request.array_specs.items():
batch.arrays[array_key].spec.roi = spec.roi
for graph_key, spec in request.graph_specs.items():
batch.graphs[graph_key].spec.roi = spec.roi
# change shift point locations to lie within roi
for graph_key in request.graph_specs.keys():
batch.graphs[graph_key].shift(-self.random_shift)
def accepts(self, request):
"""Should return True if the randomly chosen location is acceptable
(besided meeting other criteria like ``min_masked`` and/or
``ensure_nonempty``). Subclasses can overwrite this method to implement
additional tests for acceptable locations."""
return True
def __get_possible_shifts(self, request, voxel_size):
total_shift_roi = None
for key, spec in request.items():
if spec.roi is None:
continue
request_roi = spec.roi
provided_roi = self.upstream_spec[key].roi
shift_roi = provided_roi.shift(-request_roi.begin).grow(
(0,) * request_roi.dims, -(request_roi.shape - voxel_size)
)
if total_shift_roi is None:
total_shift_roi = shift_roi
else:
if shift_roi != total_shift_roi:
total_shift_roi = total_shift_roi.intersect(shift_roi)
logger.debug("valid shifts for request in " + str(total_shift_roi))
return total_shift_roi
def __select_random_shift(self, request, lcm_shift_roi, lcm_voxel_size):
ensure_points = self.ensure_nonempty is not None and random() <= self.p_nonempty
while True:
if ensure_points:
random_shift = self.__select_random_location_with_points(
request, lcm_shift_roi, lcm_voxel_size
)
else:
random_shift = self.__select_random_location(
lcm_shift_roi, lcm_voxel_size
)
logger.debug("random shift: " + str(random_shift))
if not self.__is_min_masked(random_shift, request):
logger.debug("random location does not meet 'min_masked' criterium")
continue
if not self.__accepts(random_shift, request):
logger.debug("random location does not meet user-provided criterium")
continue
return random_shift
def __is_min_masked(self, random_shift, request):
if not self.mask or self.min_masked == 0:
return True
# get randomly chosen mask ROI
request_mask_roi = request.array_specs[self.mask].roi
request_mask_roi = request_mask_roi.shift(random_shift)
# get coordinates inside mask array
mask_voxel_size = self.spec[self.mask].voxel_size
request_mask_roi_in_array = request_mask_roi / mask_voxel_size
request_mask_roi_in_array -= self.mask_spec.roi.offset / mask_voxel_size
# get number of masked-in voxels
num_masked_in = integrate(
self.mask_integral,
[request_mask_roi_in_array.begin],
[
request_mask_roi_in_array.end
- Coordinate((1,) * self.mask_integral.ndim)
],
)[0]
mask_ratio = float(num_masked_in) / request_mask_roi_in_array.size
logger.debug("mask ratio is %f", mask_ratio)
return mask_ratio >= self.min_masked
def __accepts(self, random_shift, request):
# create a shifted copy of the request
shifted_request = request.copy()
self.__shift_request(shifted_request, random_shift)
return self.accepts(shifted_request)
def __shift_request(self, request, shift):
# shift request ROIs
for specs_type in [request.array_specs, request.graph_specs]:
for key, spec in specs_type.items():
if spec.roi is None:
continue
roi = spec.roi.shift(shift)
specs_type[key].roi = roi
def __select_random_location_with_points(
self, request, lcm_shift_roi, lcm_voxel_size
):
request_points = request.graph_specs.get(self.ensure_nonempty)
if request_points is None:
total_roi = request.get_total_roi()
logger.warning(
f"Requesting non empty {self.ensure_nonempty}, however {self.ensure_nonempty} "
f"has not been requested. Falling back on using the total roi of the "
f"request {total_roi} for {self.ensure_nonempty}."
)
request_points_roi = total_roi
else:
request_points_roi = request_points.roi
while True:
# How to pick shifts that ensure that a randomly chosen point is
# contained in the request ROI:
#
#
# request point
# [---------) .
# 0 +10 17
#
# least shifted to contain point
# [---------)
# 8 +10
# ==
# point-request.begin-request.shape+1
#
# most shifted to contain point:
# [---------)
# 17 +10
# ==
# point-request.begin
#
# all possible shifts
# [---------)
# 8 +10
# ==
# point-request.begin-request.shape+1
# ==
# request.shape
# pick a random point
point = choices(self.points.data, cum_weights=self.cumulative_weights)[0]
logger.debug("select random point at %s", point)
# get the lcm voxel that contains this point
lcm_location = Coordinate(point / lcm_voxel_size)
logger.debug("belongs to lcm voxel %s", lcm_location)
# align the point request ROI with lcm voxel grid
lcm_roi = request_points_roi.snap_to_grid(lcm_voxel_size, mode="shrink")
lcm_roi = lcm_roi / lcm_voxel_size
logger.debug("Point request ROI: %s", request_points_roi)
logger.debug("Point request lcm ROI shape: %s", lcm_roi.shape)
# get all possible starting points of lcm_roi.shape that contain
# lcm_location
if self.ensure_centered:
lcm_shift_roi_begin = (
lcm_location
- lcm_roi.begin
- lcm_roi.shape / 2
+ Coordinate((1,) * len(lcm_location))
)
lcm_shift_roi_shape = Coordinate((1,) * len(lcm_location))
else:
lcm_shift_roi_begin = (
lcm_location
- lcm_roi.begin
- lcm_roi.shape
+ Coordinate((1,) * len(lcm_location))
)
lcm_shift_roi_shape = lcm_roi.shape
lcm_point_shift_roi = Roi(lcm_shift_roi_begin, lcm_shift_roi_shape)
logger.debug("lcm point shift roi: %s", lcm_point_shift_roi)
# intersect with total shift ROI
if not lcm_point_shift_roi.intersects(lcm_shift_roi):
logger.debug(
"reject random shift, random point %s shift ROI %s does "
"not intersect total shift ROI %s",
point,
lcm_point_shift_roi,
lcm_shift_roi,
)
continue
lcm_point_shift_roi = lcm_point_shift_roi.intersect(lcm_shift_roi)
# select a random shift from all possible shifts
random_shift = self.__select_random_location(
lcm_point_shift_roi, lcm_voxel_size
)
logger.debug("random shift: %s", random_shift)
# count all points inside the shifted ROI
points = self.__get_points_in_roi(request_points_roi.shift(random_shift))
assert (
point in points
), "Requested batch to contain point %s, but got points " "%s" % (
point,
points,
)
return random_shift
def __select_random_location(self, lcm_shift_roi, lcm_voxel_size):
# select a random point inside ROI
random_shift = Coordinate(
randint(begin, end - 1)
for begin, end in zip(lcm_shift_roi.begin, lcm_shift_roi.end)
)
random_shift *= lcm_voxel_size
return random_shift
def __get_points_in_roi(self, roi):
points = []
center = roi.center
radius = math.ceil(float(max(roi.shape)) / 2)
candidates = self.points.query_ball_point(center, radius, p=np.inf)
for i in candidates:
if roi.contains(self.points.data[i]):
points.append(self.points.data[i])
return np.array(points)