← Back to all posts

Repro: Network Dissection (Bau et al, 2017)

Mark Henry

Bau et al, 2017 presents network dissection, a method for automatically identifying channels which detect visual concepts described in a dataset of human visual concepts.

Broden Dataset Examples
Examples of the concepts in Bau2017's dataset of visual concepts

My toy implementation of this paper runs in just a minute or two on CPU. It:

  1. trains a three-layer MNIST net
  2. creates a tiny dataset of simple line concepts
  3. uses network dissection to identify detectors for these simple line concepts

Here are the simple line primitives which the implementation searches for:

Line Primitives
Tableaus are created with four "primitives" each.

Network dissection identifies which channels have a high IoU (Intersection of Unions) score for particular visual concepts, designating them as detectors:

Analyzing conv1:
Channel 1 detects Horizontal (IoU: 0.146)
Channel 5 detects Diagonal Left (IoU: 0.214)
[...]
Channel 15 detects Diagonal Right (IoU: 0.125)

Analyzing conv2:
Channel 7 detects Diagonal Right (IoU: 0.206)
[...]
Channel 27 detects Diagonal Left (IoU: 0.117)

Analyzing conv3:
Channel 5 detects Diagonal Left (IoU: 0.112)
Channel 10 detects Diagonal Left (IoU: 0.111)
[...]
Channel 53 detects NW (IoU: 0.115)

Because of the simple nature of the dataset and the network, about 20 detectors are identified. Let's look at a detector in each of the three layers.

In the below figures, yellow highlights where the true feature is present. Green highilghts where the detector correctly activates on the feature. Purple indicates a false positive.

This channel in the earliest layer is identified as a detector of horizontal lines. Note the small false positives on bottom edges.

Conv1 Channel 1 Performance

This channel in the middle layer reacts to "southeast curves" which curve from "south" to "east." Again we see its interest is somewhat piqued by other types of diagonal lines.

Conv2 Channel 16 Performance

And finally, this channel in the third layer is a detector for "diagonal left" lines.

Conv3 Channel 10 Performance

I noticed that the curves start to be detected in layer 3, and the detectors there have almost no false positives.

This demonstrates that the technique is effective at algorithmically identifying which neurons are detectors for human-interpretable concepts.

The code is available on github.