
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "usage_examples/lenet_style/plot_multi_output_shapes_lenet_style.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_usage_examples_lenet_style_plot_multi_output_shapes_lenet_style.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_usage_examples_lenet_style_plot_multi_output_shapes_lenet_style.py:

Multi-Output Layer Shapes
=======================================

A leaf layer's ``forward()`` doesn't always return a single tensor - ``nn.LSTM`` returns
``(output, (h_n, c_n))``: the full sequence of hidden states, plus the final hidden and cell
states. ``show_dimension`` defaults to ``True`` for this style, and every one of those output
tensors' shapes is printed, not just the first, so a downstream layer that consumes ``h_n``
instead of ``output`` (as this model does) doesn't leave its actual input shape unaccounted for.

``low_dim_orientation="x"`` is set here purely for a more compact, easier-to-read illustration -
the default (``"z"``) stacks a layer's feature count as individual depth slices, which for
``hidden_size=128``/``64`` here draws as a long diagonal chain rather than a compact block. Both
are equally valid, correct renderings of the same shapes; this is a legibility choice, the same as
this example's ``spacing``/``padding`` values below.

LSTM is sky blue and Linear is bluish green.

.. GENERATED FROM PYTHON SOURCE LINES 18-65



.. image-sg:: /usage_examples/lenet_style/images/sphx_glr_plot_multi_output_shapes_lenet_style_001.png
   :alt: plot multi output shapes lenet style
   :srcset: /usage_examples/lenet_style/images/sphx_glr_plot_multi_output_shapes_lenet_style_001.png
   :class: sphx-glr-single-img





.. code-block:: Python


    from collections import defaultdict

    import matplotlib.pyplot as plt
    import torch
    import visualtorch
    from torch import nn


    class SequenceClassifier(nn.Module):
        """A small LSTM-based classifier that reads the final hidden state, not the full sequence."""

        def __init__(self) -> None:
            super().__init__()
            self.lstm = nn.LSTM(input_size=64, hidden_size=128, batch_first=True)
            self.fc = nn.Linear(128, 10)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            """Run the LSTM, then classify from its final hidden state (h_n), not its full sequence output."""
            _output, (h_n, _c_n) = self.lstm(x)
            return self.fc(h_n.squeeze(0))


    model = SequenceClassifier()

    input_shape = (1, 7, 64)

    color_map: dict = defaultdict(dict)
    color_map[nn.LSTM]["fill"] = "#56B4E9"
    color_map[nn.Linear]["fill"] = "#009E73"

    img = visualtorch.render(
        model,
        input_shape,
        style="lenet",
        color_map=color_map,
        spacing=250,
        padding=60,
        low_dim_orientation="x",
    )

    dpi = 150  # rendered at 2x this in the final doc build (savefig.dpi=300 in conf.py)
    plt.figure(figsize=(img.width / dpi, img.height / dpi), dpi=dpi)
    plt.imshow(img)
    plt.axis("off")
    plt.tight_layout()
    plt.show()


.. _sphx_glr_download_usage_examples_lenet_style_plot_multi_output_shapes_lenet_style.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_multi_output_shapes_lenet_style.ipynb <plot_multi_output_shapes_lenet_style.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_multi_output_shapes_lenet_style.py <plot_multi_output_shapes_lenet_style.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: plot_multi_output_shapes_lenet_style.zip <plot_multi_output_shapes_lenet_style.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
