
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "usage_examples/graph/plot_residual_block.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_usage_examples_graph_plot_residual_block.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_usage_examples_graph_plot_residual_block.py:

Residual Block
=======================================

Visualization of a classic ResNet-style residual block: Conv2d + BatchNorm2d, twice, with a
plain identity shortcut around them and a final ReLU.

Conv2d is orange, BatchNorm2d is green, and ReLU is salmon.

.. GENERATED FROM PYTHON SOURCE LINES 9-53



.. image-sg:: /usage_examples/graph/images/sphx_glr_plot_residual_block_001.png
   :alt: plot residual block
   :srcset: /usage_examples/graph/images/sphx_glr_plot_residual_block_001.png
   :class: sphx-glr-single-img





.. code-block:: Python


    from collections import defaultdict

    import matplotlib.pyplot as plt
    import torch
    import visualtorch
    from torch import nn


    class ResidualBlock(nn.Module):
        """A classic ResNet-style block with a plain identity shortcut."""

        def __init__(self, channels: int) -> None:
            super().__init__()
            self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
            self.bn1 = nn.BatchNorm2d(channels)
            self.relu = nn.ReLU()
            self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
            self.bn2 = nn.BatchNorm2d(channels)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            """Define the forward pass, with a skip connection around conv1/bn1/relu/conv2/bn2."""
            identity = x
            out = self.relu(self.bn1(self.conv1(x)))
            out = self.bn2(self.conv2(out))
            out = out + identity
            return self.relu(out)


    model = ResidualBlock(channels=8)

    input_shape = (1, 8, 16, 16)

    color_map: dict = defaultdict(dict)
    color_map[nn.Conv2d]["fill"] = "#FFE4B5"
    color_map[nn.BatchNorm2d]["fill"] = "#98FB98"
    color_map[nn.ReLU]["fill"] = "#FFA07A"

    img = visualtorch.render(model, input_shape, style="graph", show_neurons=False, color_map=color_map, layer_spacing=60)

    plt.axis("off")
    plt.tight_layout()
    plt.imshow(img)
    plt.show()


.. _sphx_glr_download_usage_examples_graph_plot_residual_block.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_residual_block.ipynb <plot_residual_block.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_residual_block.py <plot_residual_block.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: plot_residual_block.zip <plot_residual_block.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
