
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "usage_examples/flow/plot_multi_input_flow.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_usage_examples_flow_plot_multi_input_flow.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_usage_examples_flow_plot_multi_input_flow.py:

Multi-Input Model (Siamese-style)
=======================================

The same two-branch model as the ``graph`` style's multi-input example - an image branch
(Conv2d + global pooling) and a tabular-vector branch (a small MLP), merged by concatenation
before a shared head - rendered in ``flow`` style instead.

Pass a tuple of per-tensor shapes as ``input_shape`` instead of a single flat shape - one shape
per positional argument of ``forward()``, in order. Each input gets its own box at the start of
the diagram; with ``show_dimension=True``, a column with parallel boxes prints all of their
shapes together (joined by ``/``) rather than overlapping.

Conv2d is orange and Linear is sky blue.

.. GENERATED FROM PYTHON SOURCE LINES 15-76



.. image-sg:: /usage_examples/flow/images/sphx_glr_plot_multi_input_flow_001.png
   :alt: plot multi input flow
   :srcset: /usage_examples/flow/images/sphx_glr_plot_multi_input_flow_001.png
   :class: sphx-glr-single-img





.. code-block:: Python


    from collections import defaultdict

    import matplotlib.pyplot as plt
    import torch
    import visualtorch
    from torch import nn


    class SiameseNet(nn.Module):
        """A two-branch model: an image branch and a tabular-vector branch, merged by concatenation."""

        def __init__(self) -> None:
            super().__init__()
            self.image_branch = nn.Sequential(
                nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.AdaptiveAvgPool2d(1),
                nn.Flatten(),
            )
            self.vector_branch = nn.Sequential(
                nn.Linear(10, 8),
                nn.ReLU(),
                nn.Linear(8, 8),
                nn.ReLU(),
            )
            self.head = nn.Linear(16, 4)

        def forward(self, image: torch.Tensor, vector: torch.Tensor) -> torch.Tensor:
            """Run each branch on its own input tensor, then concatenate and project."""
            image_features = self.image_branch(image)
            vector_features = self.vector_branch(vector)
            merged = torch.cat([image_features, vector_features], dim=1)
            return self.head(merged)


    model = SiameseNet()

    # One shape per forward() argument: (image, vector).
    input_shape = ((1, 3, 16, 16), (1, 10))

    color_map: dict = defaultdict(dict)
    color_map[nn.Conv2d]["fill"] = "#E69F00"
    color_map[nn.Linear]["fill"] = "#56B4E9"

    img = visualtorch.render(
        model,
        input_shape,
        style="flow",
        color_map=color_map,
        scale_xy=3,
        spacing=15,
        show_dimension=True,
    )

    dpi = 150  # rendered at 2x this in the final doc build (savefig.dpi=300 in conf.py)
    plt.figure(figsize=(img.width / dpi, img.height / dpi), dpi=dpi)
    plt.imshow(img)
    plt.axis("off")
    plt.tight_layout()
    plt.show()


.. _sphx_glr_download_usage_examples_flow_plot_multi_input_flow.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_multi_input_flow.ipynb <plot_multi_input_flow.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_multi_input_flow.py <plot_multi_input_flow.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: plot_multi_input_flow.zip <plot_multi_input_flow.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
