Training the StarGAN

Attention

It is recommended to use the YAML configuration method to train the conversion model.

In this tutorial, we go over the basics of how to train a (slightly modified) StarGAN for use in QuAC.

Defining the dataset

The data is expected to be in the form of image files with a directory structure denoting the classification. For example:

data_folder/
    crow/
        crow1.png
        crow2.png
    raven/
        raven1.png
        raven2.png

A training dataset is defined in quac.training.data which will need to be given two directories: a source and a reference. These directories can be the same.

The validation dataset will need the same information.

For example:

 1from quac.training.data import TrainingDataset
 2
 3training_directory = "path/to/training/data"
 4validation_directory = "path/to/validation/data"
 5
 6dataset = TrainingDataset(
 7    source=training_directory,
 8    reference=training_directory,
 9    img_size=128,
10    batch_size=4,
11    num_workers=4
12)
13
14# Setup data for validation
15val_dataset = ValidationData(
16    source=validation_directory,
17    reference=validation_directory,
18    img_size=128,
19    batch_size=16,
20    num_workers=16
21)

Defining the models

The models can be built using a function in quac.training.stargan.

 1from quac.training.stargan import build_model
 2
 3nets, nets_ema = build_model(
 4    img_size=256,  # Images are made square
 5    style_dim=64,  # The size of the style vector
 6    input_dim=1,  # Number of channels in the input
 7    latent_dim=16,  # The size of the random latent
 8    num_domains=4,  # Number of classes
 9    single_output_style_encoder=False
10)
11## Defining the models
12nets, nets_ema = build_model(**experiment.model.model_dump())

If using multiple or specific GPUs, it may be necessary to add the gpu_ids argument.

The nets_ema are a copy of the nets that will not be trained but rather will be an exponential moving average of the weight of the nets. The sub-networks of both can be accessed in a dictionary-like manner.

Creating a logger

 1# Example using WandB
 2logger = Logger.create(
 3    log_type="wandb",
 4    project="project-name",
 5    name="experiment name",
 6    tags=["experiment", "project", "test", "quac", "stargan"],
 7    hparams={ # this holds all of the hyperparameters you want to store for your run
 8        "hyperparameter_key": "Hyperparameter values"
 9    }
10)
11
12# TODO example using tensorboard

Defining the Solver

It is now time to initiate the Solver object, which will do the bulk of the work in training.

 1solver = Solver(
 2    nets,
 3    nets_ema,
 4    # Checkpointing
 5    checkpoint_dir="path/to/store/checkpoints",
 6    # Parameters for the Adam optimizers
 7    lr=1e-4,
 8    beta1=0.5,
 9    beta2=0.99,
10    weight_decay=0.1,
11)
12
13solver = Solver(nets, nets_ema, **experiment.solver.model_dump(), run=logger)

Training

Once we’ve created the solver, we also need to define how we’re going to train and validate. This is done through three different configuations.

The ValConfig determines how validation will be done. It especially tells us

1val_config=ValConfig(
2    classifier_checkpoint="/path/to/classifier/",
3    # The below is default
4    val_batch_size=32
5    num_outs_per_domain=10,
6    mean=0.5,
7    std=0.5,
8    grayscale=True,
9)
 1loss_config=LossConfig(
 2    lambda_ds=0.,
 3    lambda_reg=1.,
 4    lambda_sty=1.,
 5    lambda_cyc=1.,
 6)
 7
 8run_config=RunConfig(
 9    # All of these are default
10    resume_iter=0,
11    total_iter=100000,
12    log_every=1000,
13    save_every=10000,
14    eval_every=10000,
15)

Finally, we can train the model!

1from quac.training.options import ValConfig
2
3solver.train(dataset, val_config)

All results will be stored in the checkpoint_directory defined above.

Once your model is trained, you can move on to generating images with it.