
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "usage_examples/lenet_style/plot_vit_lenet_style.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_usage_examples_lenet_style_plot_vit_lenet_style.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_usage_examples_lenet_style_plot_vit_lenet_style.py:

Vision Transformer (ViT)
=======================================

The same small Vision Transformer as the ``graph``/``flow`` styles' ViT examples, rendered in
``lenet`` style instead. A Transformer layer's ``(seq_len, hidden_size)`` shape has no real
spatial/channel structure, so ``hidden_size`` (the feature/channel-like dimension) drives the
stacked-plane depth the same way a CNN's channel count does, keeping every layer's look
consistent throughout the diagram.

``offset_z=1`` is set here - the default (``10``) is tuned for CNN layers' typically-modest
channel counts, and multiplies into a very wide image once every Transformer layer's
``hidden_size``/``dim_feedforward`` also drives depth.

Note: VisualTorch traces the literal module-by-module computation, so this shows the real
executed sequence of layers, not the conceptual/pedagogical diagram style used in ViT papers.
``show_dimension`` (on by default for this style) is turned off here - with this many
similarly-narrow boxes packed close together, the shape labels overlap into an unreadable mess.

Conv2d is orange, MultiheadAttention is reddish purple, Linear is sky blue, LayerNorm is bluish
green, and Dropout is yellow.

.. GENERATED FROM PYTHON SOURCE LINES 22-89



.. image-sg:: /usage_examples/lenet_style/images/sphx_glr_plot_vit_lenet_style_001.png
   :alt: plot vit lenet style
   :srcset: /usage_examples/lenet_style/images/sphx_glr_plot_vit_lenet_style_001.png
   :class: sphx-glr-single-img





.. code-block:: Python


    from collections import defaultdict

    import matplotlib.pyplot as plt
    import torch
    import visualtorch
    from torch import nn


    class VisionTransformer(nn.Module):
        """A small Vision Transformer: patch embedding + positional embedding + Transformer encoder."""

        def __init__(
            self,
            img_size: int = 32,
            patch_size: int = 8,
            dim: int = 64,
            depth: int = 2,
            heads: int = 4,
            num_classes: int = 10,
        ) -> None:
            super().__init__()
            num_patches = (img_size // patch_size) ** 2
            self.patch_embed = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size)
            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, dim))
            encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=dim * 4, batch_first=True)
            self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
            self.head = nn.Linear(dim, num_classes)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            """Split into patches, embed, add positional embedding, encode, then classify."""
            x = self.patch_embed(x)
            x = x.flatten(2).transpose(1, 2)
            x = x + self.pos_embed
            x = self.encoder(x)
            x = x.mean(dim=1)
            return self.head(x)


    model = VisionTransformer()

    input_shape = (1, 3, 32, 32)

    color_map: dict = defaultdict(dict)
    color_map[nn.Conv2d]["fill"] = "#E69F00"
    color_map[nn.MultiheadAttention]["fill"] = "#CC79A7"
    color_map[nn.Linear]["fill"] = "#56B4E9"
    color_map[nn.LayerNorm]["fill"] = "#009E73"
    color_map[nn.Dropout]["fill"] = "#F0E442"

    img = visualtorch.render(
        model,
        input_shape,
        style="lenet",
        color_map=color_map,
        spacing=80,
        padding=60,
        show_dimension=False,
        offset_z=1,
    )

    dpi = 150  # rendered at 2x this in the final doc build (savefig.dpi=300 in conf.py)
    plt.figure(figsize=(img.width / dpi, img.height / dpi), dpi=dpi)
    plt.imshow(img)
    plt.axis("off")
    plt.tight_layout()
    plt.show()


.. _sphx_glr_download_usage_examples_lenet_style_plot_vit_lenet_style.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_vit_lenet_style.ipynb <plot_vit_lenet_style.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_vit_lenet_style.py <plot_vit_lenet_style.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: plot_vit_lenet_style.zip <plot_vit_lenet_style.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
