How to generate images from a pre-trained network

Defining the dataset

We will be generating images one source-target pair at a time. As such, we need to point to the subdirectory that holds the source class that we are interested in. For example, below, we are going to be using the validation data, and our source class will be class 0 which has no Diabetic Retinopathy.

1from pathlib import Path
2from quac.generate import load_data
3
4img_size = 224
5data_directory = Path("/path/to/directory/holding/the/data/source_class")
6dataset = load_data(data_directory, img_size, grayscale=False)

Loading the classifier

Next we need to load the pre-trained classifier, and wrap it in the correct pre-processing step. The classifier is expected to be saved as a torchscript checkpoint. This allows us to use it without having to redefine the python class from which it was generated.

We also have a wrapper around the classifier that re-normalizes images to the range that it expects. The assumption is that these images come from the StarGAN trained with quac, so the images will have values in [-1, 1]. Here, our pre-trained classifier expects images with the ImageNet normalization, for example.

Finally, we need to define the device, and whether to put the classifier in eval mode.

1from quac.generate import load_classifier
2
3mean = (0.485, 0.456, 0.406)
4std = (0.229, 0.224, 0.225)
5device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6
7classifier = load_classifier(classifier_checkpoint, mean=mean, std=std, eval=True, device=device)

Inference from random latents

The StarGAN model used to generate images can have two sources for the style. The first and simplest one is to use a random latent vector to create style.

Loading the StarGAN

 1from quac.generate import load_stargan
 2
 3latent_model_checkpoint_dir = Path("/path/to/directory/holding/the/stargan/checkpoints")
 4
 5inference_model = load_stargan(
 6    latent_model_checkpoint_dir,
 7    img_size=224,
 8    input_dim=1,
 9    style_dim=64,
10    latent_dim=16,
11    num_domains=5,
12    checkpoint_iter=100000,
13    kind = "latent"
14)

Running the image generation

Finally, we can run the image generation.

 1from quac.generate import get_counterfactual
 2from torchvision.utils import save_image
 3
 4output_directory = Path("/path/to/output/latent/source_class/target_class/")
 5
 6for x, name in tqdm(dataset):
 7    xcf = get_counterfactual(
 8        classifier,
 9        inference_model,
10        x,
11        target=1,
12        kind="latent",
13        device=device,
14        max_tries=10,
15        batch_size=10
16    )
17    # For example, you can save the images here
18    save_image(xcf, output_directory / name)

Inference using a reference dataset

The alternative image generation method of a StarGAN is to use an image of the target class to generate the style using the StyleEncoder. Although the structure is similar as above, there are a few key differences.

Generating the reference dataset

The first thing we need to do is to get the reference images.

1reference_data_directory = Path("/path/to/directory/holding/the/data/target_class")
2reference_dataset = load_data(reference_data_directory, img_size, grayscale=False)

Loading the StarGAN

This time, we will be creating a ReferenceInferenceModel.

 1inference_model = load_stargan(
 2    latent_model_checkpoint_dir,
 3    img_size=224,
 4    input_dim=1,
 5    style_dim=64,
 6    latent_dim=16,
 7    num_domains=5,
 8    checkpoint_iter=100000,
 9    kind = "reference"
10)

Running the image generation

Finally, we combine the two by changing the kind in our counterfactual generation, and giving it the reference dataset to use.

 1from torchvision.utils import save_image
 2
 3output_directory = Path("/path/to/output/reference/source_class/target_class/")
 4
 5for x, name in tqdm(dataset):
 6    xcf = get_counterfactual(
 7        classifier,
 8        inference_model,
 9        x,
10        target=1,
11        kind="reference",   # Change the kind of inference being done
12        dataset_ref=reference_dataset,  # Add the reference dataset
13        device=device,
14        max_tries=10,
15        batch_size=10
16    )
17    # For example, you can save the images here
18    save_image(xcf, output_directory / name)