Vision Transformer (ViT)

Vision Transformer (ViT)#

A small Vision Transformer: a Conv2d patch-embedding (splitting the image into non-overlapping patches and projecting each to a vector), a learned positional embedding added on, and a Transformer encoder over the resulting patch sequence.

Note: VisualTorch traces the literal module-by-module computation, so this shows the real executed sequence of layers (Conv2d, MultiheadAttention, Linear, LayerNorm, Dropout, …) - not the conceptual/pedagogical diagram style used in ViT papers (patch grid illustration, a “Transformer Encoder x N” labeled loop box, etc.), which are hand-designed for exposition rather than traced from a real forward pass.

Conv2d is orange, MultiheadAttention is reddish purple, Linear is sky blue, LayerNorm is bluish green, and Dropout is yellow.

plot vit
from collections import defaultdict

import matplotlib.pyplot as plt
import torch
import visualtorch
from torch import nn


class VisionTransformer(nn.Module):
    """A small Vision Transformer: patch embedding + positional embedding + Transformer encoder."""

    def __init__(
        self,
        img_size: int = 32,
        patch_size: int = 8,
        dim: int = 64,
        depth: int = 2,
        heads: int = 4,
        num_classes: int = 10,
    ) -> None:
        super().__init__()
        num_patches = (img_size // patch_size) ** 2
        self.patch_embed = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size)
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, dim))
        encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=dim * 4, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        self.head = nn.Linear(dim, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Split into patches, embed, add positional embedding, encode, then classify."""
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)
        x = x + self.pos_embed
        x = self.encoder(x)
        x = x.mean(dim=1)
        return self.head(x)


model = VisionTransformer()

input_shape = (1, 3, 32, 32)

color_map: dict = defaultdict(dict)
color_map[nn.Conv2d]["fill"] = "#E69F00"
color_map[nn.MultiheadAttention]["fill"] = "#CC79A7"
color_map[nn.Linear]["fill"] = "#56B4E9"
color_map[nn.LayerNorm]["fill"] = "#009E73"
color_map[nn.Dropout]["fill"] = "#F0E442"

img = visualtorch.render(model, input_shape, style="graph", show_neurons=False, color_map=color_map, layer_spacing=60)

dpi = 150  # rendered at 2x this in the final doc build (savefig.dpi=300 in conf.py)
plt.figure(figsize=(img.width / dpi, img.height / dpi), dpi=dpi)
plt.imshow(img)
plt.axis("off")
plt.tight_layout()
plt.show()

Gallery generated by Sphinx-Gallery