Multi-Output Layer Shapes

Multi-Output Layer Shapes#

A leaf layer’s forward() doesn’t always return a single tensor - nn.LSTM returns (output, (h_n, c_n)): the full sequence of hidden states, plus the final hidden and cell states. With show_dimension=True, every one of those output tensors’ shapes is printed, not just the first, so a downstream layer that consumes h_n instead of output (as this model does) doesn’t leave its actual input shape unaccounted for.

LSTM is sky blue and Linear is bluish green.

plot multi output shapes
from collections import defaultdict

import matplotlib.pyplot as plt
import torch
import visualtorch
from torch import nn


class SequenceClassifier(nn.Module):
    """A small LSTM-based classifier that reads the final hidden state, not the full sequence."""

    def __init__(self) -> None:
        super().__init__()
        self.lstm = nn.LSTM(input_size=64, hidden_size=128, batch_first=True)
        self.fc = nn.Linear(128, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Run the LSTM, then classify from its final hidden state (h_n), not its full sequence output."""
        _output, (h_n, _c_n) = self.lstm(x)
        return self.fc(h_n.squeeze(0))


model = SequenceClassifier()

input_shape = (1, 7, 64)

color_map: dict = defaultdict(dict)
color_map[nn.LSTM]["fill"] = "#56B4E9"
color_map[nn.Linear]["fill"] = "#009E73"

img = visualtorch.render(
    model,
    input_shape,
    style="graph",
    show_neurons=False,
    color_map=color_map,
    show_dimension=True,
    layer_spacing=100,
)

dpi = 150  # rendered at 2x this in the final doc build (savefig.dpi=300 in conf.py)
plt.figure(figsize=(img.width / dpi, img.height / dpi), dpi=dpi)
plt.imshow(img)
plt.axis("off")
plt.tight_layout()
plt.show()

Gallery generated by Sphinx-Gallery