Visualizing the results

In this tutorial, we will show you how to visualize the results of the attribution and evaluation steps. Make sure to modify the paths to the reports and the classifier to match your setup!

Obtaining the QuAC curves

Let’s start by loading the reports obtained in the previous step.

 1from quac.report import Report
 2
 3report_directory = "/path/to/report/directory/"
 4reports = {
 5    method: Report(name=method)
 6    for method in [
 7        "DDeepLift",
 8        "DIntegratedGradients",
 9    ]
10}
11
12for method, report in reports.items():
13    report.load(report_directory + method + "/default.json")

Next, we can plot the QuAC curves for each method. This allows us to get an idea of how well each method is performing, overall.

1import matplotlib.pyplot as plt
2
3fig, ax = plt.subplots()
4for method, report in reports.items():
5    report.plot_curve(ax=ax)
6# Add the legend
7plt.legend()
8plt.show()

Choosing the best attribution method for each sample

While one attribution method may be better than another on average, it is possible that the best method for a given example is different. Therefore, we will make a list of the best method for each example by comparing the quac scores.

1quac_scores = pd.DataFrame(
2    {method: report.quac_scores for method, report in reports.items()}
3)
4best_methods = quac_scores.idxmax(axis=1)
5best_quac_scores = quac_scores.max(axis=1)

We’ll also want to load the classifier at this point, so we can look at the classifications of the counterfactual images.

1import torch
2
3classifier = torch.jit.load("/path/to/classifier/model.pt")

Choosing the best examples

Next we want to choose the best example, given the best method. This is done by ordering the examples by the QuAC score, and then choosing the one with the highest score.

1order = best_quac_scores[::-1].argsort()
2
3# For example, choose the 10th best example
4idx = 10
5# Get the corresponding report
6report = reports[best_methods[order[idx]]]

We will then load that example and its counterfactual from its path, and visualize it. We also want to see the classification of both the original and the counterfactual.

 1# Transform to apply to the images so they match each other
 2# loading
 3from PIL import Image
 4
 5image_path, generated_path = report.paths[order[idx]], report.target_paths[order[idx]]
 6image, generated_image = Image.open(image_path), Image.open(generated_path)
 7
 8prediction = report.predictions[order[idx]]
 9target_prediction = report.target_predictions[order[idx]]
10
11image_path, generated_path = report.paths[order[idx]], report.target_paths[order[idx]]
12image, generated_image = Image.open(image_path), Image.open(generated_path)
13
14prediction = report.predictions[order[idx]]
15target_prediction = report.target_predictions[order[idx]]

Loading the attribution

We next want to load the attribution for the example, and visualize it.

1attribution_path = report.attribution_paths[order[idx]]
2attribution = np.load(attribution_path)

Getting the processor

We want to see the specific mask that was optimal in this case. To do this, we will need to get the optimal threshold, and get the processor used for masking.

 1from quac.evaluation import Processor
 2
 3gaussian_kernel_size = 11
 4struc = 10
 5thresh = report.optimal_thresholds()[order[idx]]
 6print(thresh)
 7processor = Processor(gaussian_kernel_size=gaussian_kernel_size, struc=struc)
 8
 9mask, _ = processor.create_mask(attribution, thresh)
10rgb_mask = mask.transpose(1, 2, 0)
11# zero-out the green and blue channels
12rgb_mask[:, :, 1] = 0
13rgb_mask[:, :, 2] = 0
14counterfactual = np.array(generated_image) / 255 * rgb_mask + np.array(image) / 255 * (1.0 - rgb_mask)

Let’s also get the classifier output for the counterfactual image.

1classifier_output = classifier(
2    torch.tensor(counterfactual).permute(2, 0, 1).float().unsqueeze(0).to(device)
3)
4counterfactual_prediction = softmax(classifier_output[0].detach().cpu().numpy())

Visualizing the results

Finally, we can visualize the results.

 1fig, axes = plt.subplots(2, 4)
 2axes[1, 0].imshow(image)
 3axes[0, 0].bar(np.arange(len(prediction)), prediction)
 4axes[1, 1].imshow(generated_image)
 5axes[0, 1].bar(np.arange(len(target_prediction)), target_prediction)
 6axes[0, 2].bar(np.arange(len(counterfactual_prediction)), counterfactual_prediction)
 7axes[1, 2].imshow(counterfactual)
 8axes[1, 3].imshow(rgb_mask)
 9axes[0, 3].axis("off")
10fig.suptitle(f"QuAC Score: {report.quac_scores[order[idx]]}")
11plt.show()

You can now see the original image, the generated image, the counterfactual image, and the mask. From here, you can choose to visualize other examples, of save the images for later use.