Note
Go to the end to download the full example code.
U-Net#
The same small U-Net as the graph style’s example, rendered in flow style instead - the
contracting-then-expanding channel/spatial shape naturally produces the classic U-Net silhouette.
This relies on show_input (on by default): without it, the diagram’s left edge would start at
the first encoder block’s output, not the raw input, making the silhouette lopsided instead of
symmetric with the output on the right.
Conv2d is orange, ConvTranspose2d is green, ReLU is sky blue, and MaxPool2d is reddish purple.

from collections import defaultdict
import matplotlib.pyplot as plt
import torch
import visualtorch
from torch import nn
class UNet(nn.Module):
"""A small U-Net with 2 encoder/decoder stages and skip connections."""
def __init__(self) -> None:
super().__init__()
self.enc1 = nn.Sequential(nn.Conv2d(3, 16, kernel_size=3, padding=1), nn.ReLU())
self.pool1 = nn.MaxPool2d(2)
self.enc2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.ReLU())
self.pool2 = nn.MaxPool2d(2)
self.bottleneck = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU())
self.up2 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
self.dec2 = nn.Sequential(nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.ReLU())
self.up1 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
self.dec1 = nn.Sequential(nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU())
self.final = nn.Conv2d(16, 2, kernel_size=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Downsample through 2 encoder stages, then upsample back, concatenating each
encoder stage's output into its corresponding decoder stage.
""" # noqa: D205
e1 = self.enc1(x)
e2 = self.enc2(self.pool1(e1))
b = self.bottleneck(self.pool2(e2))
d2 = self.dec2(torch.cat([self.up2(b), e2], dim=1))
d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
return self.final(d1)
model = UNet()
input_shape = (1, 3, 64, 64)
color_map: dict = defaultdict(dict)
color_map[nn.Conv2d]["fill"] = "#E69F00"
color_map[nn.ConvTranspose2d]["fill"] = "#009E73"
color_map[nn.ReLU]["fill"] = "#56B4E9"
color_map[nn.MaxPool2d]["fill"] = "#CC79A7"
img = visualtorch.render(model, input_shape, style="flow", color_map=color_map, scale_xy=3, spacing=15)
dpi = 150 # rendered at 2x this in the final doc build (savefig.dpi=300 in conf.py)
plt.figure(figsize=(img.width / dpi, img.height / dpi), dpi=dpi)
plt.imshow(img)
plt.axis("off")
plt.tight_layout()
plt.show()