import logging
import numpy as np
from gunpowder.array import ArrayKey, Array
from gunpowder.array_spec import ArraySpec
from gunpowder.ext import torch, tensorboardX, NoSuchModule
from gunpowder.nodes.generic_train import GenericTrain
from typing import Dict, Union, Optional, Any
import itertools
logger = logging.getLogger(__name__)
[docs]
class Train(GenericTrain):
"""Torch implementation of :class:`gunpowder.nodes.GenericTrain`.
Args:
model (subclass of ``torch.nn.Module``):
The model to train.
loss:
The torch loss to use.
optimizer:
The torch optimizer to use.
inputs (``dict``, ``string`` or ``int`` -> :class:`ArrayKey`):
Dictionary from the names of input tensors (argument names of the
``forward`` method) in the model to array keys.
loss_inputs (``dict``, ``string`` or ``int`` -> :class:`ArrayKey`):
Dictionary with the names of input variables to the loss function as
keys, and ArrayKeys containing the desired data as values. Keys can
be either strings or integers. If the key is an integer, it will
be treated as a positional argument to the loss function, a
string will be used as a named argument
outputs (``dict``, ``string`` or ``int`` -> :class:`ArrayKey`):
Dictionary from the names of tensors in the network to array
keys. If the key is a string, the tensor will be retrieved
by checking the model for an attribute with they key as its name.
If the key is an integer, it is interpreted as a tuple index of
the outputs of the network.
New arrays will be generated by this node for each entry (if
requested downstream).
gradients (``dict``, ``string`` or ``int`` -> :class:`ArrayKey`, optional):
Dictionary from the names of tensors in the network to array
keys. If the key is a string, the tensor will be retrieved
by checking the model for an attribute with they key as its name.
If the key is an integer, it is interpreted as a tuple index of
the outputs of the network.
Instead of the actual array, the gradient of the array with respect
to the loss will be generated and saved.
array_specs (``dict``, :class:`ArrayKey` -> :class:`ArraySpec`, optional):
Used to set the specs of generated arrays (at the moment only
``output``). This is useful to set the ``voxel_size``, for example,
if they differ from the voxel size of the input arrays. Only fields
that are not ``None`` in the given :class:`ArraySpec` will be used.
checkpoint_basename (``string``, optional):
The basename used for checkpoint files. Defaults to ``model``.
save_every (``int``, optional):
After how many iterations to create a checkpoint to store the
learnt weights.
log_dir (``string``, optional):
Directory for saving tensorboard summaries.
log_every (``int``, optional):
After how many iterations to write out tensorboard summaries.
spawn_subprocess (``bool``, optional):
Whether to run the ``train_step`` in a separate process. Default is false.
device (``str``, optional):
Accepts a cuda gpu specifically to train on (e.g. `cuda:1`, `cuda:2`), helps in multi-card systems.
defaults to ``cuda``
"""
def __init__(
self,
model,
loss,
optimizer,
inputs: Dict[Union[str, int], ArrayKey],
outputs: Dict[Union[int, str], ArrayKey],
loss_inputs: Dict[Union[int, str], ArrayKey],
gradients: Dict[Union[int, str], ArrayKey] = {},
array_specs: Optional[Dict[ArrayKey, ArraySpec]] = None,
checkpoint_basename: str = "model",
save_every: int = 2000,
log_dir: Optional[str] = None,
log_every: int = 1,
spawn_subprocess: bool = False,
device: str = "cuda",
):
if not model.training:
logger.warning(
"Model is in evaluation mode during training. "
"Consider using model.train()"
)
# not yet implemented
gradients = gradients
loss_inputs = {f"loss_{k}": v for k, v in loss_inputs.items()}
all_inputs: dict[str | int, Any] = {
f"{k}": v for k, v in inputs.items() if v not in outputs.values()
}
all_inputs.update(
{k: v for k, v in loss_inputs.items() if v not in outputs.values()}
)
super(Train, self).__init__(
all_inputs,
outputs,
gradients,
array_specs,
spawn_subprocess=spawn_subprocess,
)
self.model = model
self.loss = loss
self.optimizer = optimizer
self.loss_inputs = loss_inputs
self.checkpoint_basename = checkpoint_basename
self.save_every = save_every
self.dev = device
self.iteration = 0
if not isinstance(tensorboardX, NoSuchModule) and log_dir is not None:
self.summary_writer = tensorboardX.SummaryWriter(log_dir)
self.log_every = log_every
else:
self.summary_writer = None
if log_dir is not None:
logger.warning("log_dir given, but tensorboardX is not installed")
self.intermediate_layers: dict[ArrayKey, Any] = {}
self.register_hooks()
def register_hooks(self):
for key in self.outputs:
if isinstance(key, str):
layer = getattr(self.model, key)
layer.register_forward_hook(self.create_hook(key))
def create_hook(self, key):
def save_layer(module, input, output):
self.intermediate_layers[key] = output
return save_layer
def retain_gradients(self, request, outputs):
for array_name, array_key in self.gradients.items():
if array_key not in request:
continue
if isinstance(array_name, int):
tensor = outputs[array_name]
elif isinstance(array_name, str):
tensor = getattr(self.model, array_name)
else:
raise RuntimeError(
"only ints and strings are supported as gradients keys"
)
tensor.retain_grad()
def start(self):
self.use_cuda = torch.cuda.is_available()
# Issue: #188
self.device = torch.device(self.dev if self.use_cuda else "cpu")
try:
self.model = self.model.to(self.device)
except RuntimeError as e:
raise RuntimeError(
"Failed to move model to device. If you are using a child process "
"to run your model, maybe you already initialized CUDA by sending "
"your model to device in the main process."
) from e
if isinstance(self.loss, torch.nn.Module):
self.loss = self.loss.to(self.device)
checkpoint, self.iteration = self._get_latest_checkpoint(
self.checkpoint_basename
)
if checkpoint is not None:
logger.info("Resuming training from iteration %d", self.iteration)
logger.info("Loading %s", checkpoint)
checkpoint = torch.load(checkpoint, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
else:
logger.info("Starting training from scratch")
logger.info("Using device %s", self.device)
def train_step(self, batch, request):
inputs = self.__collect_provided_inputs(batch)
inputs = {k: torch.as_tensor(v, device=self.device) for k, v in inputs.items()}
requested_outputs = self.__collect_requested_outputs(request)
# keys are argument names of model forward pass
device_input_args = []
for i in range(len(inputs)):
key = f"{i}"
if key in inputs:
device_input_args.append(inputs.pop(key))
else:
break
device_input_kwargs = {k: v for k, v in inputs.items() if isinstance(k, str)}
# get outputs. Keys are tuple indices or model attr names as in self.outputs
self.optimizer.zero_grad()
model_outputs = self.model(*device_input_args, **device_input_kwargs)
if isinstance(model_outputs, tuple):
outputs = {i: model_outputs[i] for i in range(len(model_outputs))}
elif isinstance(model_outputs, torch.Tensor):
outputs = {0: model_outputs}
else:
raise RuntimeError(
"Torch train node only supports return types of tuple",
f"and torch.Tensor from model.forward(). not {type(model_outputs)}",
)
outputs.update(self.intermediate_layers)
# Some inputs to the loss should come from the batch, not the model
provided_loss_inputs = self.__collect_provided_loss_inputs(batch)
device_loss_inputs = {
k: torch.as_tensor(v, device=self.device)
for k, v in provided_loss_inputs.items()
}
# Some inputs to the loss function should come from the outputs of the model
# Update device loss inputs with tensors from outputs if available
flipped_outputs = {v: outputs[k] for k, v in self.outputs.items()}
device_loss_inputs = {
k: flipped_outputs.get(v, device_loss_inputs.get(k))
for k, v in self.loss_inputs.items()
}
device_loss_args = []
for i in range(len(device_loss_inputs)):
key = f"loss_{i}"
if key in device_loss_inputs:
device_loss_args.append(device_loss_inputs.pop(key))
else:
break
device_loss_kwargs = {}
for k, v in list(device_loss_inputs.items()):
if isinstance(k, str):
device_loss_kwargs[k] = device_loss_inputs.pop(k)
assert (
len(device_loss_inputs) == 0
), f"Not all loss inputs could be interpreted. Failed keys: {device_loss_inputs.keys()}"
self.retain_gradients(request, outputs)
logger.debug("model outputs: %s", {k: v.shape for k, v in outputs.items()})
logger.debug(
"loss inputs: %s %s",
[v.shape for v in device_loss_args],
{k: v.shape for k, v in device_loss_kwargs.items()},
)
loss = self.loss(*device_loss_args, **device_loss_kwargs)
loss.backward()
self.optimizer.step()
# add requested model outputs to batch
for array_key, array_name in requested_outputs.items():
spec = self.spec[array_key].copy()
spec.roi = request[array_key].roi
batch.arrays[array_key] = Array(
outputs[array_name].cpu().detach().numpy(), spec
)
for array_name, array_key in self.gradients.items():
if array_key not in request:
continue
if isinstance(array_name, int):
tensor = outputs[array_name]
elif isinstance(array_name, str):
tensor = getattr(self.model, array_name)
else:
raise RuntimeError(
"only ints and strings are supported as gradients keys"
)
spec = self.spec[array_key].copy()
spec.roi = request[array_key].roi
batch.arrays[array_key] = Array(tensor.grad.cpu().detach().numpy(), spec)
batch.loss = loss.cpu().detach().numpy()
self.iteration += 1
batch.iteration = self.iteration
if batch.iteration % self.save_every == 0:
checkpoint_name = self._checkpoint_name(
self.checkpoint_basename, batch.iteration
)
logger.info("Creating checkpoint %s", checkpoint_name)
torch.save(
{
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
},
checkpoint_name,
)
if self.summary_writer and batch.iteration % self.log_every == 0:
self.summary_writer.add_scalar("loss", batch.loss, batch.iteration)
def __collect_requested_outputs(self, request):
array_outputs = {}
for output_name, array_key in self.outputs.items():
if array_key in request:
array_outputs[array_key] = output_name
return array_outputs
def __collect_provided_inputs(self, batch):
return self.__collect_provided_arrays(
{
k: v
for k, v in self.inputs.items()
if (isinstance(k, int) or k not in self.loss_inputs)
},
batch,
)
def __collect_provided_loss_inputs(self, batch):
return self.__collect_provided_arrays(
self.loss_inputs, batch, expect_missing_arrays=True
)
def __collect_provided_arrays(self, reference, batch, expect_missing_arrays=False):
arrays = {}
for array_name, array_key in reference.items():
if isinstance(array_key, ArrayKey):
msg = f"batch does not contain {array_key}, array {array_name} will not be set"
if array_key in batch.arrays:
arrays[array_name] = batch.arrays[array_key].data
elif not expect_missing_arrays:
logger.warn(msg)
else:
logger.debug(msg)
elif isinstance(array_key, np.ndarray):
arrays[array_name] = array_key
elif isinstance(array_key, str):
arrays[array_name] = getattr(batch, array_key)
else:
raise Exception(
"Unknown network array key {}, can't be given to "
"network".format(array_key)
)
return arrays