"""
Core registration estimator for Twocan multimodal image registration.
This module implements the main RegEstimator class, which provides a
scikit-learn compatible interface for feature-based image registration
using OpenCV's ORB detector and affine transformation estimation.
"""
from typing import Optional, Tuple, Dict, Union, Any
import cv2
import numpy as np
from abc import ABC
from skimage import transform
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_is_fitted
from .utils import stretch_255
[docs]
class RegEstimator(TransformerMixin, BaseEstimator, ABC):
"""A scikit-learn compatible estimator for multimodal image registration.
This class implements feature-based image registration using OpenCV's ORB
(Oriented FAST and Rotated BRIEF) detector and a partial affine transformation
model. It follows scikit-learn's estimator API with fit, transform, and
fit_transform methods, making it easy to integrate into machine learning pipelines.
The registration process consists of:
1. Feature detection using ORB on both images
2. Feature matching using brute-force Hamming distance
3. Affine transformation estimation using RANSAC
4. Image transformation using the estimated parameters
Parameters
----------
registration_max_features : int, default=10000
Maximum number of features to detect in each image using ORB.
Higher values can improve registration accuracy but increase computation time.
registration_percentile : float, default=0.9
Percentile of features to keep after sorting by match quality (0-1).
Only the top percentile of matches by distance are used for transformation
estimation, which helps remove outliers.
Attributes
----------
M_ : np.ndarray
The estimated 2x3 affine transformation matrix after fitting.
Shape is (2, 3) representing the transformation [R|t] where R is
rotation/scaling and t is translation.
y_shape_ : Tuple[int, int]
Shape (height, width) of the target image used during fitting.
Used as the default output shape for transformations.
Notes
-----
The estimator automatically converts input images to 8-bit grayscale for
feature detection using the stretch_255 utility function. This ensures
consistent feature detection regardless of input image dynamic range.
The partial affine transformation model allows rotation, scaling, and
translation but not shearing, which is appropriate for most microscopy
registration tasks where imaging geometry is approximately preserved.
For best results:
- Ensure sufficient overlap between images
- Use images with distinct features (not uniform regions)
- Consider preprocessing to enhance relevant structures
- Adjust max_features based on image complexity and computational budget
"""
def __init__(self, registration_max_features: int = 10000, registration_percentile: float = 0.9):
self.registration_max_features = registration_max_features
self.registration_percentile = registration_percentile
[docs]
def fit(self, X: np.ndarray, y: np.ndarray) -> 'RegEstimator':
"""Estimate the affine transformation matrix between source (X) and target (y) images.
This method detects features in both images using ORB, matches them, and
estimates the best affine transformation that maps source features to
target features using OpenCV's robust estimation.
Parameters
----------
X : np.ndarray
Source image to be registered. Shape can be (H, W) for single-channel
or (C, H, W) for multi-channel. If multi-channel, all channels are
summed for feature detection.
y : np.ndarray
Target (reference) image to register to. Shape can be (H, W) for
single-channel or (C, H, W) for multi-channel. If multi-channel,
all channels are summed for feature detection.
Returns
-------
self : RegEstimator
The fitted estimator with estimated transformation matrix in ``self.M_``.
Raises
------
cv2.error
If affine transformation cannot be estimated, typically due to
insufficient or poorly matched features.
Notes
-----
The fitting process:
1. Convert images to 8-bit for ORB compatibility
2. Detect up to max_features keypoints in each image
3. Compute ORB descriptors for each keypoint
4. Match descriptors using brute-force Hamming distance
5. Keep top percentile of matches by distance
6. Estimate partial affine transformation using RANSAC
The method uses OpenCV's estimateAffinePartial2D which finds the optimal
similarity transformation (rotation, scaling, translation) rather than
a full affine transformation. This is more robust for most registration
scenarios.
"""
X = stretch_255(X.copy())
y = stretch_255(y.copy())
# orb detector
orb = cv2.ORB_create(self.registration_max_features, fastThreshold=0, edgeThreshold=0)
(kpsA, descsA) = orb.detectAndCompute(X, None)
(kpsB, descsB) = orb.detectAndCompute(y, None)
# match the features
method = cv2.DESCRIPTOR_MATCHER_BRUTEFORCE_HAMMING
matcher = cv2.DescriptorMatcher_create(method)
matches = matcher.match(descsA, descsB, None)
# sort the matches by their distance (the smaller the distance, the "more similar" the features are)
matches = sorted(matches, key=lambda x:x.distance)
# keep only the top percentile matches
keep = int(len(matches) * self.registration_percentile)
matches = matches[:keep]
ptsA = np.array([kpsA[m.queryIdx].pt for m in matches])
ptsB = np.array([kpsB[m.trainIdx].pt for m in matches])
# register
M, _mask = cv2.estimateAffinePartial2D(ptsA, ptsB)
self.M_ = M
self.y_shape_ = y.shape
return self
[docs]
def score(self, source: np.ndarray, target: np.ndarray) -> Dict[str, float]:
"""Calculate registration quality metrics between source and target images.
This method computes various metrics to assess the quality of registration
between binary or continuous-valued images. Metrics are calculated only
in regions where both images have valid data after transformation.
Parameters
----------
source : np.ndarray
Source image, shape (H, W). Should be the same image used for fitting
or a similar image from the same modality.
target : np.ndarray
Target image, shape (H, W). Should be the same image used for fitting
or a similar image from the same modality.
Returns
-------
Dict[str, float]
Dictionary containing registration quality metrics:
- 'and' : float
Count of pixels where both source and target are positive
(logical AND operation). Higher values indicate better overlap.
- 'or' : float
Count of pixels where either source or target is positive
(logical OR operation).
- 'xor' : float
Count of pixels where source and target disagree
(logical XOR operation). Lower values indicate better agreement.
- 'iou' : float
Intersection over Union ratio (and/or). Values range from 0-1
with 1 indicating perfect overlap. Returns 0.0 if no positive pixels exist.
- 'source_sum' : float
Sum of all source pixel intensities in the overlap region.
- 'target_sum' : float
Sum of all target pixel intensities in the overlap region.
Notes
-----
Metrics are computed only in the intersection region where both images
have valid data after transformation. This ensures fair comparison and
avoids edge effects from the transformation.
For binary images, the metrics have intuitive interpretations:
- IoU is the standard Jaccard index
- 'and' counts overlapping positive pixels
- 'xor' counts disagreement pixels
For continuous images, the logical operations are applied after
implicit conversion to boolean (non-zero values are True).
"""
assert source.ndim==2
assert target.ndim==2
stack = self.transform(source,target)
# restrict to shared area of union
source_mask = np.ones(source.shape)
target_mask = np.ones(target.shape)
stack_mask = self.transform(source_mask,target_mask).sum(0) >1
stack = stack[:,stack_mask]
# Calculate logical operations
logical_and = np.logical_and(stack[0], stack[1])
logical_or = np.logical_or(stack[0], stack[1])
logical_xor = np.logical_xor(stack[0], stack[1])
and_sum = logical_and.sum()
or_sum = logical_or.sum()
xor_sum = logical_xor.sum()
# Handle divide by zero case for IoU
iou = (and_sum / or_sum) if or_sum > 0 else 0.0
return({
'and': and_sum,
'or': or_sum,
'xor': xor_sum,
'iou': iou,
'source_sum': stack[0].sum(),
'target_sum': stack[1].sum()
})