Skip to main content

Architecture Overview

The Inception-ResNet-v2 U-Net combines the powerful Inception-ResNet-v2 architecture as an encoder with a U-Net decoder, enabling multi-scale feature extraction with residual connections.

Key Features

  • Encoder: Inception-ResNet-v2 backbone
  • Inception blocks: Multi-scale convolutions (1×1, 3×3, 5×5)
  • Residual connections: Improved gradient flow
  • Block types: block35 (×10), block17 (×20), block8 (×10)
  • Decoder: U-Net-style upsampling with skip connections
  • Output: 2-class softmax segmentation

Inception-ResNet Block

The core building block combines Inception’s multi-scale processing with ResNet’s residual connections:
models/inception.py (88-166)
def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'):
    """Adds a Inception-ResNet block.
    This function builds 3 types of Inception-ResNet blocks mentioned
    in the paper, controlled by the `block_type` argument (which is the
    block name used in the official TF-slim implementation):
        - Inception-ResNet-A: `block_type='block35'`
        - Inception-ResNet-B: `block_type='block17'`
        - Inception-ResNet-C: `block_type='block8'`
    """
    if block_type == 'block35':
        branch_0 = conv2d_bn(x, 32, 1)
        branch_1 = conv2d_bn(x, 32, 1)
        branch_1 = conv2d_bn(branch_1, 32, 3)
        branch_2 = conv2d_bn(x, 32, 1)
        branch_2 = conv2d_bn(branch_2, 48, 3)
        branch_2 = conv2d_bn(branch_2, 64, 3)
        branches = [branch_0, branch_1, branch_2]
    elif block_type == 'block17':
        branch_0 = conv2d_bn(x, 192, 1)
        branch_1 = conv2d_bn(x, 128, 1)
        branch_1 = conv2d_bn(branch_1, 160, [1, 7])
        branch_1 = conv2d_bn(branch_1, 192, [7, 1])
        branches = [branch_0, branch_1]
    elif block_type == 'block8':
        branch_0 = conv2d_bn(x, 192, 1)
        branch_1 = conv2d_bn(x, 192, 1)
        branch_1 = conv2d_bn(branch_1, 224, [1, 3])
        branch_1 = conv2d_bn(branch_1, 256, [3, 1])
        branches = [branch_0, branch_1]
    else:
        raise ValueError('Unknown Inception-ResNet block type. '
                         'Expects "block35", "block17" or "block8", '
                         'but got: ' + str(block_type))

    block_name = block_type + '_' + str(block_idx)
    channel_axis = 3
    mixed = Concatenate(
        axis=channel_axis, name=block_name + '_mixed')(branches)
    up = conv2d_bn(mixed,
                   K.int_shape(x)[channel_axis],
                   1,
                   activation=None,
                   use_bias=True,
                   name=block_name + '_conv')

    x = Lambda(lambda inputs, scale: inputs[0] + inputs[1] * scale,
                      output_shape=K.int_shape(x)[1:],
                      arguments={'scale': scale},
                      name=block_name)([x, up])
    if activation is not None:
        x = Activation(activation, name=block_name + '_ac')(x)
    return x

Block Types

Used in early layers (35×35 resolution)Three branches:
  • Branch 0: 1×1 conv (32 filters)
  • Branch 1: 1×1 → 3×3 conv (32 → 32 filters)
  • Branch 2: 1×1 → 3×3 → 3×3 conv (32 → 48 → 64 filters)
Scaling factor: 0.17Repetitions: 10 blocks

Complete Model Architecture

