The first time I opened a SMOT4SB dataset video, I spent five minutes convinced the annotations were broken. The bounding boxes were there, labelled neatly in the ground truth files, but I couldn’t see what they were tracking. Then I zoomed in. And zoomed in again. At 400% magnification, I finally saw them: birds reduced to 4×4 pixel smudges against a 4K sky.

The SMOT4SB challenge (Small Multi-Object Tracking for Spotting Birds) asks you to detect and track birds in UAV footage. The catch: the objects you’re tracking are roughly 16 square pixels in a 3840×2160 frame. That’s 0.0002% of the image. Most object detectors can’t reliably see anything below 32 pixels.

Here’s what I discovered building a tracking pipeline under strict constraints: the newest methods aren’t always the best fit. A Kalman filter from the 1960s outperformed approaches I expected to dominate. A simple change to the similarity metric improved tracking by 80%. And camera motion compensation, the technique that sounds obviously useful for drone footage, actually made things worse.

The Problem Space

Let me be concrete about what “small” means here.

StatisticValue
Total videos211
Total frames108,192
Bird instances371,690
Unique tracks2,240
Resolution3840×2160 (4K)
Average object size~4×4 pixels (~16 sq. pixels)

A typical bird in this dataset is four pixels wide. Your mouse cursor is larger. The bounding boxes overlap with a single character in a terminal window.

This creates a cascade of problems:

  1. Detection failure: YOLOv8’s feature map stride limits reliable detection to ~32×32 pixel objects. Below that, objects blur into background noise.
  2. IoU collapse: Standard Intersection over Union becomes numerically unstable. A 1-pixel diagonal shift drops IoU from 1.0 to 0.39. A 2-pixel shift: 0.14.
  3. Motion blur: Birds move. Drones move. Combined, the blur often exceeds the object size.
  4. Flocking dynamics: Birds cluster, occlude each other, and reappear in different configurations.

I set myself constraints: CPU-only execution, no fine-tuning, pretrained models only. This ruled out training a specialised small-object detector. I had to make existing tools work.

Why This Matters

This isn’t just an academic exercise. Drone-based wildlife monitoring is increasingly important for conservation, and the equipment constraints are real. Field researchers often work with laptops, not GPU clusters. Edge devices on drones have limited compute. A CPU-only solution that works is more valuable than a GPU solution that stays in the lab.

The same challenges appear in satellite imagery (tracking ships), medical imaging (tracking cells), and industrial inspection (tracking defects). Anywhere objects approach the limits of detectability, the standard toolbox breaks down in similar ways.

The Pipeline Architecture

I built a four-stage tracking-by-detection pipeline:

Input Frames (3840×2160)
[SAHI Slicer] → Divide into 640×640 patches (20% overlap)
[YOLOv8n] → Run inference on each patch
[NMS Merging] → Fuse detections across patches
[Kalman Filter] → Predict track positions
[Hungarian Algorithm] → Optimal bipartite matching
[Track Manager] → Initialise/terminate tracks
Output: Tracked bird trajectories with persistent IDs

The architecture is deliberately classical. Each component is interpretable. Each parameter is tunable. This matters when debugging why a 4-pixel bird lost its track ID.

Detection: Making YOLO See the Invisible

YOLOv8 couldn’t detect the birds. Running inference on full 4K frames returned zero detections. The objects were simply too small relative to the model’s internal resolution.

The solution was SAHI (Sliced Aided Hyper Inference). Instead of feeding the full frame, SAHI divides it into overlapping patches, runs detection on each patch independently, then merges the results with non-maximum suppression.

from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction

detection_model = AutoDetectionModel.from_pretrained(
    model_type="yolov8",
    model_path="yolov8n.pt",
    confidence_threshold=0.05,  # Low threshold for tiny objects
    device="cpu"
)

result = get_sliced_prediction(
    image=frame,
    detection_model=detection_model,
    slice_height=640,
    slice_width=640,
    overlap_height_ratio=0.2,
    overlap_width_ratio=0.2
)

The key insight: when you slice a 4K frame into 640×640 patches, a 4-pixel bird becomes roughly 4× larger relative to the detection grid. It crosses the threshold from invisible to detectable.

The confidence threshold was counterintuitive. Standard practice suggests 0.3-0.5 for good precision. I used 0.05. The reasoning: with objects this small, you want high recall even at the cost of false positives. The tracker downstream can filter noise; it cannot recover missed detections.

The DotD Revelation

This is where things got interesting.

Standard multi-object tracking uses IoU (Intersection over Union) to measure similarity between predicted and detected bounding boxes. IoU works beautifully for normal-sized objects. For 4-pixel birds, it fails catastrophically.

The mathematics are unforgiving. Let me show you exactly why:

# For a 4×4 bounding box (16 square pixels):

