Note
Go to the end to download the full example code.
Vision Transformer (ViT)#
The same small Vision Transformer as the graph/flow styles’ ViT examples, rendered in
lenet style instead. A Transformer layer’s (seq_len, hidden_size) shape has no real
spatial/channel structure, so hidden_size (the feature/channel-like dimension) drives the
stacked-plane depth the same way a CNN’s channel count does, keeping every layer’s look
consistent throughout the diagram.
offset_z=1 is set here - the default (10) is tuned for CNN layers’ typically-modest
channel counts, and multiplies into a very wide image once every Transformer layer’s
hidden_size/dim_feedforward also drives depth.
Note: VisualTorch traces the literal module-by-module computation, so this shows the real
executed sequence of layers, not the conceptual/pedagogical diagram style used in ViT papers.
show_dimension (on by default for this style) is turned off here - with this many
similarly-narrow boxes packed close together, the shape labels overlap into an unreadable mess.
Conv2d is orange, MultiheadAttention is reddish purple, Linear is sky blue, LayerNorm is bluish green, and Dropout is yellow.

from collections import defaultdict
import matplotlib.pyplot as plt
import torch
import visualtorch
from torch import nn
class VisionTransformer(nn.Module):
"""A small Vision Transformer: patch embedding + positional embedding + Transformer encoder."""
def __init__(
self,
img_size: int = 32,
patch_size: int = 8,
dim: int = 64,
depth: int = 2,
heads: int = 4,
num_classes: int = 10,
) -> None:
super().__init__()
num_patches = (img_size // patch_size) ** 2
self.patch_embed = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size)
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, dim))
encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=dim * 4, batch_first=True)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
self.head = nn.Linear(dim, num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Split into patches, embed, add positional embedding, encode, then classify."""
x = self.patch_embed(x)
x = x.flatten(2).transpose(1, 2)
x = x + self.pos_embed
x = self.encoder(x)
x = x.mean(dim=1)
return self.head(x)
model = VisionTransformer()
input_shape = (1, 3, 32, 32)
color_map: dict = defaultdict(dict)
color_map[nn.Conv2d]["fill"] = "#E69F00"
color_map[nn.MultiheadAttention]["fill"] = "#CC79A7"
color_map[nn.Linear]["fill"] = "#56B4E9"
color_map[nn.LayerNorm]["fill"] = "#009E73"
color_map[nn.Dropout]["fill"] = "#F0E442"
img = visualtorch.render(
model,
input_shape,
style="lenet",
color_map=color_map,
spacing=80,
padding=60,
show_dimension=False,
offset_z=1,
)
dpi = 150 # rendered at 2x this in the final doc build (savefig.dpi=300 in conf.py)
plt.figure(figsize=(img.width / dpi, img.height / dpi), dpi=dpi)
plt.imshow(img)
plt.axis("off")
plt.tight_layout()
plt.show()