
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "usage_examples/graph/plot_inception_block.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_usage_examples_graph_plot_inception_block.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_usage_examples_graph_plot_inception_block.py:

Multi-Branch Merge (Inception-style)
=======================================

Visualization of an Inception-style block: four parallel branches (a plain Conv2d+BatchNorm2d,
a 1x1-then-3x3 conv, a 1x1-then-5x5 conv, and a max-pool-then-1x1-conv) that all read the same
input and merge into a shared projection layer.

Unlike a residual block's single shortcut, this shows multiple layers stacked side by side
within one column - one column per branch "depth," with as many parallel boxes in a column as
there are branches still active at that depth.

Conv2d is orange, BatchNorm2d is green, and MaxPool2d is reddish purple.

.. GENERATED FROM PYTHON SOURCE LINES 14-74



.. image-sg:: /usage_examples/graph/images/sphx_glr_plot_inception_block_001.png
   :alt: plot inception block
   :srcset: /usage_examples/graph/images/sphx_glr_plot_inception_block_001.png
   :class: sphx-glr-single-img





.. code-block:: Python


    from collections import defaultdict

    import matplotlib.pyplot as plt
    import torch
    import visualtorch
    from torch import nn


    class InceptionBlock(nn.Module):
        """A simplified Inception-style block with four parallel branches."""

        def __init__(self, in_channels: int, out_1x1: int, out_3x3: int, out_5x5: int, out_pool: int) -> None:
            super().__init__()
            self.branch1 = nn.Sequential(
                nn.Conv2d(in_channels, out_1x1, kernel_size=1),
                nn.BatchNorm2d(out_1x1),
            )
            self.branch2 = nn.Sequential(
                nn.Conv2d(in_channels, out_3x3, kernel_size=1),
                nn.Conv2d(out_3x3, out_3x3, kernel_size=3, padding=1),
            )
            self.branch3 = nn.Sequential(
                nn.Conv2d(in_channels, out_5x5, kernel_size=1),
                nn.Conv2d(out_5x5, out_5x5, kernel_size=5, padding=2),
            )
            self.branch4 = nn.Sequential(
                nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
                nn.Conv2d(in_channels, out_pool, kernel_size=1),
            )
            total_channels = out_1x1 + out_3x3 + out_5x5 + out_pool
            self.project = nn.Conv2d(total_channels, total_channels, kernel_size=1)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            """Run every branch on the same input, then concatenate and project."""
            b1 = self.branch1(x)
            b2 = self.branch2(x)
            b3 = self.branch3(x)
            b4 = self.branch4(x)
            merged = torch.cat([b1, b2, b3, b4], dim=1)
            return self.project(merged)


    model = InceptionBlock(in_channels=16, out_1x1=8, out_3x3=8, out_5x5=8, out_pool=8)

    input_shape = (1, 16, 16, 16)

    color_map: dict = defaultdict(dict)
    color_map[nn.Conv2d]["fill"] = "#E69F00"
    color_map[nn.BatchNorm2d]["fill"] = "#009E73"
    color_map[nn.MaxPool2d]["fill"] = "#CC79A7"

    img = visualtorch.render(model, input_shape, style="graph", show_neurons=False, color_map=color_map, layer_spacing=60)

    dpi = 150  # rendered at 2x this in the final doc build (savefig.dpi=300 in conf.py)
    plt.figure(figsize=(img.width / dpi, img.height / dpi), dpi=dpi)
    plt.imshow(img)
    plt.axis("off")
    plt.tight_layout()
    plt.show()


.. _sphx_glr_download_usage_examples_graph_plot_inception_block.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_inception_block.ipynb <plot_inception_block.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_inception_block.py <plot_inception_block.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: plot_inception_block.zip <plot_inception_block.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
