diff --git a/lib/pgvector.rb b/lib/pgvector.rb index 0b2ca72..1e4665d 100644 --- a/lib/pgvector.rb +++ b/lib/pgvector.rb @@ -9,6 +9,8 @@ module Pgvector autoload :PG, "pgvector/pg" def self.encode(data) + return nil if data.nil? + if data.is_a?(Vector) || data.is_a?(HalfVector) || data.is_a?(SparseVector) data.to_s else @@ -17,6 +19,8 @@ def self.encode(data) end def self.decode(string) + return nil if string.nil? + if string[0] == "[" Vector.from_text(string).to_a elsif string[0] == "{" diff --git a/lib/sequel/plugins/pgvector.rb b/lib/sequel/plugins/pgvector.rb index 160e418..6095ca2 100644 --- a/lib/sequel/plugins/pgvector.rb +++ b/lib/sequel/plugins/pgvector.rb @@ -76,8 +76,11 @@ def []=(k, v) def [](k) if self.class.vector_columns.key?(k.to_sym) + value = super + return nil if value.nil? + # to_s needed for JRuby - ::Pgvector.decode(super.to_s) + ::Pgvector.decode(value.to_s) else super end diff --git a/test/pgvector_test.rb b/test/pgvector_test.rb index 5c02c98..183ee98 100644 --- a/test/pgvector_test.rb +++ b/test/pgvector_test.rb @@ -24,4 +24,12 @@ def test_decode_vector def test_decode_sparse_vector assert_equal [1, 0, 2, 0, 3, 0], Pgvector.decode("{1:1.0,3:2.0,5:3.0}/6").to_a end + + def test_encode_nil + assert_nil Pgvector.encode(nil) + end + + def test_decode_nil + assert_nil Pgvector.decode(nil) + end end diff --git a/test/sequel_test.rb b/test/sequel_test.rb index e5ca02e..9853faa 100644 --- a/test/sequel_test.rb +++ b/test/sequel_test.rb @@ -153,6 +153,19 @@ def test_instance_sparsevec_euclidean assert_equal [1, Math.sqrt(3)], results.map { |r| r[:neighbor_distance] } end + def test_nil_embedding + item = Item.create(id: 4) + item.refresh + assert_nil item.embedding + end + + def test_set_nil_embedding + item = Item.create(id: 4, embedding: [1, 1, 1]) + item.update(embedding: nil) + item.refresh + assert_nil item.embedding + end + def test_model_dataset create_items sampled_item = Item.order(Sequel.function(:random)).first