Discriminative attribution from Counterfactuals
Now that we have generated counterfactuals, we will refine our generated images into counterfactuals using discriminative attribution. Remember that although the conversion network is trained to keep as much of the image fixed as possible, it is not perfect. This means that there may still be regions of the generated image that differ from the query image even if they don’t need to. Luckily, we have a classifier that can help us identify and keep only the necessary regions of change.
The first thing that we want to do is load the classifier.
1classifier_checkpoint = "path/to/classifier/checkpoint"
2
3from quac.generate import load_classifier
4classifier = load_classifier(
5 checkpoint_path=classifier_checkpoint
6)
Next, we will define the attribution that we want to use. In this tutorial, we will use Discriminative Integrated Gradients, using the classifier as a baseline. As a comparison, we will also use Vanilla Integrated Gradients, which uses a black image as a baseline. This will allow us to identify the regions of the image that are most important for the classifier to make its decision. Later in the evaluation tutorial, we will process these attributions into masks, and finally get our counterfactuals.
1# Parameters
2attribution_directory = "path/to/store/attributions"
3
4# Defining attributions
5from quac.attribution import (
6 DIntegratedGradients,
7 VanillaIntegratedGradients,
8 AttributionIO
9)
10from torchvision import transforms
11
12attributor = AttributionIO(
13 attributions = {
14 "discriminative_ig" : DIntegratedGradients(classifier),
15 "vanilla_ig" : VanillaIntegratedGradients(classifier)
16 },
17 output_directory = atttribution_directory
18)
Finally, we want to make sure that the images are processed as we would like for the classifier. Here, we will simply define a set of torchvision transforms to do this, we will pass them to the attributor object. Keep in mind that if you processed your data in a certain way when training your classfier, you will need to use the same processing here.
1transform = transforms.Compose(
2 [
3 transforms.ToTensor(),
4 transforms.Grayscale(),
5 transforms.Resize(128),
6 transforms.Normalize(0.5, 0.5),
7 ]
8)
Finally, let’s run the attributions.
1data_directory = "path/to/data/directory"
2counterfactual_directory = "path/to/counterfactual/directory"
3
4# This will run attributions and store all of the results in the output_directory
5# Shows a progress bar
6attributor.run(
7 source_directory=data_directory,
8 counterfactual_directory=counterfactual_directory,
9 transform=transform
10)
If you look into the attribution_directory, you should see a set of attributions. They will be organized in the following way:
attribution_directory/
attribution_method_name/
source_class/
target_class/
image_name.npy
In the next tutorial, we will use these attributions to generate masks and finally get our counterfactuals.