Note
Go to the end to download the full example code.
Dark Background#
background_fill isn’t limited to plain white - it also accepts a transparent color
(e.g. (0, 0, 0, 0)), useful for dropping a figure onto a paper/slide without a white box
around it, or an opaque dark color for a nicer look on dark-mode pages.
Note: when using a non-white background, also set an outline per layer type in
color_map, and connector_fill for the lines between nodes. Both default to plain black,
which becomes invisible against a dark or black background.

from collections import defaultdict
import matplotlib.pyplot as plt
import visualtorch
from torch import nn
# Example of a simple CNN model using nn.Sequential
model = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
)
color_map: dict = defaultdict(dict)
color_map[nn.Conv2d]["fill"] = "#00F5FF"
color_map[nn.Conv2d]["outline"] = "#E0FFFF"
color_map[nn.BatchNorm2d]["fill"] = "#FF10F0"
color_map[nn.BatchNorm2d]["outline"] = "#FFD1FA"
color_map[nn.ReLU]["fill"] = "#FCEE09"
color_map[nn.ReLU]["outline"] = "#FFFACD"
input_shape = (1, 3, 32, 32)
img = visualtorch.render(
model,
input_shape=input_shape,
style="graph",
show_neurons=False,
color_map=color_map,
connector_fill="white",
background_fill="black",
)
dpi = 150 # rendered at 2x this in the final doc build (savefig.dpi=300 in conf.py)
plt.figure(figsize=(img.width / dpi, img.height / dpi), dpi=dpi)
plt.imshow(img)
plt.axis("off")
plt.tight_layout()
plt.show()