How to generate images from a pre-trained network
Defining the dataset
We will be generating images one source-target pair at a time. As such, we need to point to the subdirectory that holds the source class that we are interested in. For example, below, we are going to be using the validation data, and our source class will be class 0 which has no Diabetic Retinopathy.
1from pathlib import Path
2from quac.generate import load_data
3
4img_size = 224
5data_directory = Path("/path/to/directory/holding/the/data/source_class")
6dataset = load_data(data_directory, img_size, grayscale=False)
Loading the classifier
Next we need to load the pre-trained classifier, and wrap it in the correct pre-processing step. The classifier is expected to be saved as a torchscript checkpoint. This allows us to use it without having to redefine the python class from which it was generated.
We also have a wrapper around the classifier that re-normalizes images to the range that it expects. The assumption is that these images come from the StarGAN trained with quac, so the images will have values in [-1, 1]. Here, our pre-trained classifier expects images with the ImageNet normalization, for example.
Finally, we need to define the device, and whether to put the classifier in eval mode.
1from quac.generate import load_classifier
2
3mean = (0.485, 0.456, 0.406)
4std = (0.229, 0.224, 0.225)
5device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6
7classifier = load_classifier(classifier_checkpoint, mean=mean, std=std, eval=True, device=device)
Inference from random latents
The StarGAN model used to generate images can have two sources for the style. The first and simplest one is to use a random latent vector to create style.
Loading the StarGAN
1from quac.generate import load_stargan
2
3latent_model_checkpoint_dir = Path("/path/to/directory/holding/the/stargan/checkpoints")
4
5inference_model = load_stargan(
6 latent_model_checkpoint_dir,
7 img_size=224,
8 input_dim=1,
9 style_dim=64,
10 latent_dim=16,
11 num_domains=5,
12 checkpoint_iter=100000,
13 kind = "latent"
14)
Running the image generation
Finally, we can run the image generation.
1from quac.generate import get_counterfactual
2from torchvision.utils import save_image
3
4output_directory = Path("/path/to/output/latent/source_class/target_class/")
5
6for x, name in tqdm(dataset):
7 xcf = get_counterfactual(
8 classifier,
9 inference_model,
10 x,
11 target=1,
12 kind="latent",
13 device=device,
14 max_tries=10,
15 batch_size=10
16 )
17 # For example, you can save the images here
18 save_image(xcf, output_directory / name)
Inference using a reference dataset
The alternative image generation method of a StarGAN is to use an image of the target class to generate the style using the StyleEncoder. Although the structure is similar as above, there are a few key differences.
Generating the reference dataset
The first thing we need to do is to get the reference images.
1reference_data_directory = Path("/path/to/directory/holding/the/data/target_class")
2reference_dataset = load_data(reference_data_directory, img_size, grayscale=False)
Loading the StarGAN
This time, we will be creating a ReferenceInferenceModel.
1inference_model = load_stargan(
2 latent_model_checkpoint_dir,
3 img_size=224,
4 input_dim=1,
5 style_dim=64,
6 latent_dim=16,
7 num_domains=5,
8 checkpoint_iter=100000,
9 kind = "reference"
10)
Running the image generation
Finally, we combine the two by changing the kind in our counterfactual generation, and giving it the reference dataset to use.
1from torchvision.utils import save_image
2
3output_directory = Path("/path/to/output/reference/source_class/target_class/")
4
5for x, name in tqdm(dataset):
6 xcf = get_counterfactual(
7 classifier,
8 inference_model,
9 x,
10 target=1,
11 kind="reference", # Change the kind of inference being done
12 dataset_ref=reference_dataset, # Add the reference dataset
13 device=device,
14 max_tries=10,
15 batch_size=10
16 )
17 # For example, you can save the images here
18 save_image(xcf, output_directory / name)