Vision Transformer (ViT)

Vision Transformer (ViT)#

The same small Vision Transformer as the graph/flow styles’ ViT examples, rendered in lenet style instead. A Transformer layer’s (seq_len, hidden_size) shape has no real spatial/channel structure, so hidden_size (the feature/channel-like dimension) drives the stacked-plane depth the same way a CNN’s channel count does, keeping every layer’s look consistent throughout the diagram.

offset_z=1 is set here - the default (10) is tuned for CNN layers’ typically-modest channel counts, and multiplies into a very wide image once every Transformer layer’s hidden_size/dim_feedforward also drives depth.

Note: VisualTorch traces the literal module-by-module computation, so this shows the real executed sequence of layers, not the conceptual/pedagogical diagram style used in ViT papers. show_dimension (on by default for this style) is turned off here - with this many similarly-narrow boxes packed close together, the shape labels overlap into an unreadable mess.

Conv2d is orange, MultiheadAttention is reddish purple, Linear is sky blue, LayerNorm is bluish green, and Dropout is yellow.

plot vit lenet style
from collections import defaultdict

import matplotlib.pyplot as plt
import torch
import visualtorch
from torch import nn


class VisionTransformer(nn.Module):
    """A small Vision Transformer: patch embedding + positional embedding + Transformer encoder."""

    def __init__(
        self,
        img_size: int = 32,
        patch_size: int = 8,
        dim: int = 64,
        depth: int = 2,
        heads: int = 4,
        num_classes: int = 10,
    ) -> None:
        super().__init__()
        num_patches = (img_size // patch_size) ** 2
        self.patch_embed = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size)
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, dim))
        encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=dim * 4, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        self.head = nn.Linear(dim, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Split into patches, embed, add positional embedding, encode, then classify."""
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)
        x = x + self.pos_embed
        x = self.encoder(x)
        x = x.mean(dim=1)
        return self.head(x)


model = VisionTransformer()

input_shape = (1, 3, 32, 32)

color_map: dict = defaultdict(dict)
color_map[nn.Conv2d]["fill"] = "#E69F00"
color_map[nn.MultiheadAttention]["fill"] = "#CC79A7"
color_map[nn.Linear]["fill"] = "#56B4E9"
color_map[nn.LayerNorm]["fill"] = "#009E73"
color_map[nn.Dropout]["fill"] = "#F0E442"

img = visualtorch.render(
    model,
    input_shape,
    style="lenet",
    color_map=color_map,
    spacing=80,
    padding=60,
    show_dimension=False,
    offset_z=1,
)

dpi = 150  # rendered at 2x this in the final doc build (savefig.dpi=300 in conf.py)
plt.figure(figsize=(img.width / dpi, img.height / dpi), dpi=dpi)
plt.imshow(img)
plt.axis("off")
plt.tight_layout()
plt.show()

Gallery generated by Sphinx-Gallery