The classifier
Training a classifier
The purpose of QuAC is to explain the decisions of a pre-trained classifier. As such, you need a classifier before you can use QuAC.
Note
QuAC explanations will be at most as good as the classifier that they explain. We recommend a classifier that has is at least (almost) as good as a human, though QuAC is ideal for situations where the classifier is better than humans! Have a look at the appendix in the pre-print for an example of how a bad classifier might lead to bad explanations.
We will need the classifier to be a pytorch
model.
If you don’t already have a classifier, here are some tutorials describing how to train one:
Attention
Pay particular attention to the data normalization you use when training your classifier. We will be chaining networks in QuAC, and incorrect data ranges at the input of any of these networks will lead to incorrect results.
We strongly recommend making sure that your classifier expects input data that lies in [-1, 1]
.
While you can set the conversion network to return data in different ranges the hyper-parameters in QuAC have been optimized for [-1, 1]
.
Generative adversarial networks are very finnicky creatures, so you will likely have to do extensive tuning outside of the defaults.
Compiling to torchscript
To avoid the need to modify the code for every new type of classifier, you must convert your model to torchscript
. This is a format which includes a description of the architecture.
Modify and run the code below to do this.
1# TODO set your checkpoint paths
2input_checkpoint = "path/to/pytorch/model/checkpoint"
3output_checkpoint = "path/to/store/jit-compiled/checkpoint"
4
5model = ... # TODO create your model as you do for training
6
7# Load your desired checkpoint checkpoint
8model.load_state_dict(torch.load(input_checkpoint))
9
10# Turn your model to torch-script and save it
11torch.jit.save(torch.jit.script(model), output_checkpoint)
Every time QuAC requires a classifier checkpoint, you should now point it to the output_checkpoint
.