Evaluation
Finally, we have the generated images and the attributions, let’s run the QuAC evaluation to get and score our final counterfactuals.
Once again, we will need the classifier! Indeed, it is using the change in the classifier’s output that we will decide on the quality of our counterfactual. We also want to use the correct classifier transform, so we will define it here.
1from quac.generate import load_classifier
2
3classifier_checkpoint = "path/to/classifier/checkpoint"
4classifier = load_classifier(
5 checkpoint_path=classifier_checkpoint
6)
7
8# Defining the transform
9transform = transforms.Compose(
10 [
11 transforms.ToTensor(),
12 transforms.Grayscale(),
13 transforms.Resize(128),
14 transforms.Normalize(0.5, 0.5),
15 ]
16)
Let’s run evaluation for the discriminative version of integrated gradients. If you have been following the tutorials exactly, you will also have run the vanilla version of integrated gradients. Just swap the attribution directory in the below to run the vanilla version instead!
1# Defining processors and evaluators
2from quac.evaluation import Processor, Evaluator
3
4attribution_method_name = "discriminative_ig"
5data_directory = "path/to/data/directory"
6counterfactual_directory = "path/to/counterfactual/directory"
7attribution_directory = "path/to/attributions/directory/" + attribution_method_name
8
9
10evaluator = Evaluator(
11 classifier,
12 source_directory=data_directory,
13 counterfactual_directory=counterfactual_directory,
14 attribution_directory=attribution_directory,
15 transform=transform
16)
To run the evaluation, we will need to define a processor. This is the object that takes an attribution map and turns it into a binary attribution mask. QuAC provides a default processor that will work for most cases. Finally, we’ll need a place to store the results.
1report_directory = "path/to/store/reports/" + attribution_method_name
2
3# Run QuAC evaluation on your attribution and store a report
4report = evaluator.quantify(processor=Processor())
5# The report will be stored based on the processor's name, which is "default" by default
6report.store(report_directory)
Done! Now all that is left is to go through the report and visualize your final results.