
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "usage_examples/layered/plot_basic_custom.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_usage_examples_layered_plot_basic_custom.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_usage_examples_layered_plot_basic_custom.py:

Basic Custom
=======================================

Visualization of basic custom model

.. GENERATED FROM PYTHON SOURCE LINES 6-54



.. image-sg:: /usage_examples/layered/images/sphx_glr_plot_basic_custom_001.png
   :alt: plot basic custom
   :srcset: /usage_examples/layered/images/sphx_glr_plot_basic_custom_001.png
   :class: sphx-glr-single-img





.. code-block:: Python


    import matplotlib.pyplot as plt
    import torch
    import torch.nn.functional as func
    import visualtorch
    from torch import nn


    # Example of a simple CNN model
    class SimpleCNN(nn.Module):
        """Simple CNN Model."""

        def __init__(self) -> None:
            super().__init__()
            self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
            self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
            self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
            self.fc1 = nn.Linear(64 * 28 * 28, 128)
            self.fc2 = nn.Linear(128, 10)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            """Define the forward pass."""
            x = self.conv1(x)
            x = func.relu(x)
            x = func.max_pool2d(x, 2, 2)
            x = self.conv2(x)
            x = func.relu(x)
            x = func.max_pool2d(x, 2, 2)
            x = self.conv3(x)
            x = func.relu(x)
            x = func.max_pool2d(x, 2, 2)
            x = x.view(x.size(0), -1)
            x = self.fc1(x)
            x = func.relu(x)
            return self.fc2(x)


    # Create an instance of the SimpleCNN
    model = SimpleCNN()

    input_shape = (1, 3, 224, 224)

    img = visualtorch.layered_view(model, input_shape=input_shape, legend=True)

    plt.axis("off")
    plt.tight_layout()
    plt.imshow(img)
    plt.show()


.. _sphx_glr_download_usage_examples_layered_plot_basic_custom.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_basic_custom.ipynb <plot_basic_custom.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_basic_custom.py <plot_basic_custom.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: plot_basic_custom.zip <plot_basic_custom.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
