import logging
import os
import numpy as np
from gunpowder.array import ArrayKey, Array
from gunpowder.ext import tensorflow as tf
from gunpowder.nodes.generic_train import GenericTrain
from gunpowder.tensorflow.local_server import LocalServer
logger = logging.getLogger(__name__)
[docs]
class Train(GenericTrain):
"""Tensorflow implementation of :class:`gunpowder.nodes.Train`.
Args:
graph (``string``):
Filename of a tensorflow meta-graph storing the tensorflow graph
containing an optimizer. A meta-graph file can be created by
running::
# create tensorflow graph
...
# store it
tf.train.export_meta_graph(filename='...')
optimizer (``string`` or function):
Either the name of the tensorflow operator performing a training
iteration, or a function that, given the graph of the meta-graph
file, adds a custom loss and optimizer.
If a function is given, it should return a tuple ``(loss,
optimizer)`` of a tensor and an operator representing the loss and
the optimizer, respectively. In this case, parameter ``loss``
should be ``None``.
Example::
def add_custom_optimizer(graph):
# get the output of your graph
output = graph.get_tensor_by_name('...')
# create your custom loss
loss = custom_loss(output)
# add an optimizer of your choice
optimizer = tf.train.AdamOptimizer().minimize(loss)
return (loss, optimizer)
loss (``string`` or ``None``):
The name of the tensorflow tensor containing the loss, or ``None``
if ``optimizer`` is a function.
inputs (``dict``, ``string`` -> :class:`ArrayKey`):
Dictionary from the names of input tensors in the network to
array keys.
outputs (``dict``, ``string`` -> :class:`ArrayKey`):
Dictionary from the names of output tensors in the network to array
keys. New arrays will be generated by this node for each entry (if
requested downstream).
gradients (``dict``, ``string`` -> :class:`ArrayKey`):
Dictionary from the names of output tensors in the network to
array keys. New arrays containing the gradient of an output with
respect to the loss will be generated by this node for each entry
(if requested downstream).
summary (``string`` or
``dict``, ``string`` -> (``string`` (tensor name), freq),
optional):
The name of the tensorflow tensor containing the tensorboard
summaries or dictionary for different subcategories of summaires
(key: string, value: tuple with tensor/op name and frequency,
of evaluation).
array_specs (``dict``, :class:`ArrayKey` -> :class:`ArraySpec`, optional):
Used to set the specs of generated arrays (``outputs``). This is
useful to set the ``voxel_size``, for example, if they differ from
the voxel size of the input arrays. Only fields that are not
``None`` in the given :class:`ArraySpec` will be used.
save_every (``int``, optional):
After how many iterations to create a checkpoint to store the
learnt weights.
log_dir (``string``, optional):
Directory for saving tensorboard summaries.
log_every (``int``, optional):
After how many iterations to write out tensorboard summaries.
"""
def __init__(
self,
graph,
optimizer,
loss,
inputs,
outputs,
gradients,
summary=None,
array_specs=None,
save_every=2000,
log_dir="./",
log_every=1,
):
super(Train, self).__init__(
inputs, outputs, gradients, array_specs, spawn_subprocess=False
)
self.meta_graph_filename = graph
self.optimizer_func = None
self.optimizer_loss_names = None
self.optimizer = None
self.loss = None
self.summary = summary
self.session = None
self.tf_gradient = {}
self.graph = None
self.basic_saver = None
self.full_saver = None
self.save_every = save_every
self.iteration = None
self.iteration_increment = None
self.summary_saver = None
self.log_dir = log_dir
self.log_every = log_every
if isinstance(optimizer, str):
self.optimizer_loss_names = (optimizer, loss)
else:
self.optimizer_func = optimizer
# at least for some versions of tensorflow, the checkpoint name has to
# start with a . if it is a relative path
if not os.path.isabs(self.meta_graph_filename):
self.meta_graph_filename = os.path.join(".", self.meta_graph_filename)
def start(self):
target = LocalServer.get_target()
logger.info("Initializing tf session, connecting to %s...", target)
self.graph = tf.Graph()
self.session = tf.Session(target=target, graph=self.graph)
with self.graph.as_default():
self.__read_meta_graph()
if self.summary is not None:
self.summary_saver = tf.summary.FileWriter(self.log_dir, self.graph)
if self.optimizer_func is None:
# get actual operations/tensors from names
self.optimizer = self.graph.get_operation_by_name(
self.optimizer_loss_names[0]
)
self.loss = self.graph.get_tensor_by_name(self.optimizer_loss_names[1])
# add symbolic gradients
for tensor_name in self.gradients:
tensor = self.graph.get_tensor_by_name(tensor_name)
self.tf_gradient[tensor_name] = tf.gradients(self.loss, [tensor])[0]
def train_step(self, batch, request):
array_outputs = self.__collect_requested_outputs(request)
inputs = self.__collect_provided_inputs(batch)
to_compute = {
"optimizer": self.optimizer,
"loss": self.loss,
"iteration": self.iteration_increment,
}
to_compute.update(array_outputs)
# compute outputs, gradients, and update variables
if isinstance(self.summary, str):
to_compute["summaries"] = self.summary
elif isinstance(self.summary, dict):
for k, (v, f) in self.summary.items():
if int(self.current_step + 1) % f == 0:
to_compute[k] = v
outputs = self.session.run(to_compute, feed_dict=inputs)
for array_key in array_outputs:
spec = self.spec[array_key].copy()
spec.roi = request[array_key].roi
batch.arrays[array_key] = Array(outputs[array_key], spec)
batch.loss = outputs["loss"]
batch.iteration = outputs["iteration"][0]
self.current_step = batch.iteration
if self.summary is not None:
if isinstance(self.summary, str) and (
batch.iteration % self.log_every == 0 or batch.iteration == 1
):
self.summary_saver.add_summary(outputs["summaries"], batch.iteration)
else:
for k, (_, f) in self.summary.items():
if int(self.current_step) % f == 0:
self.summary_saver.add_summary(outputs[k], batch.iteration)
if batch.iteration % self.save_every == 0:
checkpoint_name = (
self.meta_graph_filename + "_checkpoint_%i" % batch.iteration
)
logger.info("Creating checkpoint %s", checkpoint_name)
self.full_saver.save(self.session, checkpoint_name)
def stop(self):
if self.session is not None:
self.optimizer = None
self.loss = None
if self.summary is not None:
self.summary_saver.close()
self.session.close()
self.graph = None
self.session = None
def __read_meta_graph(self):
logger.info("Reading meta-graph...")
# read the original meta-graph
tf.train.import_meta_graph(
self.meta_graph_filename + ".meta", clear_devices=True
)
# add custom gunpowder variables
with tf.variable_scope("gunpowder"):
self.iteration = tf.get_variable(
"iteration", shape=1, initializer=tf.zeros_initializer, trainable=False
)
self.iteration_increment = tf.assign(self.iteration, self.iteration + 1)
# Until now, only variables have been added to the graph that are part
# of every checkpoint. We create a 'basic_saver' for only those
# variables.
self.basic_saver = tf.train.Saver(max_to_keep=None)
# Add custom optimizer and loss, if requested. This potentially adds
# more variables, not covered by the basic_saver.
if self.optimizer_func is not None:
loss, optimizer = self.optimizer_func(self.graph)
self.loss = loss
self.optimizer = optimizer
# We create a 'full_saver' including those variables.
self.full_saver = tf.train.Saver(max_to_keep=None)
# find most recent checkpoint
checkpoint_dir = os.path.dirname(self.meta_graph_filename)
checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if checkpoint:
try:
# Try to restore the graph, including the custom optimizer
# state (if a custom optimizer was used).
self.__restore_graph(checkpoint, restore_full=True)
except tf.errors.NotFoundError:
# If that failed, we just transitioned from an earlier training
# without the custom optimizer. In this case, restore only the
# variables of the original meta-graph and 'gunpowder'
# variables. Custom optimizer variables will be default
# initialized.
logger.info("Checkpoint did not contain custom optimizer " "variables")
self.__restore_graph(checkpoint, restore_full=False)
else:
logger.info("No checkpoint found")
# initialize all variables
self.session.run(tf.global_variables_initializer())
def __restore_graph(self, checkpoint, restore_full):
logger.info("Restoring model from %s", checkpoint)
if restore_full:
logger.info("...using a saver for all variables")
self.full_saver.restore(self.session, checkpoint)
else:
# initialize all variables, such that non-basic variables are
# initialized
self.session.run(tf.global_variables_initializer())
logger.info("...using a saver for basic variables only")
self.basic_saver.restore(self.session, checkpoint)
self.current_step = self.session.run(self.iteration)
def __collect_requested_outputs(self, request):
array_outputs = {}
for output_name, array_key in self.outputs.items():
if array_key in request:
array_outputs[array_key] = output_name
for output_name, array_key in self.gradients.items():
if array_key in request:
array_outputs[array_key] = self.tf_gradient[output_name]
return array_outputs
def __collect_provided_inputs(self, batch):
inputs = {}
for input_name, input_key in self.inputs.items():
if isinstance(input_key, ArrayKey):
if input_key in batch.arrays:
inputs[input_name] = batch.arrays[input_key].data
else:
logger.warn(
"batch does not contain %s, input %s will not " "be set",
input_key,
input_name,
)
elif isinstance(input_key, np.ndarray):
inputs[input_name] = input_key
elif isinstance(input_key, str):
inputs[input_name] = getattr(batch, input_key)
else:
raise Exception(
"Unknown network input key {}, can't be given to "
"network".format(input_key)
)
return inputs