Skip to content

Commit be91f3a

Browse files
ElektrikAkarclaude
andcommitted
fix: wrap dtw_distance/dtw_distance_missing to accept lists and numpy
The pointer-based refactor changed the C++ binding to require nb::ndarray (numpy arrays). All existing tests and user code pass Python lists. Added Python wrappers that auto-convert via np.asarray(). 114/114 Python tests pass. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4098c6c commit be91f3a

1 file changed

Lines changed: 30 additions & 3 deletions

File tree

python/dtwcpp/__init__.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""DTWC++ — Fast Dynamic Time Warping and Clustering."""
22

3+
import numpy as _np
4+
35
from dtwcpp._dtwcpp_core import (
46
# Enums
57
Method,
@@ -14,14 +16,14 @@
1416
Data,
1517
# Classes
1618
Problem,
17-
# DTW functions
18-
dtw_distance,
19+
# DTW functions (raw C++ bindings — require numpy arrays)
20+
dtw_distance as _dtw_distance_raw,
1921
ddtw_distance,
2022
wdtw_distance,
2123
adtw_distance,
2224
soft_dtw_distance,
2325
soft_dtw_gradient,
24-
dtw_distance_missing,
26+
dtw_distance_missing as _dtw_distance_missing_raw,
2527
# Algorithms
2628
fast_pam,
2729
fast_clara,
@@ -42,6 +44,31 @@
4244

4345
__version__ = "1.0.0"
4446

47+
48+
# Wrappers that accept both lists and numpy arrays (the C++ nb::ndarray
49+
# binding only accepts numpy arrays; these auto-convert for convenience).
50+
def dtw_distance(x, y, band=-1, metric="l1"):
51+
"""Compute DTW distance between two time series.
52+
53+
Accepts lists or numpy arrays. metric: 'l1' (default) or 'squared_euclidean'.
54+
band=-1 for full DTW, band>0 for Sakoe-Chiba banded DTW.
55+
"""
56+
return _dtw_distance_raw(
57+
_np.asarray(x, dtype=_np.float64),
58+
_np.asarray(y, dtype=_np.float64),
59+
band, metric)
60+
61+
62+
def dtw_distance_missing(x, y, band=-1, metric="l1"):
63+
"""DTW distance with missing data support (NaN = missing).
64+
65+
Accepts lists or numpy arrays. NaN values contribute zero cost.
66+
"""
67+
return _dtw_distance_missing_raw(
68+
_np.asarray(x, dtype=_np.float64),
69+
_np.asarray(y, dtype=_np.float64),
70+
band, metric)
71+
4572
# Pure-Python sklearn-compatible layer
4673
from dtwcpp._clustering import DTWClustering
4774

0 commit comments

Comments
 (0)