models/inception.py (169-271)
def get_inception_resnet_v2_unet_softmax(input_shape, weights='imagenet'):
    n_channel = 3
    n_class = 2
    img_input = Input(input_shape + (n_channel,))
    
    # Stem block: 35 x 35 x 192
    x = conv2d_bn(img_input, 32, 3, strides=2, padding='same')
    x = conv2d_bn(x, 32, 3, padding='same')
    x = conv2d_bn(x, 64, 3)
    conv1 = x
    x = MaxPooling2D(3, strides=2, padding='same')(x)
    x = conv2d_bn(x, 80, 1, padding='same')
    x = conv2d_bn(x, 192, 3, padding='same')
    conv2 = x
    x = MaxPooling2D(3, strides=2, padding='same')(x)

    # Mixed 5b (Inception-A block): 35 x 35 x 320
    branch_0 = conv2d_bn(x, 96, 1)
    branch_1 = conv2d_bn(x, 48, 1)
    branch_1 = conv2d_bn(branch_1, 64, 5)
    branch_2 = conv2d_bn(x, 64, 1)
    branch_2 = conv2d_bn(branch_2, 96, 3)
    branch_2 = conv2d_bn(branch_2, 96, 3)
    branch_pool = AveragePooling2D(3, strides=1, padding='same')(x)
    branch_pool = conv2d_bn(branch_pool, 64, 1)
    branches = [branch_0, branch_1, branch_2, branch_pool]
    channel_axis = 1 if K.image_data_format() == 'channels_first' else 3
    x = Concatenate(axis=channel_axis, name='mixed_5b')(branches)

    # 10x block35 (Inception-ResNet-A block): 35 x 35 x 320
    for block_idx in range(1, 11):
        x = inception_resnet_block(x,
                                   scale=0.17,
                                   block_type='block35',
                                   block_idx=block_idx)
    conv3 = x
    
    # Mixed 6a (Reduction-A block): 17 x 17 x 1088
    branch_0 = conv2d_bn(x, 384, 3, strides=2, padding='same')
    branch_1 = conv2d_bn(x, 256, 1)
    branch_1 = conv2d_bn(branch_1, 256, 3)
    branch_1 = conv2d_bn(branch_1, 384, 3, strides=2, padding='same')
    branch_pool = MaxPooling2D(3, strides=2, padding='same')(x)
    branches = [branch_0, branch_1, branch_pool]
    x = Concatenate(axis=channel_axis, name='mixed_6a')(branches)

    # 20x block17 (Inception-ResNet-B block): 17 x 17 x 1088
    for block_idx in range(1, 21):
        x = inception_resnet_block(x,
                                   scale=0.1,
                                   block_type='block17',
                                   block_idx=block_idx)
    conv4 = x
    
    # Mixed 7a (Reduction-B block): 8 x 8 x 2080
    branch_0 = conv2d_bn(x, 256, 1)
    branch_0 = conv2d_bn(branch_0, 384, 3, strides=2, padding='same')
    branch_1 = conv2d_bn(x, 256, 1)
    branch_1 = conv2d_bn(branch_1, 288, 3, strides=2, padding='same')
    branch_2 = conv2d_bn(x, 256, 1)
    branch_2 = conv2d_bn(branch_2, 288, 3)
    branch_2 = conv2d_bn(branch_2, 320, 3, strides=2, padding='same')
    branch_pool = MaxPooling2D(3, strides=2, padding='same')(x)
    branches = [branch_0, branch_1, branch_2, branch_pool]
    x = Concatenate(axis=channel_axis, name='mixed_7a')(branches)

    # 10x block8 (Inception-ResNet-C block): 8 x 8 x 2080
    for block_idx in range(1, 10):
        x = inception_resnet_block(x,
                                   scale=0.2,
                                   block_type='block8',
                                   block_idx=block_idx)
    x = inception_resnet_block(x,
                               scale=1.,
                               activation=None,
                               block_type='block8',
                               block_idx=10)

    # Final convolution block: 8 x 8 x 1536
    x = conv2d_bn(x, 1536, 1, name='conv_7b')
    conv5 = x
    
    # U-Net Decoder with skip connections
    conv6 = conv_block(UpSampling2D()(conv5), 320)
    conv6 = concatenate([conv6, conv4], axis=-1)
    conv6 = conv_block(conv6, 320)

    conv7 = conv_block(UpSampling2D()(conv6), 256)
    conv7 = concatenate([conv7, conv3], axis=-1)  
    conv7 = conv_block(conv7, 256)

    conv8 = conv_block(UpSampling2D()(conv7), 128)
    conv8 = concatenate([conv8, conv2], axis=-1)
    conv8 = conv_block(conv8, 128)

    conv9 = conv_block(UpSampling2D()(conv8), 96)
    conv9 = concatenate([conv9, conv1], axis=-1)
    conv9 = conv_block(conv9, 96)

    conv10 = conv_block(UpSampling2D()(conv9), 64)
    conv10 = conv_block(conv10, 64)
    res = Conv2D(n_class, (1, 1), activation='softmax')(conv10)
    
    model = Model(img_input, res)

    return model

