
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "usage_examples/graph/plot_residual_block.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_usage_examples_graph_plot_residual_block.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_usage_examples_graph_plot_residual_block.py:

Hiding Individual Neurons
=======================================

By default, ``graph_view`` draws a fully-connected mesh between every pair of adjacent layers'
neuron circles. That's accurate for a genuinely dense layer (e.g. Linear), but misleading for a
convolutional one - a Conv2d's real connectivity is local and shared across spatial positions,
not "every input channel wired to every output channel." Setting ``show_neurons=False`` draws
each layer as a single box instead, which is the more honest representation for a conv-heavy
model.

The model used here is a classic ResNet-style residual block (Conv2d + BatchNorm2d, twice, with
a plain identity shortcut and a final ReLU) - conv-heavy and branching, a good stress test for
both this setting and graph_view's skip-connection routing. Conv2d is orange, BatchNorm2d is
green, and ReLU is sky blue.

.. GENERATED FROM PYTHON SOURCE LINES 16-62



.. image-sg:: /usage_examples/graph/images/sphx_glr_plot_residual_block_001.png
   :alt: plot residual block
   :srcset: /usage_examples/graph/images/sphx_glr_plot_residual_block_001.png
   :class: sphx-glr-single-img





.. code-block:: Python


    from collections import defaultdict

    import matplotlib.pyplot as plt
    import torch
    import visualtorch
    from torch import nn


    class ResidualBlock(nn.Module):
        """A classic ResNet-style block with a plain identity shortcut."""

        def __init__(self, channels: int) -> None:
            super().__init__()
            self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
            self.bn1 = nn.BatchNorm2d(channels)
            self.relu = nn.ReLU()
            self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
            self.bn2 = nn.BatchNorm2d(channels)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            """Define the forward pass, with a skip connection around conv1/bn1/relu/conv2/bn2."""
            identity = x
            out = self.relu(self.bn1(self.conv1(x)))
            out = self.bn2(self.conv2(out))
            out = out + identity
            return self.relu(out)


    model = ResidualBlock(channels=8)

    input_shape = (1, 8, 16, 16)

    color_map: dict = defaultdict(dict)
    color_map[nn.Conv2d]["fill"] = "#E69F00"
    color_map[nn.BatchNorm2d]["fill"] = "#009E73"
    color_map[nn.ReLU]["fill"] = "#56B4E9"

    img = visualtorch.render(model, input_shape, style="graph", show_neurons=False, color_map=color_map, layer_spacing=60)

    dpi = 150  # rendered at 2x this in the final doc build (savefig.dpi=300 in conf.py)
    plt.figure(figsize=(img.width / dpi, img.height / dpi), dpi=dpi)
    plt.imshow(img)
    plt.axis("off")
    plt.tight_layout()
    plt.show()


.. _sphx_glr_download_usage_examples_graph_plot_residual_block.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_residual_block.ipynb <plot_residual_block.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_residual_block.py <plot_residual_block.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: plot_residual_block.zip <plot_residual_block.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
