
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "usage_examples/lenet_style/plot_unet_lenet_style.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_usage_examples_lenet_style_plot_unet_lenet_style.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_usage_examples_lenet_style_plot_unet_lenet_style.py:

U-Net
=======================================

The same small U-Net as the ``graph``/``flow`` styles' examples, rendered in ``lenet`` style
instead - the contracting-then-expanding channel/spatial shape naturally produces the classic
U-Net silhouette.

Conv2d is orange, ConvTranspose2d is green, ReLU is sky blue, and MaxPool2d is reddish purple.

.. GENERATED FROM PYTHON SOURCE LINES 10-65



.. image-sg:: /usage_examples/lenet_style/images/sphx_glr_plot_unet_lenet_style_001.png
   :alt: plot unet lenet style
   :srcset: /usage_examples/lenet_style/images/sphx_glr_plot_unet_lenet_style_001.png
   :class: sphx-glr-single-img





.. code-block:: Python


    from collections import defaultdict

    import matplotlib.pyplot as plt
    import torch
    import visualtorch
    from torch import nn


    class UNet(nn.Module):
        """A small U-Net with 2 encoder/decoder stages and skip connections."""

        def __init__(self) -> None:
            super().__init__()
            self.enc1 = nn.Sequential(nn.Conv2d(3, 16, kernel_size=3, padding=1), nn.ReLU())
            self.pool1 = nn.MaxPool2d(2)
            self.enc2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.ReLU())
            self.pool2 = nn.MaxPool2d(2)
            self.bottleneck = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU())
            self.up2 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
            self.dec2 = nn.Sequential(nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.ReLU())
            self.up1 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
            self.dec1 = nn.Sequential(nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU())
            self.final = nn.Conv2d(16, 2, kernel_size=1)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            """Downsample through 2 encoder stages, then upsample back, concatenating each
            encoder stage's output into its corresponding decoder stage.
            """  # noqa: D205
            e1 = self.enc1(x)
            e2 = self.enc2(self.pool1(e1))
            b = self.bottleneck(self.pool2(e2))
            d2 = self.dec2(torch.cat([self.up2(b), e2], dim=1))
            d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
            return self.final(d1)


    model = UNet()

    input_shape = (1, 3, 64, 64)

    color_map: dict = defaultdict(dict)
    color_map[nn.Conv2d]["fill"] = "#E69F00"
    color_map[nn.ConvTranspose2d]["fill"] = "#009E73"
    color_map[nn.ReLU]["fill"] = "#56B4E9"
    color_map[nn.MaxPool2d]["fill"] = "#CC79A7"

    img = visualtorch.render(model, input_shape, style="lenet", color_map=color_map, spacing=25)

    dpi = 150  # rendered at 2x this in the final doc build (savefig.dpi=300 in conf.py)
    plt.figure(figsize=(img.width / dpi, img.height / dpi), dpi=dpi)
    plt.imshow(img)
    plt.axis("off")
    plt.tight_layout()
    plt.show()


.. _sphx_glr_download_usage_examples_lenet_style_plot_unet_lenet_style.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_unet_lenet_style.ipynb <plot_unet_lenet_style.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_unet_lenet_style.py <plot_unet_lenet_style.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: plot_unet_lenet_style.zip <plot_unet_lenet_style.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