Network Structure

StageBlock TypeBlocksOutput ChannelsResolution
StemConv364H/2 × W/2
-Pool-64H/4 × W/4
StemConv2192H/4 × W/4
-Pool-192H/8 × W/8
Mixed 5bInception-A1320H/8 × W/8
conv3Block3510320H/8 × W/8
Mixed 6aReduction-A11088H/16 × W/16
conv4Block17201088H/16 × W/16
Mixed 7aReduction-B12080H/32 × W/32
conv5Block8101536H/32 × W/32

Advantages of Inception-ResNet-v2

Multi-Scale Feature Extraction

  • Parallel branches: Captures features at different scales simultaneously
  • Factorized convolutions: 1×7 and 7×1 convs reduce parameters
  • Efficient computation: Smaller kernels with similar receptive fields

Residual Connections

  • Gradient flow: Scaled residual connections prevent vanishing gradients
  • Training stability: Easier optimization of very deep networks
  • Adaptive scaling: Different scale factors per block type

U-Net Integration

  • Skip connections: Preserves spatial details from encoder
  • Progressive reconstruction: Gradual upsampling to original resolution
  • Multi-level features: Combines semantic and spatial information

Model Weights

Pretrained Weights

The model can optionally load ImageNet-pretrained weights:
# With ImageNet pretraining (encoder only)
model = get_inception_resnet_v2_unet_softmax(
    input_shape=(256, 256), 
    weights='imagenet'
)

# Random initialization
model = get_inception_resnet_v2_unet_softmax(
    input_shape=(256, 256), 
    weights=None
)

DigiPathAI Weights

Task-specific weights are available:
  • digestpath_inception.h5: Trained on DigestPath dataset
  • paip_inception.h5: Trained on PAIP dataset
  • camelyon_inception.h5: Trained on Camelyon dataset

Input/Output Specifications

Input

  • Shape: (batch, height, width, 3)
  • Flexible dimensions: (None, None, 3) for variable sizes
  • Preprocessing: Standard ImageNet normalization

Output

  • Shape: (batch, height, width, 2)
  • Classes: [background, tissue]
  • Activation: Softmax probabilities
The Inception-ResNet-v2 U-Net excels at capturing multi-scale features through its parallel Inception blocks, making it particularly effective for complex tissue structures with varying scales.

Usage Example

from DigiPathAI.models.inception import get_inception_resnet_v2_unet_softmax
from DigiPathAI.helpers.utils import load_trained_models

# Create model architecture
model = get_inception_resnet_v2_unet_softmax(
    input_shape=(256, 256), 
    weights=None
)

# Load pretrained weights
model = load_trained_models(
    model='inception',
    path='~/.DigiPathAI/digestpath_models/digestpath_inception.h5',
    patch_size=256
)

# Predict on image patch
import numpy as np
image_patch = np.random.rand(1, 256, 256, 3)
prediction = model.predict(image_patch)

Performance Characteristics

  • Parameters: ~55M parameters
  • Depth: 164 layers (stem + 40 inception-resnet blocks + decoder)
  • Memory: Higher memory footprint due to multiple branches
  • Speed: Moderate inference speed
  • Accuracy: Excellent for multi-scale features

DenseNet U-Net

Dense connectivity for feature reuse

DeepLabv3+

Atrous spatial pyramid pooling

Build docs developers (and LLMs) love