Note
Go to the end to download the full example code.
U-Net#
A small U-Net: two downsampling encoder stages, a bottleneck, and two upsampling decoder stages, with a skip connection concatenating each encoder stage’s output into its corresponding decoder stage. Both skip connections are genuine bypasses (around the pooling/bottleneck path) and are correctly routed above the diagram, nested since the two spans overlap.
Conv2d is orange, ConvTranspose2d is green, ReLU is sky blue, and MaxPool2d is reddish purple.

from collections import defaultdict
import matplotlib.pyplot as plt
import torch
import visualtorch
from torch import nn
class UNet(nn.Module):
"""A small U-Net with 2 encoder/decoder stages and skip connections."""
def __init__(self) -> None:
super().__init__()
self.enc1 = nn.Sequential(nn.Conv2d(3, 16, kernel_size=3, padding=1), nn.ReLU())
self.pool1 = nn.MaxPool2d(2)
self.enc2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.ReLU())
self.pool2 = nn.MaxPool2d(2)
self.bottleneck = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU())
self.up2 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
self.dec2 = nn.Sequential(nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.ReLU())
self.up1 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
self.dec1 = nn.Sequential(nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU())
self.final = nn.Conv2d(16, 2, kernel_size=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Downsample through 2 encoder stages, then upsample back, concatenating each
encoder stage's output into its corresponding decoder stage.
""" # noqa: D205
e1 = self.enc1(x)
e2 = self.enc2(self.pool1(e1))
b = self.bottleneck(self.pool2(e2))
d2 = self.dec2(torch.cat([self.up2(b), e2], dim=1))
d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
return self.final(d1)
model = UNet()
input_shape = (1, 3, 64, 64)
color_map: dict = defaultdict(dict)
color_map[nn.Conv2d]["fill"] = "#E69F00"
color_map[nn.ConvTranspose2d]["fill"] = "#009E73"
color_map[nn.ReLU]["fill"] = "#56B4E9"
color_map[nn.MaxPool2d]["fill"] = "#CC79A7"
img = visualtorch.render(model, input_shape, style="graph", show_neurons=False, color_map=color_map, layer_spacing=60)
dpi = 150 # rendered at 2x this in the final doc build (savefig.dpi=300 in conf.py)
plt.figure(figsize=(img.width / dpi, img.height / dpi), dpi=dpi)
plt.imshow(img)
plt.axis("off")
plt.tight_layout()
plt.show()