Source code for gunpowder.torch.nodes.train

import logging
import numpy as np

from gunpowder.array import ArrayKey, Array
from gunpowder.array_spec import ArraySpec
from gunpowder.ext import torch, tensorboardX, NoSuchModule
from gunpowder.nodes.generic_train import GenericTrain

from typing import Dict, Union, Optional, Any
import itertools

logger = logging.getLogger(__name__)


[docs] class Train(GenericTrain): """Torch implementation of :class:`gunpowder.nodes.GenericTrain`. Args: model (subclass of ``torch.nn.Module``): The model to train. loss: The torch loss to use. optimizer: The torch optimizer to use. inputs (``dict``, ``string`` or ``int`` -> :class:`ArrayKey`): Dictionary from the names of input tensors (argument names of the ``forward`` method) in the model to array keys. loss_inputs (``dict``, ``string`` or ``int`` -> :class:`ArrayKey`): Dictionary with the names of input variables to the loss function as keys, and ArrayKeys containing the desired data as values. Keys can be either strings or integers. If the key is an integer, it will be treated as a positional argument to the loss function, a string will be used as a named argument outputs (``dict``, ``string`` or ``int`` -> :class:`ArrayKey`): Dictionary from the names of tensors in the network to array keys. If the key is a string, the tensor will be retrieved by checking the model for an attribute with they key as its name. If the key is an integer, it is interpreted as a tuple index of the outputs of the network. New arrays will be generated by this node for each entry (if requested downstream). gradients (``dict``, ``string`` or ``int`` -> :class:`ArrayKey`, optional): Dictionary from the names of tensors in the network to array keys. If the key is a string, the tensor will be retrieved by checking the model for an attribute with they key as its name. If the key is an integer, it is interpreted as a tuple index of the outputs of the network. Instead of the actual array, the gradient of the array with respect to the loss will be generated and saved. array_specs (``dict``, :class:`ArrayKey` -> :class:`ArraySpec`, optional): Used to set the specs of generated arrays (at the moment only ``output``). This is useful to set the ``voxel_size``, for example, if they differ from the voxel size of the input arrays. Only fields that are not ``None`` in the given :class:`ArraySpec` will be used. checkpoint_basename (``string``, optional): The basename used for checkpoint files. Defaults to ``model``. save_every (``int``, optional): After how many iterations to create a checkpoint to store the learnt weights. log_dir (``string``, optional): Directory for saving tensorboard summaries. log_every (``int``, optional): After how many iterations to write out tensorboard summaries. spawn_subprocess (``bool``, optional): Whether to run the ``train_step`` in a separate process. Default is false. device (``str``, optional): Accepts a cuda gpu specifically to train on (e.g. `cuda:1`, `cuda:2`), helps in multi-card systems. defaults to ``cuda`` """ def __init__( self, model, loss, optimizer, inputs: Dict[Union[str, int], ArrayKey], outputs: Dict[Union[int, str], ArrayKey], loss_inputs: Dict[Union[int, str], ArrayKey], gradients: Dict[Union[int, str], ArrayKey] = {}, array_specs: Optional[Dict[ArrayKey, ArraySpec]] = None, checkpoint_basename: str = "model", save_every: int = 2000, log_dir: Optional[str] = None, log_every: int = 1, spawn_subprocess: bool = False, device: str = "cuda", ): if not model.training: logger.warning( "Model is in evaluation mode during training. " "Consider using model.train()" ) # not yet implemented gradients = gradients loss_inputs = {f"loss_{k}": v for k, v in loss_inputs.items()} all_inputs: dict[str | int, Any] = { f"{k}": v for k, v in inputs.items() if v not in outputs.values() } all_inputs.update( {k: v for k, v in loss_inputs.items() if v not in outputs.values()} ) super(Train, self).__init__( all_inputs, outputs, gradients, array_specs, spawn_subprocess=spawn_subprocess, ) self.model = model self.loss = loss self.optimizer = optimizer self.loss_inputs = loss_inputs self.checkpoint_basename = checkpoint_basename self.save_every = save_every self.dev = device self.iteration = 0 if not isinstance(tensorboardX, NoSuchModule) and log_dir is not None: self.summary_writer = tensorboardX.SummaryWriter(log_dir) self.log_every = log_every else: self.summary_writer = None if log_dir is not None: logger.warning("log_dir given, but tensorboardX is not installed") self.intermediate_layers: dict[ArrayKey, Any] = {} self.register_hooks() def register_hooks(self): for key in self.outputs: if isinstance(key, str): layer = getattr(self.model, key) layer.register_forward_hook(self.create_hook(key)) def create_hook(self, key): def save_layer(module, input, output): self.intermediate_layers[key] = output return save_layer def retain_gradients(self, request, outputs): for array_name, array_key in self.gradients.items(): if array_key not in request: continue if isinstance(array_name, int): tensor = outputs[array_name] elif isinstance(array_name, str): tensor = getattr(self.model, array_name) else: raise RuntimeError( "only ints and strings are supported as gradients keys" ) tensor.retain_grad() def start(self): self.use_cuda = torch.cuda.is_available() # Issue: #188 self.device = torch.device(self.dev if self.use_cuda else "cpu") try: self.model = self.model.to(self.device) except RuntimeError as e: raise RuntimeError( "Failed to move model to device. If you are using a child process " "to run your model, maybe you already initialized CUDA by sending " "your model to device in the main process." ) from e if isinstance(self.loss, torch.nn.Module): self.loss = self.loss.to(self.device) checkpoint, self.iteration = self._get_latest_checkpoint( self.checkpoint_basename ) if checkpoint is not None: logger.info("Resuming training from iteration %d", self.iteration) logger.info("Loading %s", checkpoint) checkpoint = torch.load(checkpoint, map_location=self.device) self.model.load_state_dict(checkpoint["model_state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) else: logger.info("Starting training from scratch") logger.info("Using device %s", self.device) def train_step(self, batch, request): inputs = self.__collect_provided_inputs(batch) inputs = {k: torch.as_tensor(v, device=self.device) for k, v in inputs.items()} requested_outputs = self.__collect_requested_outputs(request) # keys are argument names of model forward pass device_input_args = [] for i in range(len(inputs)): key = f"{i}" if key in inputs: device_input_args.append(inputs.pop(key)) else: break device_input_kwargs = {k: v for k, v in inputs.items() if isinstance(k, str)} # get outputs. Keys are tuple indices or model attr names as in self.outputs self.optimizer.zero_grad() model_outputs = self.model(*device_input_args, **device_input_kwargs) if isinstance(model_outputs, tuple): outputs = {i: model_outputs[i] for i in range(len(model_outputs))} elif isinstance(model_outputs, torch.Tensor): outputs = {0: model_outputs} else: raise RuntimeError( "Torch train node only supports return types of tuple", f"and torch.Tensor from model.forward(). not {type(model_outputs)}", ) outputs.update(self.intermediate_layers) # Some inputs to the loss should come from the batch, not the model provided_loss_inputs = self.__collect_provided_loss_inputs(batch) device_loss_inputs = { k: torch.as_tensor(v, device=self.device) for k, v in provided_loss_inputs.items() } # Some inputs to the loss function should come from the outputs of the model # Update device loss inputs with tensors from outputs if available flipped_outputs = {v: outputs[k] for k, v in self.outputs.items()} device_loss_inputs = { k: flipped_outputs.get(v, device_loss_inputs.get(k)) for k, v in self.loss_inputs.items() } device_loss_args = [] for i in range(len(device_loss_inputs)): key = f"loss_{i}" if key in device_loss_inputs: device_loss_args.append(device_loss_inputs.pop(key)) else: break device_loss_kwargs = {} for k, v in list(device_loss_inputs.items()): if isinstance(k, str): device_loss_kwargs[k] = device_loss_inputs.pop(k) assert ( len(device_loss_inputs) == 0 ), f"Not all loss inputs could be interpreted. Failed keys: {device_loss_inputs.keys()}" self.retain_gradients(request, outputs) logger.debug("model outputs: %s", {k: v.shape for k, v in outputs.items()}) logger.debug( "loss inputs: %s %s", [v.shape for v in device_loss_args], {k: v.shape for k, v in device_loss_kwargs.items()}, ) loss = self.loss(*device_loss_args, **device_loss_kwargs) loss.backward() self.optimizer.step() # add requested model outputs to batch for array_key, array_name in requested_outputs.items(): spec = self.spec[array_key].copy() spec.roi = request[array_key].roi batch.arrays[array_key] = Array( outputs[array_name].cpu().detach().numpy(), spec ) for array_name, array_key in self.gradients.items(): if array_key not in request: continue if isinstance(array_name, int): tensor = outputs[array_name] elif isinstance(array_name, str): tensor = getattr(self.model, array_name) else: raise RuntimeError( "only ints and strings are supported as gradients keys" ) spec = self.spec[array_key].copy() spec.roi = request[array_key].roi batch.arrays[array_key] = Array(tensor.grad.cpu().detach().numpy(), spec) batch.loss = loss.cpu().detach().numpy() self.iteration += 1 batch.iteration = self.iteration if batch.iteration % self.save_every == 0: checkpoint_name = self._checkpoint_name( self.checkpoint_basename, batch.iteration ) logger.info("Creating checkpoint %s", checkpoint_name) torch.save( { "model_state_dict": self.model.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), }, checkpoint_name, ) if self.summary_writer and batch.iteration % self.log_every == 0: self.summary_writer.add_scalar("loss", batch.loss, batch.iteration) def __collect_requested_outputs(self, request): array_outputs = {} for output_name, array_key in self.outputs.items(): if array_key in request: array_outputs[array_key] = output_name return array_outputs def __collect_provided_inputs(self, batch): return self.__collect_provided_arrays( { k: v for k, v in self.inputs.items() if (isinstance(k, int) or k not in self.loss_inputs) }, batch, ) def __collect_provided_loss_inputs(self, batch): return self.__collect_provided_arrays( self.loss_inputs, batch, expect_missing_arrays=True ) def __collect_provided_arrays(self, reference, batch, expect_missing_arrays=False): arrays = {} for array_name, array_key in reference.items(): if isinstance(array_key, ArrayKey): msg = f"batch does not contain {array_key}, array {array_name} will not be set" if array_key in batch.arrays: arrays[array_name] = batch.arrays[array_key].data elif not expect_missing_arrays: logger.warn(msg) else: logger.debug(msg) elif isinstance(array_key, np.ndarray): arrays[array_name] = array_key elif isinstance(array_key, str): arrays[array_name] = getattr(batch, array_key) else: raise Exception( "Unknown network array key {}, can't be given to " "network".format(array_key) ) return arrays