
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "usage_examples/graph/plot_vit.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_usage_examples_graph_plot_vit.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_usage_examples_graph_plot_vit.py:

Vision Transformer (ViT)
=======================================

A small Vision Transformer: a Conv2d patch-embedding (splitting the image into non-overlapping
patches and projecting each to a vector), a learned positional embedding added on, and a
Transformer encoder over the resulting patch sequence.

Note: VisualTorch traces the literal module-by-module computation, so this shows the real
executed sequence of layers (Conv2d, MultiheadAttention, Linear, LayerNorm, Dropout, ...) - not
the conceptual/pedagogical diagram style used in ViT papers (patch grid illustration, a
"Transformer Encoder x N" labeled loop box, etc.), which are hand-designed for exposition rather
than traced from a real forward pass.

Conv2d is orange, MultiheadAttention is reddish purple, Linear is sky blue, LayerNorm is bluish
green, and Dropout is yellow.

.. GENERATED FROM PYTHON SOURCE LINES 17-75



.. image-sg:: /usage_examples/graph/images/sphx_glr_plot_vit_001.png
   :alt: plot vit
   :srcset: /usage_examples/graph/images/sphx_glr_plot_vit_001.png
   :class: sphx-glr-single-img





.. code-block:: Python


    from collections import defaultdict

    import matplotlib.pyplot as plt
    import torch
    import visualtorch
    from torch import nn


    class VisionTransformer(nn.Module):
        """A small Vision Transformer: patch embedding + positional embedding + Transformer encoder."""

        def __init__(
            self,
            img_size: int = 32,
            patch_size: int = 8,
            dim: int = 64,
            depth: int = 2,
            heads: int = 4,
            num_classes: int = 10,
        ) -> None:
            super().__init__()
            num_patches = (img_size // patch_size) ** 2
            self.patch_embed = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size)
            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, dim))
            encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=dim * 4, batch_first=True)
            self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
            self.head = nn.Linear(dim, num_classes)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            """Split into patches, embed, add positional embedding, encode, then classify."""
            x = self.patch_embed(x)
            x = x.flatten(2).transpose(1, 2)
            x = x + self.pos_embed
            x = self.encoder(x)
            x = x.mean(dim=1)
            return self.head(x)


    model = VisionTransformer()

    input_shape = (1, 3, 32, 32)

    color_map: dict = defaultdict(dict)
    color_map[nn.Conv2d]["fill"] = "#E69F00"
    color_map[nn.MultiheadAttention]["fill"] = "#CC79A7"
    color_map[nn.Linear]["fill"] = "#56B4E9"
    color_map[nn.LayerNorm]["fill"] = "#009E73"
    color_map[nn.Dropout]["fill"] = "#F0E442"

    img = visualtorch.render(model, input_shape, style="graph", show_neurons=False, color_map=color_map, layer_spacing=60)

    dpi = 150  # rendered at 2x this in the final doc build (savefig.dpi=300 in conf.py)
    plt.figure(figsize=(img.width / dpi, img.height / dpi), dpi=dpi)
    plt.imshow(img)
    plt.axis("off")
    plt.tight_layout()
    plt.show()


.. _sphx_glr_download_usage_examples_graph_plot_vit.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_vit.ipynb <plot_vit.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_vit.py <plot_vit.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: plot_vit.zip <plot_vit.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
