"""Vision Transformer (ViT)
=======================================

The same small Vision Transformer as the ``graph``/``flow`` styles' ViT examples, rendered in
``lenet`` style instead. A Transformer layer's ``(seq_len, hidden_size)`` shape has no real
spatial/channel structure, so ``hidden_size`` (the feature/channel-like dimension) drives the
stacked-plane depth the same way a CNN's channel count does, keeping every layer's look
consistent throughout the diagram.

``offset_z=1`` is set here - the default (``10``) is tuned for CNN layers' typically-modest
channel counts, and multiplies into a very wide image once every Transformer layer's
``hidden_size``/``dim_feedforward`` also drives depth.

Note: VisualTorch traces the literal module-by-module computation, so this shows the real
executed sequence of layers, not the conceptual/pedagogical diagram style used in ViT papers.
``show_dimension`` (on by default for this style) is turned off here - with this many
similarly-narrow boxes packed close together, the shape labels overlap into an unreadable mess.

Conv2d is orange, MultiheadAttention is reddish purple, Linear is sky blue, LayerNorm is bluish
green, and Dropout is yellow.
"""  # noqa: D205

from collections import defaultdict

import matplotlib.pyplot as plt
import torch
import visualtorch
from torch import nn


class VisionTransformer(nn.Module):
    """A small Vision Transformer: patch embedding + positional embedding + Transformer encoder."""

    def __init__(
        self,
        img_size: int = 32,
        patch_size: int = 8,
        dim: int = 64,
        depth: int = 2,
        heads: int = 4,
        num_classes: int = 10,
    ) -> None:
        super().__init__()
        num_patches = (img_size // patch_size) ** 2
        self.patch_embed = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size)
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, dim))
        encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=dim * 4, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        self.head = nn.Linear(dim, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Split into patches, embed, add positional embedding, encode, then classify."""
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)
        x = x + self.pos_embed
        x = self.encoder(x)
        x = x.mean(dim=1)
        return self.head(x)


model = VisionTransformer()

input_shape = (1, 3, 32, 32)

color_map: dict = defaultdict(dict)
color_map[nn.Conv2d]["fill"] = "#E69F00"
color_map[nn.MultiheadAttention]["fill"] = "#CC79A7"
color_map[nn.Linear]["fill"] = "#56B4E9"
color_map[nn.LayerNorm]["fill"] = "#009E73"
color_map[nn.Dropout]["fill"] = "#F0E442"

img = visualtorch.render(
    model,
    input_shape,
    style="lenet",
    color_map=color_map,
    spacing=80,
    padding=60,
    show_dimension=False,
    offset_z=1,
)

dpi = 150  # rendered at 2x this in the final doc build (savefig.dpi=300 in conf.py)
plt.figure(figsize=(img.width / dpi, img.height / dpi), dpi=dpi)
plt.imshow(img)
plt.axis("off")
plt.tight_layout()
plt.show()
