Multi-Branch Merge (Inception-style)

Multi-Branch Merge (Inception-style)#

The same Inception-style block as the graph style’s multi-branch example - four parallel branches (a plain Conv2d+BatchNorm2d, a 1x1-then-3x3 conv, a 1x1-then-5x5 conv, and a max-pool-then-1x1-conv) that all read the same input and merge into a shared projection layer - rendered in lenet style instead.

Conv2d is orange, BatchNorm2d is green, and MaxPool2d is reddish purple. Shape labels are turned off here (show_dimension=False) since parallel branches share a column, which would otherwise overlap several labels on top of each other.

plot inception block lenet style
from collections import defaultdict

import matplotlib.pyplot as plt
import torch
import visualtorch
from torch import nn


class InceptionBlock(nn.Module):
    """A simplified Inception-style block with four parallel branches."""

    def __init__(self, in_channels: int, out_1x1: int, out_3x3: int, out_5x5: int, out_pool: int) -> None:
        super().__init__()
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_channels, out_1x1, kernel_size=1),
            nn.BatchNorm2d(out_1x1),
        )
        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels, out_3x3, kernel_size=1),
            nn.Conv2d(out_3x3, out_3x3, kernel_size=3, padding=1),
        )
        self.branch3 = nn.Sequential(
            nn.Conv2d(in_channels, out_5x5, kernel_size=1),
            nn.Conv2d(out_5x5, out_5x5, kernel_size=5, padding=2),
        )
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels, out_pool, kernel_size=1),
        )
        total_channels = out_1x1 + out_3x3 + out_5x5 + out_pool
        self.project = nn.Conv2d(total_channels, total_channels, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Run every branch on the same input, then concatenate and project."""
        b1 = self.branch1(x)
        b2 = self.branch2(x)
        b3 = self.branch3(x)
        b4 = self.branch4(x)
        merged = torch.cat([b1, b2, b3, b4], dim=1)
        return self.project(merged)


model = InceptionBlock(in_channels=16, out_1x1=8, out_3x3=8, out_5x5=8, out_pool=8)

input_shape = (1, 16, 16, 16)

color_map: dict = defaultdict(dict)
color_map[nn.Conv2d]["fill"] = "#E69F00"
color_map[nn.BatchNorm2d]["fill"] = "#009E73"
color_map[nn.MaxPool2d]["fill"] = "#CC79A7"

img = visualtorch.render(model, input_shape, style="lenet", color_map=color_map, scale_xy=1.5, show_dimension=False)

dpi = 150  # rendered at 2x this in the final doc build (savefig.dpi=300 in conf.py)
plt.figure(figsize=(img.width / dpi, img.height / dpi), dpi=dpi)
plt.imshow(img)
plt.axis("off")
plt.tight_layout()
plt.show()

Gallery generated by Sphinx-Gallery