Custom Color

Custom Color#

Visualization of custom color. The synthetic input box can be recolored too, keyed by visualtorch.Input in color_map just like any real layer type - left uncustomized, it would default to the same color as Conv2d here, since both would otherwise claim the same slot in the color wheel.

plot custom color
from collections import defaultdict

import matplotlib.pyplot as plt
import visualtorch
from torch import nn

# Example of a simple CNN model using nn.Sequential
model = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(16, 32, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(32, 64, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Flatten(),
    nn.Linear(64 * 28 * 28, 256),  # Adjusted the input size for the Linear layer
    nn.ReLU(),
    nn.Linear(256, 10),  # Assuming 10 output classes
)

color_map: dict = defaultdict(dict)
color_map[visualtorch.Input]["fill"] = "#D55E00"  # vermillion
color_map[nn.Conv2d]["fill"] = "#E69F00"  # orange
color_map[nn.ReLU]["fill"] = "#56B4E9"  # sky blue
color_map[nn.MaxPool2d]["fill"] = "#CC79A7"  # reddish purple
color_map[nn.Flatten]["fill"] = "#009E73"  # bluish green
color_map[nn.Linear]["fill"] = "#0072B2"  # blue

input_shape = (1, 3, 224, 224)
img = visualtorch.render(model, input_shape=input_shape, style="flow", color_map=color_map)

dpi = 150  # rendered at 2x this in the final doc build (savefig.dpi=300 in conf.py)
plt.figure(figsize=(img.width / dpi, img.height / dpi), dpi=dpi)
plt.imshow(img)
plt.axis("off")
plt.tight_layout()
plt.show()

Gallery generated by Sphinx-Gallery