From 0065022b169bc31651cc1cf4909f4e9abafdceb8 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 7 May 2026 10:47:15 -0700 Subject: [PATCH] Fix overflow warning in nearest neighbors code During the unittest we were seeing a warning saying ``` tests/knn_test.py::BM25Test::test_rank_items_batch tests/knn_test.py::BM25Test::test_similar_items_filter tests/knn_test.py::TFIDFTest::test_rank_items_batch tests/knn_test.py::TFIDFTest::test_similar_items_filter tests/knn_test.py::CosineTest::test_rank_items_batch tests/knn_test.py::CosineTest::test_similar_items_filter /home/ben/code/implicit/implicit/utils.py:134: RuntimeWarning: overflow encountered in cast output_scores[i] = batch_scores[:N] ``` This is because the `_batch_call` was generating output in float32, but the KNN models were returning float64 results. Fix. --- implicit/nearest_neighbours.py | 8 +++++++- implicit/utils.py | 6 +++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/implicit/nearest_neighbours.py b/implicit/nearest_neighbours.py index a1f36c1e..29e50c2c 100644 --- a/implicit/nearest_neighbours.py +++ b/implicit/nearest_neighbours.py @@ -62,6 +62,7 @@ def recommend( userid, user_items=user_items, N=N, + score_dtype=np.float64, filter_already_liked_items=filter_already_liked_items, filter_items=filter_items, recalculate_user=recalculate_user, @@ -115,7 +116,12 @@ def similar_items( if not np.isscalar(itemid): return _batch_call( - self.similar_items, itemid, N=N, filter_items=filter_items, items=items + self.similar_items, + itemid, + N=N, + score_dtype=np.float64, + filter_items=filter_items, + items=items, ) if filter_items is not None and items is not None: diff --git a/implicit/utils.py b/implicit/utils.py index f61da18d..6b7ac5ae 100644 --- a/implicit/utils.py +++ b/implicit/utils.py @@ -103,11 +103,11 @@ def augment_inner_product_matrix(factors): return max_norm, np.append(factors, extra_dimension.reshape(norms.shape[0], 1), axis=1) -def _batch_call(func, ids, *args, N=10, **kwargs): +def _batch_call(func, ids, *args, N=10, id_dtype=np.int32, score_dtype=np.float32, **kwargs): # we're running in batch mode, just loop over each item and call the scalar version of the # function - output_ids = np.zeros((len(ids), N), dtype=np.int32) - output_scores = np.zeros((len(ids), N), dtype=np.float32) + output_ids = np.zeros((len(ids), N), dtype=id_dtype) + output_scores = np.zeros((len(ids), N), dtype=score_dtype) user_items = kwargs.pop("user_items") if "user_items" in kwargs else None item_users = kwargs.pop("item_users") if "item_users" in kwargs else None