
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "usage_examples/flow/plot_unet_flow.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_usage_examples_flow_plot_unet_flow.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_usage_examples_flow_plot_unet_flow.py:

U-Net
=======================================

The same small U-Net as the ``graph`` style's example, rendered in ``flow`` style instead - the
contracting-then-expanding channel/spatial shape naturally produces the classic U-Net silhouette.

This relies on ``show_input`` (on by default): without it, the diagram's left edge would start at
the first encoder block's *output*, not the raw input, making the silhouette lopsided instead of
symmetric with the output on the right.

Conv2d is orange, ConvTranspose2d is green, ReLU is sky blue, and MaxPool2d is reddish purple.

.. GENERATED FROM PYTHON SOURCE LINES 13-68



.. image-sg:: /usage_examples/flow/images/sphx_glr_plot_unet_flow_001.png
   :alt: plot unet flow
   :srcset: /usage_examples/flow/images/sphx_glr_plot_unet_flow_001.png
   :class: sphx-glr-single-img





.. code-block:: Python


    from collections import defaultdict

    import matplotlib.pyplot as plt
    import torch
    import visualtorch
    from torch import nn


    class UNet(nn.Module):
        """A small U-Net with 2 encoder/decoder stages and skip connections."""

        def __init__(self) -> None:
            super().__init__()
            self.enc1 = nn.Sequential(nn.Conv2d(3, 16, kernel_size=3, padding=1), nn.ReLU())
            self.pool1 = nn.MaxPool2d(2)
            self.enc2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.ReLU())
            self.pool2 = nn.MaxPool2d(2)
            self.bottleneck = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU())
            self.up2 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
            self.dec2 = nn.Sequential(nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.ReLU())
            self.up1 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
            self.dec1 = nn.Sequential(nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU())
            self.final = nn.Conv2d(16, 2, kernel_size=1)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            """Downsample through 2 encoder stages, then upsample back, concatenating each
            encoder stage's output into its corresponding decoder stage.
            """  # noqa: D205
            e1 = self.enc1(x)
            e2 = self.enc2(self.pool1(e1))
            b = self.bottleneck(self.pool2(e2))
            d2 = self.dec2(torch.cat([self.up2(b), e2], dim=1))
            d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
            return self.final(d1)


    model = UNet()

    input_shape = (1, 3, 64, 64)

    color_map: dict = defaultdict(dict)
    color_map[nn.Conv2d]["fill"] = "#E69F00"
    color_map[nn.ConvTranspose2d]["fill"] = "#009E73"
    color_map[nn.ReLU]["fill"] = "#56B4E9"
    color_map[nn.MaxPool2d]["fill"] = "#CC79A7"

    img = visualtorch.render(model, input_shape, style="flow", color_map=color_map, scale_xy=3, spacing=15)

    dpi = 150  # rendered at 2x this in the final doc build (savefig.dpi=300 in conf.py)
    plt.figure(figsize=(img.width / dpi, img.height / dpi), dpi=dpi)
    plt.imshow(img)
    plt.axis("off")
    plt.tight_layout()
    plt.show()


.. _sphx_glr_download_usage_examples_flow_plot_unet_flow.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_unet_flow.ipynb <plot_unet_flow.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_unet_flow.py <plot_unet_flow.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: plot_unet_flow.zip <plot_unet_flow.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
