Training the StarGAN
Attention
It is recommended to use the YAML configuration method to train the conversion model.
In this tutorial, we go over the basics of how to train a (slightly modified) StarGAN for use in QuAC.
Defining the dataset
The data is expected to be in the form of image files with a directory structure denoting the classification. For example:
data_folder/
crow/
crow1.png
crow2.png
raven/
raven1.png
raven2.png
A training dataset is defined in quac.training.data which will need to be given two directories: a source and a reference. These directories can be the same.
The validation dataset will need the same information.
For example:
1from quac.training.data import TrainingDataset
2
3training_directory = "path/to/training/data"
4validation_directory = "path/to/validation/data"
5
6dataset = TrainingDataset(
7 source=training_directory,
8 reference=training_directory,
9 img_size=128,
10 batch_size=4,
11 num_workers=4
12)
13
14# Setup data for validation
15val_dataset = ValidationData(
16 source=validation_directory,
17 reference=validation_directory,
18 img_size=128,
19 batch_size=16,
20 num_workers=16
21)
Defining the models
The models can be built using a function in quac.training.stargan.
1from quac.training.stargan import build_model
2
3nets, nets_ema = build_model(
4 img_size=256, # Images are made square
5 style_dim=64, # The size of the style vector
6 input_dim=1, # Number of channels in the input
7 latent_dim=16, # The size of the random latent
8 num_domains=4, # Number of classes
9 single_output_style_encoder=False
10)
11## Defining the models
12nets, nets_ema = build_model(**experiment.model.model_dump())
If using multiple or specific GPUs, it may be necessary to add the gpu_ids argument.
The nets_ema are a copy of the nets that will not be trained but rather will be an exponential moving average of the weight of the nets. The sub-networks of both can be accessed in a dictionary-like manner.
Creating a logger
1# Example using WandB
2logger = Logger.create(
3 log_type="wandb",
4 project="project-name",
5 name="experiment name",
6 tags=["experiment", "project", "test", "quac", "stargan"],
7 hparams={ # this holds all of the hyperparameters you want to store for your run
8 "hyperparameter_key": "Hyperparameter values"
9 }
10)
11
12# TODO example using tensorboard
Defining the Solver
It is now time to initiate the Solver object, which will do the bulk of the work in training.
1solver = Solver(
2 nets,
3 nets_ema,
4 # Checkpointing
5 checkpoint_dir="path/to/store/checkpoints",
6 # Parameters for the Adam optimizers
7 lr=1e-4,
8 beta1=0.5,
9 beta2=0.99,
10 weight_decay=0.1,
11)
12
13solver = Solver(nets, nets_ema, **experiment.solver.model_dump(), run=logger)
Training
Once we’ve created the solver, we also need to define how we’re going to train and validate. This is done through three different configuations.
The ValConfig determines how validation will be done. It especially tells us
1val_config=ValConfig(
2 classifier_checkpoint="/path/to/classifier/",
3 # The below is default
4 val_batch_size=32
5 num_outs_per_domain=10,
6 mean=0.5,
7 std=0.5,
8 grayscale=True,
9)
1loss_config=LossConfig(
2 lambda_ds=0.,
3 lambda_reg=1.,
4 lambda_sty=1.,
5 lambda_cyc=1.,
6)
7
8run_config=RunConfig(
9 # All of these are default
10 resume_iter=0,
11 total_iter=100000,
12 log_every=1000,
13 save_every=10000,
14 eval_every=10000,
15)
Finally, we can train the model!
1from quac.training.options import ValConfig
2
3solver.train(dataset, val_config)
All results will be stored in the checkpoint_directory defined above.
Once your model is trained, you can move on to generating images with it.