# Perfect alignment
intersection = 4 * 4  # 16
union = 16 + 16 - 16  # 16
iou = 16 / 16  # = 1.0

# 1-pixel diagonal shift
intersection = 3 * 3  # 9
union = 16 + 16 - 9   # 23
iou = 9 / 23  # ≈ 0.39  (61% drop from a 1-pixel error!)

# 2-pixel diagonal shift
intersection = 2 * 2  # 4
union = 16 + 16 - 4   # 28
iou = 4 / 28  # ≈ 0.14  (86% drop)

A 1-pixel error drops IoU by 61%. A 2-pixel error makes it 0.14, which the algorithm treats as “no match”. For practical purposes, that’s zero. The algorithm treats a nearly-correct prediction as a complete mismatch.

Think of it this way: judging whether two postage stamps overlap by looking at them from across a football pitch. A millimetre of misalignment looks like a complete miss.

I replaced IoU with DotD (Dot Distance), a similarity metric based on centre distance with exponential decay:

def compute_dotd(box_a: BoundingBox, box_b: BoundingBox) -> float:
    """Compute DotD similarity between two boxes.

    DotD = exp(-d / s) where:
    - d = Euclidean distance between centres
    - s = average object size

    Returns smooth similarity even for non-overlapping tiny objects.
    """
    c_a = box_a.centre
    c_b = box_b.centre
    distance = np.sqrt((c_a[0] - c_b[0])**2 + (c_a[1] - c_b[1])**2)

    avg_size = (np.sqrt(box_a.area) + np.sqrt(box_b.area)) / 2
    avg_size = max(avg_size, 1.0)  # Prevent division by zero

    return np.exp(-distance / avg_size)

The exponential decay provides graceful degradation. Here’s the key insight:

# For a 4-pixel object with 2-pixel centre error:

# IoU approach:
iou = 0.14  # Algorithm sees "no match", assigns wrong ID

# DotD approach:
distance = 2  # pixels
avg_size = 4  # pixels
dotd = exp(-2/4)  # = 0.61, algorithm sees "close match"

Same prediction error, completely different outcomes. IoU says “mismatch”. DotD says “probably the same bird”. The algorithm can still find the correct match.

I ran ablations comparing IoU against DotD with identical detection and tracking components:

MetricIoU-basedDotD-basedImprovement
SO-HOTA0.1110.199+80%
Association Accuracy0.2490.819+229%
ID Switches1.5 ± 2.10.2 ± 0.4-87%

An 80% improvement from changing one function. Everything else remained constant. The metric selection dominated the architecture choice.

Kalman Filters: The Classics Still Work

For state estimation, I used a Kalman filter with a constant velocity model. State vector: [x, y, vx, vy]. The implementation is straightforward:

def _init_kalman_filter(self, box: BoundingBox) -> KalmanFilter:
    """Initialise Kalman filter with constant velocity model.

    State: [x, y, vx, vy]
    Measurement: [x, y] (centre position only)
    """
    kf = KalmanFilter(dim_x=4, dim_z=2)

    # State transition: constant velocity
    kf.F = np.array([
        [1, 0, 1, 0],  # x' = x + vx
        [0, 1, 0, 1],  # y' = y + vy
        [0, 0, 1, 0],  # vx' = vx
        [0, 0, 0, 1]   # vy' = vy
    ])

    # Measurement: observe position only
    kf.H = np.array([
        [1, 0, 0, 0],
        [0, 1, 0, 0]
    ])

    # Initial state from detection
    cx, cy = box.centre
    kf.x = np.array([[cx], [cy], [0], [0]])

    return kf

The intuition is simple: if you see a bird at position (100, 200) moving east at 5 pixels per frame, predict it’ll be at (105, 200) next frame. When the detection arrives and says (107, 201), the Kalman filter blends the prediction with the measurement, weighted by how confident it is in each. That’s it. No neural networks required.

I considered more sophisticated alternatives: recurrent neural networks, transformers, learned motion models. None were necessary. Birds in short video clips move predictably. The constant velocity assumption held well enough.

For data association, the Hungarian algorithm provides optimal bipartite matching in O(n³). With typically <50 detections per frame, this runs in milliseconds:

from scipy.optimize import linear_sum_assignment

def associate(detections, predictions):
    # Build cost matrix using DotD
    cost_matrix = np.zeros((len(predictions), len(detections)))
    for i, pred in enumerate(predictions):
        for j, det in enumerate(detections):
            similarity = compute_dotd(pred, det)
            cost_matrix[i, j] = 1 - similarity  # Convert to cost

    # Hungarian algorithm: optimal matching
    row_ind, col_ind = linear_sum_assignment(cost_matrix)

    return [(r, c) for r, c in zip(row_ind, col_ind)
            if cost_matrix[r, c] < 0.7]

Classical methods from the 1950s and 1960s. Still competitive. The difference: careful metric selection to match the problem domain.

