API Reference
Data Containers
Batch
- class gunpowder.Batch[source]
Contains the requested batch as a collection of
Arrays
andGraph
that is passed through the pipeline from sources to sinks.This collection mimics a dictionary. Items can be added with:
batch = Batch() batch[array_key] = Array(...) batch[graph_key] = Graph(...)Here,
array_key
andgraph_key
areArrayKey
andGraphKey
. The items can be queried with:array = batch[array_key] graph = batch[graph_key]Furthermore, pairs of keys/values can be iterated over using
batch.items()
.To access only arrays or graphs, use the dictionaries
batch.arrays
orbatch.graphs
, respectively.Attributes:
- merge(batch, merge_profiling_stats=True)[source]
Merge this batch (
a
) with another batch (b
).This creates a new batch
c
containing arrays and graphs from both batchesa
andb
:
Arrays or Graphs that exist in either
a
orb
will be referenced inc
(not copied).Arrays or Graphs that exist in both batches will keep only a reference to the version in
b
inc
.All other cases will lead to an exception.
Array
- class gunpowder.Array(data, spec=None, attrs=None)[source]
A numpy array with a specification describing the data.
Args:
data (array-like):
The data to be stored in the array. Will be converted to a numpy array, if necessary.
spec (
ArraySpec
, optional):A spec describing the data.
attrs (
dict
, optional):Optional attributes to describe this array.
- crop(roi, copy=True)[source]
Create a cropped copy of this Array.
Args:
roi (
Roi
):ROI in world units to crop to.
copy (
bool
):Make a copy of the data.
- merge(array, copy_from_self=False, copy=False)[source]
Merge this array with another one. The resulting array will have the size of the larger one, with values replaced from
array
.This only works if one of the two arrays is contained in the other. In this case,
array
will overwrite values inself
(unlesscopy_from_self
is set toTrue
).A copy will only be made if necessary or
copy
is set toTrue
.
Graph
- class gunpowder.Graph(nodes: Iterator[Node], edges: Iterator[Edge], spec: GraphSpec)[source]
A structure containing a list of
Node
, a list ofEdge
, and a specification describing the data.Args:
- add_edge(edge: Edge)[source]
Adds an edge to the graph. If an edge exists with the same u and v, its attributes will be overwritten.
- add_node(node: Node)[source]
Adds a node to the graph. If a node exists with the same id as the node you are adding, its attributes will be overwritten.
- crop(roi: Roi)[source]
Will remove all nodes from self that are not contained in roi except for “dangling” nodes. This means that if there are nodes A, B s.t. there is an edge (A, B) and A is contained in roi but B is not, the edge (A, B) is considered contained in the roi and thus node B will be kept as a “dangling” node.
Note there is a helper function trim that will remove B and replace it with a node at the intersection of the edge (A, B) and the bounding box of roi.
Args:
roi (
Roi
):ROI in world units to crop to.
- classmethod from_nx_graph(graph, spec)[source]
Create a gunpowder graph from a networkx graph. The network graph is expected to have a “location” attribute for each node. If it is a subclass of a networkx graph with extra functionality, this may not work.
- merge(other, copy_from_self=False, copy=False)[source]
Merge this graph with another. The resulting graph will have the Roi of the larger one.
This only works if one of the two graphs contains the other. In this case,
other
will overwrite edges and nodes with the same ID inself
(unlesscopy_from_self
is set toTrue
). Vertices and edges inself
that are contained in the Roi ofother
will be removed (vice versa forcopy_from_self
)A copy will only be made if necessary or
copy
is set toTrue
.
- relabel_connected_components()[source]
create a new attribute “component” for each node in this Graph
- remove_node(node: Node, retain_connectivity=False)[source]
Remove a node.
retain_connectivity: preserve removed nodes neighboring edges. Given graph: a->b->c, removing b without retain_connectivity would leave us with two connected components, {‘a’} and {‘b’}. removing ‘b’ with retain_connectivity flag set to True would leave us with the graph: a->c, and only one connected component {a, c}, thus preserving the connectivity of ‘a’ and ‘c’
Node
- class gunpowder.Node(id: int, location: ndarray, temporary: bool = False, attrs: Dict[str, Any] | None = None)[source]
A stucture representing each node in a Graph.
Args:
id (
int
):A unique identifier for this Node
location (
np.ndarray
):A numpy array containing a nodes location
Optional attrs (
dict
, str ->Any
):A dictionary containing a mapping from attribute to value. Used to store any extra attributes associated with the Node such as color, size, etc.
Optional temporary (bool):
A tag to mark a node as temporary. Some operations such as trim might make new nodes that are just biproducts of viewing the data with a limited scope. These nodes are only guaranteed to have an id different from those in the same Graph, but may have conflicts if you request multiple graphs from the same source with different rois.
Edge
ArrayKey
- class gunpowder.ArrayKey(identifier)[source]
A key to identify arrays in requests, batches, and across nodes.
Used as key in
BatchRequest
andBatch
to retrieve array specs or arrays.Args:
identifier (
string
):A unique, human readable identifier for this array key. Will be used in log messages and to look up arrays in requests and batches. Should be upper case (like
RAW
,GT_LABELS
). The identifier is unique: Two array keys with the same identifier will refer to the same array.
GraphKey
- class gunpowder.GraphKey(identifier)[source]
A key to identify graphs in requests, batches, and across nodes.
Used as key in
BatchRequest
andBatch
to retrieve specs or graphs.Args:
identifier (
string
):A unique, human readable identifier for this graph key. Will be used in log messages and to look up graphs in requests and batches. Should be upper case (like
CENTER_GRAPH
). The identifier is unique: Two graph keys with the same identifier will refer to the same graph.
Requests and Specifications
ProviderSpec
- class gunpowder.ProviderSpec(array_specs=None, graph_specs=None)[source]
A collection of (possibly partial)
ArraySpecs
andGraphSpecs
describing aBatchProvider's
offered arrays and graphs.This collection mimics a dictionary. Specs can be added with:
provider_spec = ProviderSpec() provider_spec[array_key] = ArraySpec(...) provider_spec[graph_key] = GraphSpec(...)Here,
array_key
andgraph_key
areArrayKey
andGraphKey
. The specs can be queried with:array_spec = provider_spec[array_key] graph_spec = provider_spec[graph_key]Furthermore, pairs of keys/values can be iterated over using
provider_spec.items()
.To access only array or graph specs, use the dictionaries
provider_spec.array_specs
orprovider_spec.graph_specs
, respectively.Args:
Attributes:
BatchRequest
- class gunpowder.BatchRequest(*args, random_seed=None, **kwargs)[source]
A collection of (possibly partial)
ArraySpec
andGraphSpec
forming a request.Inherits from
ProviderSpec
.Additional Kwargs:
random_seed (
int
):The random seed that will be associated with this batch to guarantee deterministic and repeatable batch requests.
- add(key, shape, voxel_size=None, directed=None, placeholder=False)[source]
Convenience method to add an array or graph spec by providing only the shape of a ROI (in world units).
A ROI with zero-offset will be generated. If more than one request is added, the ROIs with smaller shapes will be shifted to be centered in the largest one.
Args:
The key for which to add a spec.
shape (
Coordinate
):A tuple containing the shape of the desired roi
voxel_size (
Coordinate
):A tuple contening the voxel sizes for each corresponding dimension
ArraySpec
- class gunpowder.ArraySpec(roi=None, voxel_size=None, interpolatable=None, nonspatial=False, dtype=None, placeholder=False)[source]
Contains meta-information about an array. This is used by
BatchProviders
to communicate the arrays they offer, as well as byArrays
to describe the data they contain.Attributes:
roi (
Roi
):The region of interested represented by this array spec. Can be
None
for nonspatial arrays or to indicate the true value is unknown.voxel_size (
Coordinate
):The size of the spatial axises in world units. Can be
None
for nonspatial arrays or to indicate the true value is unknown.interpolatable (
bool
):Whether the values of this array can be interpolated.
nonspatial (
bool
, optional):If set, this array does not represent spatial data (e.g., a list of labels for samples in a batch).
roi
andvoxel_size
have to beNone
. No consistency checks will be performed.dtype (
np.dtype
):The data type of the array.
GraphSpec
- class gunpowder.GraphSpec(roi=None, directed=None, dtype=<class 'numpy.float32'>, placeholder=False)[source]
Contains meta-information about a graph. This is used by
BatchProviders
to communicate the graphs they offer, as well as byGraph
to describe the data they contain.Attributes:
roi (
Roi
):The region of interested represented by this graph.
directed (
bool
, optional):Whether the graph is directed or not.
dtype (
dtype
, optional):The data type of the “location” attribute. Currently only supports np.float32.
Geometry
Coordinate
- class gunpowder.Coordinate(*array_like)[source]
A
tuple
of integers.Allows the following element-wise operators: addition, subtraction, multiplication, division, absolute value, and negation. All operations are applied element wise and support both Coordinates and Numbers. This allows to perform simple arithmetics with coordinates, e.g.:
shape = Coordinate(2, 3, 4) voxel_size = Coordinate(10, 5, 1) size = shape*voxel_size # == Coordinate(20, 15, 4) size * 2 + 1 # == Coordinate(41, 31, 9)Coordinates can be initialized with any iterable of ints, e.g.:
Coordinate((1,2,3)) Coordinate([1,2,3]) Coordinate(np.array([1,2,3]))Coordinates can also pack multiple args into an iterable, e.g.:
Coordinate(1,2,3)
- is_multiple_of(coordinate: Coordinate) bool [source]
Test if this coordinate is a multiple of the given coordinate.
- round_division(other: Coordinate) Coordinate [source]
Will always round down if self % other == other / 2.
Node Base Classes
BatchProvider
- class gunpowder.BatchProvider[source]
Superclass for all nodes in a gunpowder graph.
A
BatchProvider
providesBatches
containingArrays
and/orGraph
. The available data is specified in aProviderSpec
instance, accessible viaspec
.To create a new node, subclass this class and implement (at least)
setup()
andprovide()
.A
BatchProvider
can be linked to any number of otherBatchProviders
upstream. If your node accepts exactly one upstream provider, consider subclassingBatchFilter
instead.
- provide(request)[source]
To be implemented in subclasses.
This function takes a
BatchRequest
and should return the correspondingBatch
.Args:
request(
BatchRequest
):The request to process.
- provides(key, spec)[source]
Introduce a new output provided by this
BatchProvider
.Implementations should call this in their
setup()
method, which will be called when the pipeline is build.Args:
- request_batch(request)[source]
Request a batch from this provider.
Args:
request (
BatchRequest
):A request containing (possibly partial)
ArraySpecs
andGraphSpecs
.
- setup()[source]
To be implemented in subclasses.
Called during initialization of the DAG. Callees can assume that all upstream providers are set up already.
In setup, call
provides()
to announce the arrays and points provided by this node.
- property spec
Get the
ProviderSpec
of thisBatchProvider
.Note that the spec is only available after the pipeline has been build. Before that, it is
None
.
BatchFilter
- class gunpowder.BatchFilter[source]
Convenience wrapper for
BatchProviders
with exactly one input provider.By default, a node of this class will expose the same
ProviderSpec
as the upstream provider. You can modify the provider spec by callingprovides()
andupdates()
insetup()
.Subclasses need to implement at least
process()
to modify a passed batch (downstream). Optionally, the following methods can be implemented:Initialize this filter. Called after setup of the DAG. All upstream providers will be set up already.
Destruct this filter, free resources, stop worker processes.
Prepare for a batch request. Always called before each
process()
. Used to communicate dependencies.
- enable_autoskip(skip=True)[source]
Enable automatic skipping of this
BatchFilter
, based on givenupdates()
andprovides()
calls. Has to be called insetup()
.By default,
BatchFilters
are not skipped automatically, regardless of what they update or provide. If autskip is enabled,BatchFilters
will only be run if the request contains at least one key reported earlier withupdates()
orprovides()
.
- prepare(request)[source]
To be implemented in subclasses.
Prepare for a batch request. Should return a
BatchRequest
of needed dependencies. If None is returned, it will be assumed that all of request is needed.
- process(batch, request)[source]
To be implemented in subclasses.
Filter a batch, will be called after
prepare()
. Should return aBatch
containing modified Arrays and Graphs. Keys in the returned batch will replace the associated data in the original batch. If None is returned it is assumed that the batch has been modified in place.request
is the same as passed toprepare()
, provided for convenience.Args:
batch (
Batch
):The batch received from upstream to be modified by this node.
request (
BatchRequest
):The request this node received. The updated batch should meet this request.
- provides(key, spec)
Introduce a new output provided by this
BatchProvider
.Implementations should call this in their
setup()
method, which will be called when the pipeline is build.Args:
- request_batch(request)
Request a batch from this provider.
Args:
request (
BatchRequest
):A request containing (possibly partial)
ArraySpecs
andGraphSpecs
.
- setup()[source]
To be implemented in subclasses.
Called during initialization of the DAG. Callees can assume that all upstream providers are set up already.
In setup, call
provides()
orupdates()
to announce the arrays and points provided or changed by this node.
- property spec
Get the
ProviderSpec
of thisBatchProvider
.Note that the spec is only available after the pipeline has been build. Before that, it is
None
.
- teardown()
To be implemented in subclasses.
Called during destruction of the DAG. Subclasses should use this to stop worker processes, if they used some.
- updates(key, spec)[source]
Update an output provided by this
BatchFilter
.Implementations should call this in their
setup()
method, which will be called when the pipeline is build.Args:
Source Nodes
ArraySource
- class gunpowder.ArraySource(key: ArrayKey, array: Array, interpolatable: bool | None = None)[source]
A array source.
Provides a source for any array that can fit into the funkelab funlib.persistence.Array format. This class comes with assumptions about the available metadata and convenient methods for indexing the data with a
Roi
in world units.Args:
key (
ArrayKey
):The ArrayKey for accessing this array.
array (
Array
):A funlib.persistence.Array object.
interpolatable (
bool
, optional):Whether the array is interpolatable. If not given it is guessed based on dtype.
ZarrSource
- class gunpowder.ZarrSource(store: BaseStore | MutableMapping | str = None, datasets=None, array_specs=None, channels_first=True, filename=None)[source]
A zarr data source.
Provides arrays from zarr datasets. If the attribute
resolution
is set in a zarr dataset, it will be used as the array’svoxel_size
. If the attributeoffset
is set in a dataset, it will be used as the offset of theRoi
for this array. It is assumed that the offset is given in world units.Args:
store (
string
,zarr.BaseStore
):A zarr store or path to a zarr directory or zip file.
datasets (
dict
,ArrayKey
->string
):Dictionary of array keys to dataset names that this source offers.
array_specs (
dict
,ArrayKey
->ArraySpec
, optional):An optional dictionary of array keys to array specs to overwrite the array specs automatically determined from the data file. This is useful to set a missing
voxel_size
, for example. Only fields that are notNone
in the givenArraySpec
will be used.channels_first (
bool
, optional):Specifies the ordering of the dimensions of the HDF5-like data source. If channels_first is set (default), then the input shape is expected to be (channels, spatial dimensions). This is recommended because of better performance. If channels_first is set to false, then the input data is read in channels_last manner and converted to channels_first.
Hdf5Source
- class gunpowder.Hdf5Source(filename, datasets, array_specs=None, channels_first=True)[source]
An HDF5 data source.
Provides arrays from HDF5 datasets. If the attribute
resolution
is set in a HDF5 dataset, it will be used as the array’svoxel_size
. If the attributeoffset
is set in a dataset, it will be used as the offset of theRoi
for this array. It is assumed that the offset is given in world units.Args:
filename (
string
):The HDF5 file.
datasets (
dict
,ArrayKey
->string
):Dictionary of array keys to dataset names that this source offers.
array_specs (
dict
,ArrayKey
->ArraySpec
, optional):An optional dictionary of array keys to array specs to overwrite the array specs automatically determined from the data file. This is useful to set a missing
voxel_size
, for example. Only fields that are notNone
in the givenArraySpec
will be used.channels_first (
bool
, optional):Specifies the ordering of the dimensions of the HDF5-like data source. If channels_first is set (default), then the input shape is expected to be (channels, spatial dimensions). This is recommended because of better performance. If channels_first is set to false, then the input data is read in channels_last manner and converted to channels_first.
KlbSource
- class gunpowder.KlbSource(filename, array, array_spec=None, num_threads=1)[source]
A KLB data source.
Provides a single array from the given KLB dataset.
Args:
filename (
string
):The name of the KLB file. This string can be a glob expression (e.g.,
frame_*.klb
), in which case all files that match are sorted and stacked together to form an additional dimension (like time). The additional dimension will start at 0 and have a default voxel size of 1 (which can be overwritten using thearray_spec
argument).array (
ArrayKey
):ArrayKey that this source offers.
array_spec (
ArraySpec
, optional):num_threads (
int
):An optional integer to pass to pyklb reader indicating the number of threads to use when reading klb files. Entering None causes uses the pyklb default, which now is based on the number of cores in the machine. This pyklb default is bad for jobs on the cluster that are limited to the number of cores requested, and 1 is recommended.
DvidSource
- class gunpowder.DvidSource(hostname, port, uuid, datasets, masks=None, array_specs=None)[source]
A DVID array source.
Provides arrays from DVID servers for each array key given.
Args:
hostname (
string
):The name of the DVID server.
port (
int
):The port of the DVID server.
uuid (
string
):The UUID of the DVID node to use.
datasets (
dict
,ArrayKey
->string
):Dictionary mapping array keys to DVID data instance names that this source offers.
masks (
dict
,ArrayKey
->string
, optional):Dictionary of array keys to DVID ROI instance names. This will create binary masks from DVID ROIs.
array_specs (
dict
,ArrayKey
->ArraySpec
, optional):An optional dictionary of array keys to specs to overwrite the array specs automatically determined from the DVID server. This is useful to set
voxel_size
, for example. Only fields that are notNone
in the givenArraySpec
will be used.
CsvPointsSource
- class gunpowder.CsvPointsSource(filename: str, points: GraphKey, spatial_cols: list[int], points_spec: GraphSpec | None = None, scale: int | float | tuple | list | ndarray | None = None, id_col: int | None = None, delimiter: str = ',')[source]
Read a set of points from a comma-separated-values text file. Each line in the file represents one point, e.g. z y x (id). Note: this reads all points into memory and finds the ones in the given roi by iterating over all the points. For large datasets, this may be too slow.
Args:
filename (
string
):The file to read from.
points (
GraphKey
):The key of the points set to create.
spatial_cols (list[
int
]):The columns of the csv that hold the coordinates of the points (in the order that you want them to be used in training)
points_spec (
GraphSpec
, optional):An optional
GraphSpec
to overwrite the points specs automatically determined from the CSV file. This is useful to set theRoi
manually.scale (scalar or array-like):
An optional scaling to apply to the coordinates of the points read from the CSV file. This is useful if the points refer to voxel positions to convert them to world units.
id_col (
int
, optional):The column of the csv that holds an id for each point. If not provided, the index of the rows are used as the ids. When read from file, ids are left as strings and not cast to anything.
delimiter (
str
, optional):Delimiter to pass to the csv reader. Defaults to “,”.
GraphSource
- class gunpowder.GraphSource(graph_provider, graph, graph_spec=None)[source]
Creates a gunpowder graph source from a daisy graph provider. Queries for graphs from a given Roi will only return edges completely contained within the Roi - edges that cross the boundary will not be included.
Arguments:
- graph_provider (
daisy.SharedGraphProvider
):A daisy graph provider to read the graph from. Can be backed by MongoDB or any other implemented backend.
- graph (
GraphKey
):The key of the graph to create
- graph_spec (
GraphSpec
, optional):An optional
GraphSpec
containing a roi and optionally whether the graph is directed. The default is to have an unbounded roi and detect directedness from the graph_provider.
Augmentation Nodes
DefectAugment
- class gunpowder.DefectAugment(intensities, prob_missing=0.05, prob_low_contrast=0.05, prob_artifact=0.0, prob_deform=0.0, contrast_scale=0.1, artifact_source=None, artifacts=None, artifacts_mask=None, deformation_strength=20, axis=0, p=1.0)[source]
Augment intensity arrays section-wise with artifacts like missing sections, low-contrast sections, by blending in artifacts drawn from a separate source, or by deforming a section.
Args:
intensities (
ArrayKey
):The key of the array of intensities to modify.
prob_missing(
float
): prob_low_contrast(float
): prob_artifact(float
): prob_deform(float
):Probabilities of having a missing section, low-contrast section, an artifact (see param
artifact_source
) or a deformed slice. The sum should not exceed 1. Values in missing sections will be set to 0.contrast_scale (
float
, optional):By how much to scale the intensities for a low-contrast section, used if
prob_low_contrast
> 0.artifact_source (class:BatchProvider, optional):
artifacts(
ArrayKey
, optional):The key to query
artifact_source
for to get the intensities of the artifacts.artifacts_mask(
ArrayKey
, optional):The key to query
artifact_source
for to get the alpha mask of the artifacts to blend them withintensities
.deformation_strength (
int
, optional):Strength of the slice deformation in voxels, used if
prob_deform
> 0. The deformation models a fold by shifting the section contents towards a randomly oriented line in the section. The line itself will be drawn with a value of 0.axis (
int
, optional):Along which axis sections are cut.
p (
float
, optional):Probability applying the augmentation. Default is 1.0 (always apply). Should be a float value between 0 and 1. Lowering this value could be useful for computational efficiency and increasing augmentation space.
DeformAugment
- class gunpowder.DeformAugment(control_point_spacing: Coordinate, jitter_sigma: Coordinate, scale_interval=(1.0, 1.0), rotate: bool = True, subsample=1, spatial_dims=3, use_fast_points_transform=False, recompute_missing_points=True, transform_key: ArrayKey | None = None, graph_raster_voxel_size: Coordinate | None = None, p: float = 1.0)[source]
Elasticly deform a batch. Requests larger batches upstream to avoid data loss due to rotation and jitter.
Args:
control_point_spacing (
tuple
ofint
):Distance between control points for the elastic deformation, in physical units per dimension.
jitter_sigma (
tuple
offloat
):Standard deviation of control point jitter distribution, in physical units per dimension.
scale_interval (
tuple
of twofloats
):Interval to randomly sample scale factors from.
subsample (
int
):Instead of creating an elastic transformation on the full resolution, create one subsampled by the given factor, and linearly interpolate to obtain the full resolution transformation. This can significantly speed up this node, at the expense of having visible piecewise linear deformations for large factors. Usually, a factor of 4 can savely by used without noticable changes. However, the default is 1 (i.e., no subsampling).
spatial_dims (
int
):The number of spatial dimensions in arrays. Spatial dimensions are assumed to be the last ones and cannot be more than 3 (default). Set this value here to avoid treating channels as spacial dimension. If, for example, your array is indexed as
(c,y,x)
(2D plus channels), you would want to setspatial_dims=2
to perform the elastic deformation only on x and y.use_fast_points_transform (
bool
):By solving for all of your points simultaneously with the following 3 step proceedure: 1) Rasterize nodes into numpy array 2) Apply elastic transform to array 3) Read out nodes via center of mass of transformed points You can gain substantial speed up as opposed to calculating the elastic transform for each point individually. However this may lead to nodes being lost during the transform.
recompute_missing_points (
bool
):Whether or not to compute the elastic transform node wise for nodes that were lossed during the fast elastic transform process.
p (
float
, optional):Probability applying the augmentation. Default is 1.0 (always apply). Should be a float value between 0 and 1. Lowering this value could be useful for computational efficiency and increasing augmentation space.
IntensityAugment
- class gunpowder.IntensityAugment(array, scale_min, scale_max, shift_min, shift_max, z_section_wise=False, clip=True, p=1.0)[source]
Randomly scale and shift the values of an intensity array.
Args:
array (
ArrayKey
):The intensity array to modify.
scale_min (
float
): scale_max (float
): shift_min (float
): shift_max (float
):The min and max of the uniformly randomly drawn scaling and shifting values for the intensity augmentation. Intensities are changed as:
a = a.mean() + (a-a.mean())*scale + shiftz_section_wise (
bool
):Perform the augmentation z-section wise. Requires 3D arrays and assumes that z is the first dimension.
clip (
bool
):Set to False if modified values should not be clipped to [0, 1] Disables range check!
p (
float
, optional):Probability applying the augmentation. Default is 1.0 (always apply). Should be a float value between 0 and 1. Lowering this value could be useful for computational efficiency and increasing augmentation space.
NoiseAugment
- class gunpowder.NoiseAugment(array, mode='gaussian', clip=True, p=1.0, **kwargs)[source]
Add random noise to an array. Uses the scikit-image function skimage.util.random_noise. See scikit-image documentation for more information on arguments and additional kwargs.
Args:
array (
ArrayKey
):The intensity array to modify. Should be of type float and within range [-1, 1] or [0, 1].
mode (
string
):Type of noise to add, see scikit-image documentation.
clip (
bool
):Whether to preserve the image range (either [-1, 1] or [0, 1]) by clipping values in the end, see scikit-image documentation
p (
float
, optional):Probability applying the augmentation. Default is 1.0 (always apply). Should be a float value between 0 and 1. Lowering this value could be useful for computational efficiency and increasing augmentation space.
SimpleAugment
- class gunpowder.SimpleAugment(mirror_only=None, transpose_only=None, mirror_probs=None, transpose_probs=None, p=1.0)[source]
Randomly mirror and transpose all
Arrays
andGraph
in a batch.Args:
mirror_only (
list
ofint
, optional):If set, only mirror between the given axes. This is useful to exclude channels that have a set direction, like time.
transpose_only (
list
ofint
, optional):If set, only transpose between the given axes. This is useful to limit the transpose to axes with the same resolution or to exclude non-spatial dimensions.
mirror_probs (
list
offloat
, optional):If set, provides the probability for mirroring given axes. Default is 0.5 per axis. If given, must be given for every axis. i.e. [0,1,0] for 100% chance of mirroring axis 1 an no others.
transpose_probs (
dict
oftuple
->float
orlist
offloat
, optional):The probability of transposing. If None, each transpose is equally likely. Can also be a dictionary of for
tuple
->float
. For example {(0,1,2):0.5, (1,0,2):0.5} to define a 50% chance of transposing axes 0 and 1. Note that if a provided option violates the transpose_only arg it will be dropped and remaining options will be reweighted. Can also be provided as a list offloat
. i.e. [0.3, 0.5, 0.7]. This will automatically generate a list of possible permutations and attempt to weight them appropriately. A weight of 0 means this axis will never be transposed, a weight of 1 means this axis will always be transposed.p (
float
, optional):Probability applying the augmentation. Default is 1.0 (always apply). Should be a float value between 0 and 1. Lowering this value could be useful for computational efficiency and increasing augmentation space.
Location Manipulation Nodes
Crop
- class gunpowder.Crop(key, roi=None, fraction_negative=None, fraction_positive=None)[source]
Limits provided ROIs by either giving a new
Roi
or crop fractions from either face of the provided ROI.Args:
The key of the array or points set to modify.
roi (
Roi
orNone
):The ROI to crop to.
fraction_negative (
tuple
offloat
):Relative crop starting from the negative end of the provided ROI.
fraction_positive (
tuple
offloat
):Relative crop starting from the positive end of the provided ROI.
Pad
- class gunpowder.Pad(key, size, mode='constant', value=None)[source]
Add a constant intensity padding around arrays of another batch provider. This is useful if your requested batches can be larger than what your source provides.
Args:
The array or points set to pad.
size (
Coordinate
orNone
):The padding to be added. If None, an infinite padding is added. If a coordinate, this amount will be added to the ROI in the positive and negative direction.
mode (string):
One of ‘constant’ or ‘reflect’. Default is ‘constant’
value (scalar or
None
):The value to report inside the padding. If not given, 0 is used. Only used in case of ‘constant’ mode. Only used for
Array
.
RandomLocation
- class gunpowder.RandomLocation(min_masked=0, mask=None, ensure_nonempty=None, p_nonempty=1.0, ensure_centered=None, point_balance_radius=1, random_shift_key=None)[source]
Choses a batch at a random location in the bounding box of the upstream provider.
The random location is chosen such that the batch request ROI lies entirely inside the provider’s ROI.
If
min_masked
andmask
are set, only batches are returned that have at least the given ratio of masked-in voxels. This is in general faster than using theReject
node, at the expense of storing an integral array of the complete mask.If
ensure_nonempty
is set to aGraphKey
, only batches are returned that have at least one point of this point collection within the requested ROI.Additional tests for randomly picked locations can be implemented by subclassing and overwriting of
accepts()
. This method takes the randomly shifted request that meets all previous criteria (likemin_masked
andensure_nonempty
) and should returnTrue
if the request is acceptable.Args:
min_masked (
float
, optional):If non-zero, require that the random sample contains at least that ratio of masked-in voxels.
mask (
ArrayKey
, optional):The array to use for mask checks.
ensure_nonempty (
GraphKey
, optional):Ensures that when finding a random location, a request for
ensure_nonempty
will contain at least one point.p_nonempty (
float
, optional):If
ensure_nonempty
is set, it defines the probability that a request forensure_nonempty
will contain at least one point. Default value is 1.0.ensure_centered (
bool
, optional):if
ensure_nonempty
is set,ensure_centered
guarantees that the center voxel of the roi contains a point.point_balance_radius (
int
):if
ensure_nonempty
is set,point_balance_radius
defines a radius s.t. for every point p inensure_nonempty
, the probability of picking p is inversely related to the number of other points within a distance ofpoint_balance_radius
to p. This helps avoid oversampling of dense regions of the graph, and undersampling of sparse regions.random_shift_key (
ArrayKey
optional):if
random_shift_key
is not None, this node will populate that key with a nonspatial array containing the random shift used for each request. This can be useful for snapshot iterations if you want to figure out where that snapshot came from.
Reject
- class gunpowder.Reject(mask=None, min_masked=0.5, ensure_nonempty=None, reject_probability=1.0)[source]
Reject batches based on the masked-in vs. masked-out ratio.
If a pipeline also contains a
RandomLocation
node,Reject
needs to be placed downstream of it.Args:
mask (
ArrayKey
, optional):The mask to use, if any.
min_masked (
float
, optional):The minimal required ratio of masked-in vs. masked-out voxels. Defaults to 0.5.
ensure_nonempty (
GraphKey
, optional)Ensures there is at least one point in the batch.
reject_probability (
float
, optional):The probability by which a batch that is not valid (less than min_masked) is actually rejected. Defaults to 1., i.e. strict rejection.
SpecifiedLocation
- class gunpowder.SpecifiedLocation(locations, choose_randomly=False, extra_data=None, jitter=None, attempt_factor: int = 5)[source]
Choses a batch at a location from the list provided at init, making sure it is in the bounding box of the upstream provider.
Locations should be given in world units.
Locations will be chosen in order or at random from the list depending on the
choose_randomly
parameter.If a location requires a shift outside the bounding box of any upstream provider the module will skip that location with a warning.
Args:
locations (
list
of locations):Locations to center batches around.
choose_randomly (
bool
):Defines whether locations should be picked in order or at random from the list.
extra_data (
list
of array-like):A list of data that will be passed along with the arrays provided by this node. This data will be appended as an attribute to the dataset so it must be a data format compatible with hdf5.
jitter (
tuple
of int):How far to allow the point to shift in each direction. Default is None, which places the point in the center. Chooses uniformly from [loc - jitter, loc + jitter] in each direction.
attempt_factor (
int
):If choosing randomly then given n points, sample attempt_factor * n points at most before giving up and throwing an error.
IterateLocations
- class gunpowder.IterateLocations(graph, roi=None, node_id=None, choose_randomly=False)[source]
Iterates over the nodes in a graph and centers batches at their locations. The iteration is thread safe.
- Args:
graph (
GraphKey
): Key of graph to read nodes from
- roi (
Roi
): Roi within which to read and iterate over nodes.Defaults to None, which queries the whole Roi of the upstream graph source
- node_id (
ArrayKey
, optional): Nonspatial array key in which tostore the id of the “current” node in graph. Default is None, in which case no attribute is stored and there is no way to tell which node is being considered.
- choose_randomly (bool): If true, choose nodes randomly with
replacement. Default is false, which loops over the list.
Array Manipulation Nodes
Squeeze
Unsqueeze
Image Processing Nodes
DownSample
UpSample
IntensityScaleShift
Normalize
- class gunpowder.Normalize(array, factor=None, dtype=<class 'numpy.float32'>)[source]
Normalize the values of an array to be floats between 0 and 1, based on the type of the array.
Args:
array (
ArrayKey
):The key of the array to modify.
factor (scalar, optional):
The factor to use. If not given, a factor is chosen based on the
dtype
of the array (e.g.,np.uint8
would result in a factor of1.0/255
).dtype (data-type, optional):
The datatype of the normalized array. Defaults to
np.float32
.
Label Manipulation Nodes
AddAffinities
- class gunpowder.AddAffinities(affinity_neighborhood, labels, affinities, labels_mask=None, unlabelled=None, affinities_mask=None, dtype=<class 'numpy.uint8'>)[source]
Add an array with affinities for a given label array and neighborhood to the batch. Affinity values are created one for each voxel and entry in the neighborhood list, i.e., for each voxel and each neighbor of this voxel. Values are 1 iff both labels (of the voxel and the neighbor) are equal and non-zero.
Args:
affinity_neighborhood (
list
of array-like):List of offsets for the affinities to consider for each voxel.
labels (
ArrayKey
):The array to read the labels from.
affinities (
ArrayKey
):The array to generate containing the affinities.
labels_mask (
ArrayKey
, optional):The array to use as a mask for
labels
. Affinities connecting at least one masked out label will be masked out inaffinities_mask
. If not given,affinities_mask
will contain ones everywhere (if requested).unlabelled (
ArrayKey
, optional):A binary array to indicate unlabelled areas with 0. Affinities from labelled to unlabelled voxels are set to 0, affinities between unlabelled voxels are masked out (they will not be used for training).
affinities_mask (
ArrayKey
, optional):The array to generate containing the affinitiy mask, as derived from parameter
labels_mask
.
BalanceLabels
- class gunpowder.BalanceLabels(labels, scales, mask=None, slab=None, num_classes=2, clipmin=0.05, clipmax=0.95)[source]
Creates a scale array to balance the loss between class labels.
Note that this only balances loss weights per-batch and does not accumulate statistics about class balance across batches.
Args:
labels (
ArrayKey
):An array containing binary or integer labels.
scales (
ArrayKey
):A array with scales to be created. This new array will have the same ROI and resolution as
labels
.mask (
ArrayKey
, optional):An optional mask (or list of masks) to consider for balancing. Every voxel marked with a 0 will not contribute to the scaling and will have a scale of 0 in
scales
.slab (
tuple
ofint
, optional):A shape specification to perform the balancing in slabs of this size. -1 can be used to refer to the actual size of the label array. For example, a slab of:
(2, -1, -1, -1)will perform the balancing for every each slice
[0:2,:]
,[2:4,:]
, … individually.num_classes(
int
, optional):The number of classes. Labels will be expected to be in the interval [0,
num_classes
). Defaults to 2 for binary classification.clipmin (
float
, optional):Clip class fraction to clipmin when calculating class weights. Defaults to 0.05. Set to None if you do not want to clip min values.
clipmax (
float
, optional):Clip class fraction to clipmax when calculating class weights. Defaults to 0.95. Set to None, if you do not want to clip max values.
ExcludeLabels
- class gunpowder.ExcludeLabels(labels, exclude, ignore_mask=None, ignore_mask_erode=0, background_value=0)[source]
Excludes several labels from the ground-truth.
The labels will be replaced by background_value. An optional ignore mask will be created and set to 0 for the excluded locations that are further than a threshold away from not excluded locations.
Args:
labels (
ArrayKey
):The array containing the labels.
exclude (
list
ofint
):The labels to exclude from
labels
.ignore_mask (
ArrayKey
, optional):The ignore mask to create.
ignore_mask_erode (
float
, optional):By how much (in world units) to erode the ignore mask.
background_value (
int
, optional):Value to replace excluded IDs, defaults to 0.
GrowBoundary
- class gunpowder.GrowBoundary(labels, mask=None, steps=1, background=0, only_xy=False)[source]
Grow a boundary between regions in a label array. Does not grow at the border of the batch or an optionally provided mask.
Args:
labels (
ArrayKey
):The array containing labels.
mask (
ArrayKey
, optional):A mask indicating unknown regions. This is to avoid boundaries to grow between labelled and unknown regions.
steps (
int
, optional):Number of voxels (not world units!) to grow.
background (
int
, optional):The label to assign to the boundary voxels.
only_xy (
bool
, optional):Do not grow a boundary in the z direction.
RenumberConnectedComponents
Graph Processing Nodes
RasterizeGraph
- class gunpowder.RasterizeGraph(graph, array, array_spec=None, settings=None)[source]
Draw graphs into a binary array as balls/tubes of a given radius.
Args:
- graph (
GraphKey
):The key of the graph to rasterize.
- array (
ArrayKey
):The key of the binary array to create.
array_spec (
ArraySpec
, optional):The spec of the array to create. Use this to set the datatype and voxel size.
- settings (
RasterizationSettings
, optional):Which settings to use to rasterize the graph.
- class gunpowder.RasterizationSettings(radius, mode='ball', mask=None, inner_radius_fraction=None, fg_value=1, bg_value=0, edges=True, color_attr=None)[source]
Data structure to store parameters for rasterization of graph.
Args:
radius (
float
ortuple
offloat
):The radius (for balls or tubes) or sigma (for peaks) in world units.
mode (
string
):One of
ball
orpeak
. Ifball
(the default), a ball with the givenradius
will be drawn. Ifpeak
, the point will be rasterized as a peak with values \(\exp(-|x-p|^2/\sigma)\) with sigma set byradius
.mask (
ArrayKey
, optional):Used to mask the rasterization of points. The array is assumed to contain discrete labels. The object id at the specific point being rasterized is used to intersect the rasterization to keep it inside the specific object.
inner_radius_fraction (
float
, optional):Only for mode
ball
.If set, instead of a ball, a hollow sphere is rastered. The radius of the whole sphere corresponds to the radius specified with
radius
. This parameter sets the radius of the hollow area, as a fraction ofradius
.fg_value (
int
, optional):Only for mode
ball
.The value to use to rasterize points, defaults to 1.
bg_value (
int
, optional):Only for mode
ball
.The value to use to for the background in the output array, defaults to 0.
edges (
bool
, optional):Whether to rasterize edges by linearly interpolating between Nodes. Default is True.
color_attr (
str
, optional)Which graph attribute to use for coloring nodes and edges. One useful example might be component which would color your graph based on the component labels. Notes: - Only available in “ball” mode - Nodes and Edges missing the attribute will be skipped. - color_attr must be populated for nodes and edges upstream of this node
Provider Combination Nodes
MergeProvider
RandomProvider
- class gunpowder.RandomProvider(probabilities=None, random_provider_key=None)[source]
Randomly selects one of the upstream providers:
(a, b, c) + RandomProvider()will create a provider that randomly relays requests to providers
a
,b
, orc
. Array and point keys ofa
,b
, andc
should be the same.
- Args:
probabilities (1-D array-like, optional):
An optional list of probabilities for choosing upstream providers, given in the same order. Probabilities do not need to be normalized. Default is
None
, corresponding to equal probabilities.random_provider_key (
ArrayKey
):If provided, this node will store the index of the chosen random provider in a nonspatial array.
Training and Prediction Nodes
Stack
- class gunpowder.Stack(num_repetitions)[source]
Request several batches and stack them together, introducing a new dimension for each array. This is useful to create batches with several samples and only makes sense if there is a source of randomness upstream.
This node stacks only arrays, not points. The resulting batch will have the same point sets as found in the first batch requested upstream.
Args:
num_repetitions (
int
):How many upstream batches to stack.
torch.Train
- class gunpowder.torch.Train(model, loss, optimizer, inputs: Dict[str | int, ArrayKey], outputs: Dict[str | int, ArrayKey], loss_inputs: Dict[str | int, ArrayKey], gradients: Dict[str | int, ArrayKey] = {}, array_specs: Dict[ArrayKey, ArraySpec] | None = None, checkpoint_basename: str = 'model', save_every: int = 2000, log_dir: str | None = None, log_every: int = 1, spawn_subprocess: bool = False, device: str = 'cuda')[source]
Torch implementation of
gunpowder.nodes.GenericTrain
.Args:
model (subclass of
torch.nn.Module
):The model to train.
loss:
The torch loss to use.
optimizer:
The torch optimizer to use.
inputs (
dict
,string
orint
->ArrayKey
):Dictionary from the names of input tensors (argument names of the
forward
method) in the model to array keys.loss_inputs (
dict
,string
orint
->ArrayKey
):Dictionary with the names of input variables to the loss function as keys, and ArrayKeys containing the desired data as values. Keys can be either strings or integers. If the key is an integer, it will be treated as a positional argument to the loss function, a string will be used as a named argument
outputs (
dict
,string
orint
->ArrayKey
):Dictionary from the names of tensors in the network to array keys. If the key is a string, the tensor will be retrieved by checking the model for an attribute with they key as its name. If the key is an integer, it is interpreted as a tuple index of the outputs of the network. New arrays will be generated by this node for each entry (if requested downstream).
gradients (
dict
,string
orint
->ArrayKey
, optional):Dictionary from the names of tensors in the network to array keys. If the key is a string, the tensor will be retrieved by checking the model for an attribute with they key as its name. If the key is an integer, it is interpreted as a tuple index of the outputs of the network. Instead of the actual array, the gradient of the array with respect to the loss will be generated and saved.
array_specs (
dict
,ArrayKey
->ArraySpec
, optional):Used to set the specs of generated arrays (at the moment only
output
). This is useful to set thevoxel_size
, for example, if they differ from the voxel size of the input arrays. Only fields that are notNone
in the givenArraySpec
will be used.checkpoint_basename (
string
, optional):The basename used for checkpoint files. Defaults to
model
.save_every (
int
, optional):After how many iterations to create a checkpoint to store the learnt weights.
log_dir (
string
, optional):Directory for saving tensorboard summaries.
log_every (
int
, optional):After how many iterations to write out tensorboard summaries.
spawn_subprocess (
bool
, optional):Whether to run the
train_step
in a separate process. Default is false.device (
str
, optional):Accepts a cuda gpu specifically to train on (e.g. cuda:1, cuda:2), helps in multi-card systems. defaults to
cuda
torch.Predict
- class gunpowder.torch.Predict(model, inputs: Dict[str | int, ArrayKey], outputs: Dict[str | int, ArrayKey], array_specs: Dict[ArrayKey, ArraySpec] | None = None, checkpoint: str | None = None, device='cuda', spawn_subprocess=False)[source]
Torch implementation of
gunpowder.nodes.Predict
.Args:
model (subclass of
torch.nn.Module
):The model to use for prediction.
inputs (
dict
,string
orint
->ArrayKey
):Dictionary from the position (for args) and names (for kwargs) of input tensors (argument names of the
forward
method) in the model to array keys.outputs (
dict
,string
orint
->ArrayKey
):Dictionary from the names of tensors in the network to array keys. If the key is a string, the tensor will be retrieved by checking the model for an attribute with the key as its name. If the key is an integer, it is interpreted as a tuple index of the outputs of the network. New arrays will be generated by this node for each entry (if requested downstream).
array_specs (
dict
,ArrayKey
->ArraySpec
, optional):Used to set the specs of generated arrays (
outputs
). This is useful to set thevoxel_size
, for example, if they differ from the voxel size of the input arrays. Only fields that are notNone
in the givenArraySpec
will be used.checkpoint: (
string
, optional):An optional path to the saved parameters for your torch module. These will be loaded and used for prediction if provided.
device (
string
, optional):Which device to use for prediction (
"cpu"
or"cuda"
). Default is"cuda"
, which falls back to CPU if CUDA is not available.
- spawn_subprocess (bool, optional): Whether to run
predict
in aseparate process. Default is false.
tensorflow.Train
- class gunpowder.tensorflow.Train(graph, optimizer, loss, inputs, outputs, gradients, summary=None, array_specs=None, save_every=2000, log_dir='./', log_every=1)[source]
Tensorflow implementation of
gunpowder.nodes.Train
.Args:
graph (
string
):Filename of a tensorflow meta-graph storing the tensorflow graph containing an optimizer. A meta-graph file can be created by running:
# create tensorflow graph ... # store it tf.train.export_meta_graph(filename='...')optimizer (
string
or function):Either the name of the tensorflow operator performing a training iteration, or a function that, given the graph of the meta-graph file, adds a custom loss and optimizer.
If a function is given, it should return a tuple
(loss, optimizer)
of a tensor and an operator representing the loss and the optimizer, respectively. In this case, parameterloss
should beNone
.Example:
def add_custom_optimizer(graph): # get the output of your graph output = graph.get_tensor_by_name('...') # create your custom loss loss = custom_loss(output) # add an optimizer of your choice optimizer = tf.train.AdamOptimizer().minimize(loss) return (loss, optimizer)loss (
string
orNone
):The name of the tensorflow tensor containing the loss, or
None
ifoptimizer
is a function.inputs (
dict
,string
->ArrayKey
):Dictionary from the names of input tensors in the network to array keys.
outputs (
dict
,string
->ArrayKey
):Dictionary from the names of output tensors in the network to array keys. New arrays will be generated by this node for each entry (if requested downstream).
gradients (
dict
,string
->ArrayKey
):Dictionary from the names of output tensors in the network to array keys. New arrays containing the gradient of an output with respect to the loss will be generated by this node for each entry (if requested downstream).
- summary (
string
or
dict
,string
-> (string
(tensor name), freq), optional):The name of the tensorflow tensor containing the tensorboard summaries or dictionary for different subcategories of summaires (key: string, value: tuple with tensor/op name and frequency, of evaluation).
array_specs (
dict
,ArrayKey
->ArraySpec
, optional):Used to set the specs of generated arrays (
outputs
). This is useful to set thevoxel_size
, for example, if they differ from the voxel size of the input arrays. Only fields that are notNone
in the givenArraySpec
will be used.save_every (
int
, optional):After how many iterations to create a checkpoint to store the learnt weights.
log_dir (
string
, optional):Directory for saving tensorboard summaries.
log_every (
int
, optional):After how many iterations to write out tensorboard summaries.
tensorflow.Predict
- class gunpowder.tensorflow.Predict(checkpoint, inputs, outputs, array_specs=None, graph=None, skip_empty=False, max_shared_memory=1073741824)[source]
Tensorflow implementation of
gunpowder.nodes.Predict
.Args:
checkpoint (
string
):Basename of a tensorflow checkpoint storing the tensorflow graph and associated tensor values and metadata, as created by
gunpowder.nodes.Train
, for example.inputs (
dict
,string
->ArrayKey
):Dictionary from the names of input tensors in the network to array keys.
outputs (
dict
,string
->ArrayKey
):Dictionary from the names of output tensors in the network to array keys. New arrays will be generated by this node for each entry (if requested downstream).
array_specs (
dict
,ArrayKey
->ArraySpec
, optional):Used to set the specs of generated arrays (
outputs
). This is useful to set thevoxel_size
, for example, if they differ from the voxel size of the input arrays. Only fields that are notNone
in the givenArraySpec
will be used.graph: (
string
, optional):An optional path to a tensorflow computation graph that should be used for prediction. The checkpoint is used to restore the values of matching variable names in the graph. Note that the graph specified here can differ from the one associated to the checkpoint.
skip_empty (
bool
, optional):Skip prediction, if all inputs are empty (contain only 0). In this case, outputs are simply set to 0.
max_shared_memory (
int
, optional):The maximal amount of shared memory in bytes to allocate to send batches to the GPU processes. Defaults to 1GB.
jax.Train
- class gunpowder.jax.Train(model: GenericJaxModel, inputs: Dict[str, ndarray | ArrayKey], outputs: Dict[str | int, ArrayKey], gradients: Dict[str | int, ArrayKey] = {}, array_specs: Dict[ArrayKey, ArraySpec] | None = None, checkpoint_basename: str = 'model', save_every: int = 2000, keep_n_checkpoints: int | None = None, log_dir: str | None = None, log_every: int = 1, spawn_subprocess: bool = False, n_devices: int | None = None, validate_fn=None, validate_every=None)[source]
JAX implementation of
gunpowder.nodes.GenericTrain
.Args:
model (subclass of
gunpowder.jax.GenericJaxModel
):The model to train. This model encapsulates the forward model, loss, and optimizer.
inputs (
dict
,string
-> Union[np.ndarray, ArrayKey]):Dictionary from the names of input tensors expected by the
train_step
method to array keys or ndarray.outputs (
dict
,string
->ArrayKey
):Dictionary from the names of tensors in the network to array keys. If the key is a string, the tensor will be retrieved by checking the model for an attribute with they key as its name. If the key is an integer, it is interpreted as a tuple index of the outputs of the network. New arrays will be generated by this node for each entry (if requested downstream).
array_specs (
dict
,ArrayKey
->ArraySpec
, optional):Used to set the specs of generated arrays (at the moment only
output
). This is useful to set thevoxel_size
, for example, if they differ from the voxel size of the input arrays. Only fields that are notNone
in the givenArraySpec
will be used.checkpoint_basename (
string
, optional):The basename used for checkpoint files. Defaults to
model
.save_every (
int
, optional):After how many iterations to create a checkpoint to store the learnt weights.
keep_n_checkpoints (
int
, optional):Number of checkpoints to keep. Node will attempt to delete older checkpoints. Default is None (no deletion).
log_dir (
string
, optional):Directory for saving tensorboard summaries.
log_every (
int
, optional):After how many iterations to write out tensorboard summaries.
spawn_subprocess (
bool
, optional):Whether to run the
train_step
in a separate process. Default is false.n_devices (
int
, optional):Number of GPU devices to train on concurrently using jax.pmap. If None, the number of available GPUs will be automatically detected and used.
validate_fn (function -> Union[
float
, (dict
,string
->float
)] , optional):Function to run validation on, which should has the form of
def validate_fn(model, params)
where model is the same provided GenericJaxModel model and params is the parameter of this model, and returns either a
float
(one loss) or a dictionary of losses to record in tensorboard.validate_every (
int
, optional):After how many iterations to run validate_fn.
jax.Predict
- class gunpowder.jax.Predict(model: GenericJaxModel, inputs: Dict[str, ArrayKey], outputs: Dict[str | int, ArrayKey], array_specs: Dict[ArrayKey, ArraySpec] | None = None, checkpoint: str | None = None, spawn_subprocess=False)[source]
JAX implementation of
gunpowder.nodes.Predict
.Args:
model (subclass of
gunpowder.jax.GenericJaxModel
):The model to use for prediction.
inputs (
dict
,string
->ArrayKey
):Dictionary from the names of input tensors in the network to array keys.
outputs (
dict
,string
->ArrayKey
):Dictionary from the names of output tensors in the network to array keys. New arrays will be generated by this node for each entry (if requested downstream).
array_specs (
dict
,ArrayKey
->ArraySpec
, optional):Used to set the specs of generated arrays (
outputs
). This is useful to set thevoxel_size
, for example, if they differ from the voxel size of the input arrays. Only fields that are notNone
in the givenArraySpec
will be used.checkpoint: (
string
, optional):An optional path to the saved parameters for your jax module. These will be loaded and used for prediction if provided.
- spawn_subprocess (bool, optional): Whether to run
predict
in aseparate process. Default is false.
Output Nodes
Hdf5Write
- class gunpowder.Hdf5Write(dataset_names, output_dir='.', output_filename='output.hdf', compression_type=None, dataset_dtypes=None)[source]
Assemble arrays of passing batches in one HDF5 file. This is useful to store chunks produced by
Scan
on disk without keeping the larger array in memory. The ROIs of the passing arrays will be used to determine the position where to store the data in the dataset.Args:
dataset_names (
dict
,ArrayKey
->string
):A dictionary from array keys to names of the datasets to store them in.
output_dir (
string
):The directory to save the HDF5 file. Will be created, if it does not exist.
output_filename (
string
):The output filename of the container. Will be created, if it does not exist, otherwise data is overwritten in the existing container.
compression_type (
string
orint
):Compression strategy. Legal values are
gzip
,szip
,lzf
. If an integer between 1 and 10, this indicatesgzip
compression level.dataset_dtypes (
dict
,ArrayKey
-> data type):A dictionary from array keys to datatype (eg.
np.int8
). If given, arrays are stored using this type. The original arrays within the pipeline remain unchanged.
ZarrWrite
- class gunpowder.ZarrWrite(dataset_names, output_dir='.', output_filename='output.hdf', compression_type=None, dataset_dtypes=None, store: BaseStore | MutableMapping | str = None)[source]
Assemble arrays of passing batches in one zarr container. This is useful to store chunks produced by
Scan
on disk without keeping the larger array in memory. The ROIs of the passing arrays will be used to determine the position where to store the data in the dataset.Args:
dataset_names (
dict
,ArrayKey
->string
):A dictionary from array keys to names of the datasets to store them in.
store (
string
orBaseStore
):The directory to save the zarr container. Will be created, if it does not exist.
compression_type (
string
orint
):Compression strategy. Legal values are
gzip
,szip
,lzf
. If an integer between 1 and 10, this indicatesgzip
compression level.dataset_dtypes (
dict
,ArrayKey
-> data type):A dictionary from array keys to datatype (eg.
np.int8
). If given, arrays are stored using this type. The original arrays within the pipeline remain unchanged.
Snapshot
- class gunpowder.Snapshot(dataset_names, output_dir='snapshots', output_filename='{iteration}.zarr', every=1, additional_request=None, compression_type=None, dataset_dtypes=None, store_value_range=False)[source]
Save a passing batch in an HDF or Zarr file.
The default behaviour is to periodically save a snapshot after
every
iterations.Data-dependent criteria for saving can be implemented by subclassing and overwriting
write_if()
. This method is applied as an additional filter to the batches picked for periodic saving. It should returnTrue
if a batch meets the criteria for saving.Args:
dataset_names (
dict
,ArrayKey
->string
):A dictionary from array keys to names of the datasets to store them in.
output_dir (
string
):The directory to save the snapshots. Will be created, if it does not exist.
output_filename (
string
):Template for output filenames.
{id}
in the string will be replaced with the ID of the batch.{iteration}
with the training iteration (if training was performed on this batch). Snapshot will be saved as zarr file if output_filename ends in.zarr
and as HDF otherwise.every (
int
):How often to save a batch.
every=1
indicates that every batch will be stored,every=2
every second and so on. By default, every batch will be stored.additional_request (
BatchRequest
):An additional batch request to merge with the passing request, if a snapshot is to be made. If not given, only the arrays that are in the batch anyway are recorded. This is useful to request additional arrays like loss gradients for visualization that are otherwise not needed.
compression_type (
string
orint
):Compression strategy. Legal values are
gzip
,szip
,lzf
. If an integer between 1 and 10, this indicatesgzip
compression level.dataset_dtypes (
dict
,ArrayKey
-> data type):A dictionary from array keys to datatype (eg.
np.int8
). If given, arrays are stored using this type. The original arrays within the pipeline remain unchanged.store_value_range (
bool
):If set to
True
, store range of values in data set attributes.
Performance Nodes
PreCache
- class gunpowder.PreCache(cache_size=50, num_workers=20)[source]
Pre-cache repeated equal batch requests. For the first of a series of equal batch request, a set of workers is spawned to pre-cache the batches in parallel processes. This way, subsequent requests can be served quickly.
A note on changing the requests sent to PreCache. Given requests A and B, if requests are sent in the sequence: A, …, A, B, A, …, A, B, A, … Precache will build a Queue of batches that satisfy A, and handle requests B on demand. This prevents PreCache from discarding the queue on every SnapshotRequest. However if B request replace A as the most common request, i.e.: A, A, A, …, A, B, B, B, …, PreCache will discard the A queue and build a B queue after it has seen more B requests than A requests out of the last 5 requests.
This node only makes sense if:
Incoming batch requests are repeatedly the same.
There is a source of randomness in upstream nodes.
Args:
cache_size (
int
):How many batches to hold at most in the cache.
num_workers (
int
):How many processes to spawn to fill the cache.
PrintProfilingStats
- class gunpowder.PrintProfilingStats(every=1)[source]
Print profiling information about nodes upstream of this node in the DAG.
The output also includes a
TOTAL
section, which shows the wall-time spent in the upstream and downstream passes. For the downstream pass, this information is not available in the first iteration, since the request-batch cycle is not completed, yet.Args:
every (
int
):Collect statistics about that many batch requests and show min, max, mean, and median runtimes.
Iterative Processing Nodes
Scan
- class gunpowder.Scan(reference, num_workers=1, cache_size=50, progress_callback=None)[source]
Iteratively requests batches of size
reference
from upstream providers in a scanning fashion, until all requested ROIs are covered. If the batch request to this node is empty, it will scan the complete upstream ROIs (and return nothing). Otherwise, it scans only the requested ROIs and returns a batch assembled of the smaller requests. In either case, the upstream requests will be contained in the downstream requested ROI or upstream ROIs.See also
Hdf5Write
.Args:
reference (
BatchRequest
):A reference
BatchRequest
. This request will be shifted in a scanning fashion over the upstream ROIs of the requested arrays or points.num_workers (
int
, optional):If set to >1, upstream requests are made in parallel with that number of workers.
cache_size (
int
, optional):If multiple workers are used, how many batches to hold at most.
progress_callback (class:ScanCallback, optional):
A callback instance to get updated from this node while processing chunks. See
ScanCallback
for details. The default is a callback that shows atqdm
progress bar.
DaisyRequestBlocks
- class gunpowder.DaisyRequestBlocks(reference, roi_map, num_workers=1, block_done_callback=None)[source]
Iteratively requests batches similar to
reference
from upstream providers, with their ROIs set to blocks distributed bydaisy
.The ROIs of the array or point specs in the reference can be set to either the block’s
read_roi
orwrite_roi
, see parameterroi_map
.The batch request to this node has to be empty, as there is no guarantee that this node will get to process all chunks required to fulfill a particular batch request.
Args:
reference (
BatchRequest
):A reference
BatchRequest
. This request will be shifted according to blocks distributed bydaisy
.roi_map (
dict
fromArrayKey
orGraphKey
tostring
):A map indicating which daisy block ROI (
read_roi
orwrite_roi
) to use for which item in the reference request.num_workers (
int
, optional):If set to >1, upstream requests are made in parallel with that number of workers.
block_done_callback (function, optional):
If given, will be called with arguments
(block, start, duration)
for each block that was processed.start
andduration
will be given in seconds, as instart = time.time()
andduration = time.time() - start
, right before and after a block gets processed.This callback can be used to log blocks that have successfully finished processing, which can be used in
check_function
ofdaisy.run_blockwise
to skip already processed blocks in repeated runs.