The Counterintuitive Finding: Camera Motion Compensation Hurts

UAV footage has significant camera ego-motion. The standard approach is Camera Motion Compensation (CMC): estimate a homography between frames using ORB features, warp the predicted bounding boxes, then match against detections.

I implemented this. Initial results looked promising. Then I ran proper ablations across multiple videos:

ConfigurationSO-HOTA
Without CMC0.199 ± 0.197
With CMC0.206 ± 0.192

A 3.5% improvement. Within the noise margin.

Here’s why CMC doesn’t help at this scale: homography warping involves bilinear interpolation. For normal-sized objects, sub-pixel rounding errors are negligible. For 4-pixel objects, those errors are comparable to the object dimensions themselves. The CMC correction introduces as much noise as it removes.

I disabled CMC in the final configuration. Computational savings with no accuracy loss.

Results and What They Mean

Final performance on the SMOT4SB validation set:

MetricValue
SO-HOTA19.9 ± 19.7
Detection Accuracy (DetA)0.08
Association Accuracy (AssA)0.82
ID Switches0.2 ± 0.4
Runtime~2.9 seconds per frame (CPU)

For context: the challenge baseline achieved SO-HOTA = 9.9. My pipeline doubled that despite stricter constraints (CPU-only, no fine-tuning). The challenge winner reached 55.2 using GPU-accelerated inference and fine-tuned models.

The diagnostic insight is in the accuracy breakdown: Detection Accuracy = 0.08, Association Accuracy = 0.82. Once a bird is detected, tracking works extremely well. The bottleneck is detection, not tracking.

This 10× gap tells you where to focus future effort. Better small-object detectors would improve overall performance more than sophisticated tracking algorithms.

The high standard deviation (±19.7) reflects dataset heterogeneity, not measurement noise. Some videos have birds against clear sky (easy). Others have birds against complex backgrounds, in dense flocks, with heavy motion blur (hard). A single SO-HOTA number obscures this variance.

Configuration Details

For reproducibility, here are the final parameters:

ParameterValueRationale
SAHI slice size640×640Matches YOLO training scale
SAHI overlap0.2Prevents edge detection loss
Confidence threshold0.05High recall for tiny objects
Track min_hits2Confirm after 2 detections
Track max_age30Tolerate missed detections
BBox expansion2.0×For association robustness
EMA velocity alpha0.8Smooth erratic bird motion
CMCDisabledHurts at 4-pixel scale

The Lesson

Three principles emerged from this project:

1. Match your metrics to your scale. IoU is the default for bounding box matching. It fails below ~20 pixels. DotD provides smooth similarity at any scale. The 80% improvement came from one function change.

2. Question every enhancement. Camera motion compensation sounds obviously useful for drone footage. At 4-pixel scales, it adds noise. The experiment showed a 3.5% improvement within the noise margin. Not every technique that works in general works in your domain.

3. Diagnose before you optimise. The DetA vs AssA breakdown revealed that detection is the bottleneck, not tracking. A more sophisticated tracker wouldn’t help much. A better small-object detector would.

The constraints forced clarity. CPU-only execution meant I couldn’t hide inefficiency behind GPU parallelism. No fine-tuning meant I had to make pretrained models work. These limitations pushed toward simpler, more interpretable solutions.

Modern tracking research focuses on learned representations, attention mechanisms, and transformer architectures. These approaches excel on standard benchmarks with normal-sized objects. At the extreme lower boundary of detectability, classical methods with careful metric selection performed competitively.

When to Apply These Lessons

The patterns here generalise beyond bird tracking:

  • Satellite imagery: Ships, vehicles, and structures near the resolution limit face the same IoU collapse
  • Medical imaging: Tracking cells or particles in microscopy has identical scale challenges
  • Industrial inspection: Defect detection on production lines often involves sub-pixel features
  • Any domain where objects approach detector limits: The standard toolbox breaks down in predictable ways

The meta-lesson: when your problem operates at the boundaries of your tools’ capabilities, the choice of similarity metric matters more than the sophistication of your architecture. I spent days reading papers on transformer-based trackers and learned motion models. The 80% improvement came from swapping one distance function.

If your objects are small enough that IoU is failing you, try DotD. If your enhancement isn’t beating the noise margin, disable it. And if your DetA is an order of magnitude below your AssA, stop tuning your tracker and start improving your detector.

The answers aren’t always in the latest arXiv preprint. Sometimes they’re in a 1960 paper about missile guidance systems.


Code

The complete implementation is available as a Colab notebook:

SMOT4SB Tracking Pipeline (Colab)

Key dependencies:

  • ultralytics (YOLOv8)
  • sahi (Sliced Aided Hyper Inference)
  • filterpy (Kalman Filter)
  • scipy (Hungarian Algorithm)

URLs in this post: