diff --git a/sql/sql_base.cc b/sql/sql_base.cc index 95abfa798bf65..795780f800b83 100644 --- a/sql/sql_base.cc +++ b/sql/sql_base.cc @@ -10145,11 +10145,15 @@ int TABLE::unlock_hlindexes() int TABLE::hlindexes_on_insert() { - DBUG_ASSERT(s->hlindexes() == (hlindex != NULL)); - if (hlindex && hlindex->in_use) - if (int err= mhnsw_insert(this, key_info + s->keys)) - return err; - return 0; + DBUG_ASSERT(s->hlindexes() == (hlindex != NULL)); + if (hlindex && hlindex->in_use) + { + if (hlindex->bulk_insert_active) + return mhnsw_bulk_insert_row(this, key_info + s->keys); + else + return mhnsw_insert(this, key_info + s->keys); + } + return 0; } int TABLE::hlindexes_on_update() @@ -10208,3 +10212,30 @@ int TABLE::hlindex_read_end() { return mhnsw_read_end(this); } + +int TABLE::hlindexes_bulk_insert_begin(ha_rows rows) +{ + if (s->hlindexes()) + { + if (!hlindex || !hlindex->in_use) + if (int err= open_hlindexes_for_write()) + return err; + + if (hlindex && hlindex->in_use) + { + hlindex->bulk_insert_active= true; + return mhnsw_bulk_insert_begin(this, key_info + s->keys, rows); + } + } + return 0; +} + +int TABLE::hlindexes_bulk_insert_end() +{ + if (hlindex && hlindex->in_use) + { + hlindex->bulk_insert_active= false; + return mhnsw_bulk_insert_end(this, key_info + s->keys); + } + return 0; +} diff --git a/sql/sql_table.cc b/sql/sql_table.cc index b07f16102a6bb..81a5cfbb07757 100644 --- a/sql/sql_table.cc +++ b/sql/sql_table.cc @@ -12616,6 +12616,7 @@ copy_data_between_tables(THD *thd, TABLE *from, TABLE *to, bool make_unversioned= from->versioned() && !to->versioned(); bool keep_versioned= from->versioned() && to->versioned(); bool bulk_insert_started= 0; + bool hlindex_bulk_started= 0; Field *to_row_start= NULL, *to_row_end= NULL, *from_row_end= NULL; MYSQL_TIME query_start; DBUG_ENTER("copy_data_between_tables"); @@ -12662,11 +12663,17 @@ copy_data_between_tables(THD *thd, TABLE *from, TABLE *to, from->file->info(HA_STATUS_VARIABLE); to->file->extra(HA_EXTRA_PREPARE_FOR_ALTER_TABLE); - if (!to->s->long_unique_table && !to->s->hlindexes()) + if (!to->s->long_unique_table) { - to->file->ha_start_bulk_insert(from->file->stats.records, - ignore ? 0 : HA_CREATE_UNIQUE_INDEX_BY_SORT); - bulk_insert_started= 1; + to->file->ha_start_bulk_insert(from->file->stats.records, + ignore ? 0 : HA_CREATE_UNIQUE_INDEX_BY_SORT); + bulk_insert_started= 1; + + if (to->s->hlindexes()) + { + to->hlindexes_bulk_insert_begin(from->file->stats.records); + hlindex_bulk_started= 1; + } } mysql_stage_set_work_estimated(thd->m_stage_progress_psi, from->file->stats.records); List_iterator it(alter_info->create_list); @@ -12999,6 +13006,14 @@ copy_data_between_tables(THD *thd, TABLE *from, TABLE *to, } bulk_insert_started= 0; + if (hlindex_bulk_started && to->hlindexes_bulk_insert_end() && error <= 0) + { + if (!thd->is_error()) + to->file->print_error(my_errno, MYF(0)); + error= 1; + } + hlindex_bulk_started=0; + if (error <= 0 && !to->s->hlindexes()) { Abort_on_warning_instant_set save_abort_on_warning(thd, false); diff --git a/sql/table.h b/sql/table.h index 0713341840127..de5ac196cc3ca 100644 --- a/sql/table.h +++ b/sql/table.h @@ -1632,6 +1632,7 @@ struct TABLE */ bool alias_name_used; /* true if table_name is alias */ bool get_fields_in_item_tree; /* Signal to fix_field */ + bool bulk_insert_active; /* mhnsw bulk_insert_started flag */ private: bool m_needs_reopen; bool created; /* For tmp tables. TRUE <=> tmp table was actually created.*/ @@ -1875,6 +1876,8 @@ struct TABLE int hlindexes_on_update(); int hlindexes_on_delete(const uchar *buf); int hlindexes_on_delete_all(bool truncate); + int hlindexes_bulk_insert_begin(ha_rows rows); + int hlindexes_bulk_insert_end(); int unlock_hlindexes(); void prepare_triggers_for_insert_stmt_or_event(); diff --git a/sql/vector_mhnsw.cc b/sql/vector_mhnsw.cc index c480c36c7e7ad..f427f1d51834d 100644 --- a/sql/vector_mhnsw.cc +++ b/sql/vector_mhnsw.cc @@ -510,7 +510,7 @@ class MHNSW_Share : public Sql_alloc const uint M; metric_type metric; bool use_subdist; - + bool bulk_active; MHNSW_Share(TABLE *t) : tref_len(t->file->ref_length), gref_len(t->hlindex->file->ref_length), M(static_cast(t->s->key_info[t->s->keys].option_struct->M)), @@ -1012,6 +1012,8 @@ int FVectorNode::load_from_record(TABLE *graph) FVector *vec_ptr= FVector::align_ptr(tref() + tref_len()); memcpy(vec_ptr->data(), v->ptr(), v->length()); vec_ptr->postprocess(ctx->use_subdist, ctx->vec_len); + if (ctx->metric == COSINE) + vec_ptr->abs2= 0.5f; longlong layer= graph->field[FIELD_LAYER]->val_int(); if (layer > 100) // 10e30 nodes at M=2, more at larger M's @@ -1266,8 +1268,9 @@ static int update_second_degree_neighbors(MHNSW_param *p, FVectorNode *node) if (int err= select_neighbors(p, neigh, neighneighbors, node, max_neighbors)) return err; - if (int err= neigh->save(p->graph)) - return err; + if (!p->ctx->bulk_active) + if (int err= neigh->save(p->graph)) + return err; } return 0; } @@ -1504,6 +1507,193 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) } +struct MHNSW_Bulk_context : public Sql_alloc { + MHNSW_Share *ctx; + DYNAMIC_ARRAY nodes; + ha_rows estimated_rows; + uint8_t current_max_layer; +}; + +int mhnsw_bulk_insert_begin(TABLE *table, KEY *keyinfo, ha_rows rows) +{ + TABLE *graph= table->hlindex; + DBUG_ASSERT(graph); + DBUG_ASSERT(keyinfo->algorithm == HA_KEY_ALG_VECTOR); + DBUG_ASSERT(keyinfo->usable_key_parts == 1); + + MHNSW_Bulk_context *bulk= new (table->in_use->mem_root) MHNSW_Bulk_context(); + if (!bulk) + return HA_ERR_OUT_OF_MEM; + + bulk->estimated_rows= rows; + if (my_init_dynamic_array(PSI_INSTRUMENT_MEM, &bulk->nodes, sizeof(FVectorNode*), + rows + rows * 0.1, rows, MYF(0))) + { + delete bulk; + return HA_ERR_OUT_OF_MEM; + } + + int err= MHNSW_Share::acquire(&bulk->ctx, table, true); + if (err && err != HA_ERR_END_OF_FILE && err != HA_ERR_KEY_NOT_FOUND) + { + delete_dynamic(&bulk->nodes); + delete bulk; + return err; + } + + bulk->ctx->bulk_active= 1; + bulk->current_max_layer= 0; + table->hlindex->context= bulk; + return 0; +} + +int mhnsw_bulk_insert_row(TABLE *table, KEY *keyinfo) +{ + TABLE *graph= table->hlindex; + MHNSW_Bulk_context *bulk= (MHNSW_Bulk_context*)graph->context; + MHNSW_Share *ctx= bulk->ctx; + MY_BITMAP *old_map= dbug_tmp_use_all_columns(table, &table->read_set); + + DBUG_ASSERT(graph); + DBUG_ASSERT(bulk); + DBUG_ASSERT(keyinfo->algorithm == HA_KEY_ALG_VECTOR); + DBUG_ASSERT(keyinfo->usable_key_parts == 1); + + Field *vec_field= keyinfo->key_part->field; + String buf, *res= vec_field->val_str(&buf); + + DBUG_ASSERT(vec_field->binary()); + DBUG_ASSERT(vec_field->cmp_type() == STRING_RESULT); + DBUG_ASSERT(res); // ER_INDEX_CANNOT_HAVE_NULL + DBUG_ASSERT(res->length() > 0 && res->length() % 4 == 0); + DBUG_ASSERT(table->file->ref_length <= graph->field[FIELD_TREF]->field_length); + + table->file->position(table->record[0]); + + if (ctx->byte_len == 0) + ctx->set_lengths(res->length()); + + if (ctx->byte_len != res->length()) + return my_errno= HA_ERR_CRASHED; + + const double NORMALIZATION_FACTOR= 1 / std::log(ctx->M); + double log= -std::log(my_rnd(&table->in_use->rand)) * NORMALIZATION_FACTOR; + uint8_t max_layer= bulk->current_max_layer; + uint8_t target_layer= std::min(static_cast(std::floor(log)), max_layer + 1); + + if (bulk->nodes.elements == 0) + target_layer= 0; + + if (target_layer > bulk->current_max_layer) + bulk->current_max_layer= target_layer; + + FVectorNode *node= new (ctx->alloc_node()) + FVectorNode(ctx, table->file->ref, target_layer, res->ptr()); + + if (insert_dynamic(&bulk->nodes, (uchar*)&node)) + return HA_ERR_OUT_OF_MEM; + + dbug_tmp_restore_column_map(&table->read_set, old_map); + return 0; +} + +int mhnsw_bulk_insert_end(TABLE *table, KEY *keyinfo) +{ + THD *thd= table->in_use; + TABLE *graph= table->hlindex; + MHNSW_Bulk_context *bulk= (MHNSW_Bulk_context*)graph->context; + + DBUG_ASSERT(graph); + DBUG_ASSERT(bulk); + + MHNSW_Share *ctx= bulk->ctx; + SCOPE_EXIT([ctx, bulk, table](){ + delete_dynamic(&bulk->nodes); + ctx->bulk_active= 0; + ctx->release(table); + table->hlindex->context= nullptr; + }); + + for (uint i= 0; i < bulk->nodes.elements; i++) + { + FVectorNode *target= *(FVectorNode**)dynamic_element(&bulk->nodes, i, FVectorNode**); + + if (!ctx->start) + { + ctx->start= target; + continue; + } + + MEM_ROOT_SAVEPOINT memroot_sv; + root_make_savepoint(thd->mem_root, &memroot_sv); + SCOPE_EXIT([memroot_sv](){ root_free_to_savepoint(&memroot_sv); }); + + const uint8_t max_layer= ctx->start->max_layer; + uint8_t target_layer= target->max_layer; + + MHNSW_param p(ctx, graph, max_layer); + p.acc.graph_size= 1; + + const size_t max_found= ctx->max_neighbors(0); + Neighborhood candidates; + candidates.init(thd->alloc(max_found + 7), max_found); + candidates.links[candidates.num++]= ctx->start; + + for (; p.layer > target_layer; p.layer--) + { + if (int err= search_layer(&p, target->vec, NEAREST, 1, &candidates, false)) + return err; + } + + for (; p.layer >= 0; p.layer--) + { + uint max_neighbors= ctx->max_neighbors(p.layer); + if (int err= search_layer(&p, target->vec, NEAREST, max_neighbors, + &candidates, true)) + return err; + if (int err= select_neighbors(&p, target, candidates, 0, max_neighbors)) + return err; + } + + ctx->add_to_stats(p.acc); + + if (target_layer > max_layer) + ctx->start= target; + + for (p.layer= target_layer; p.layer >= 0; p.layer--) + { + if (int err= update_second_degree_neighbors(&p, target)) + return err; + } + } + + graph->file->ha_start_bulk_insert(bulk->nodes.elements, 0); + + for (uint i= 0; i < bulk->nodes.elements; i++) + { + FVectorNode *node= *(FVectorNode**)dynamic_element(&bulk->nodes, i, FVectorNode**); + if (int err= node->save(graph)) + return err; + } + + if (int err= graph->file->ha_end_bulk_insert()) + return err; + + if (int err= graph->file->ha_rnd_init(0)) + return err; + SCOPE_EXIT([graph](){ graph->file->ha_rnd_end(); }); + + // fix neighbors grefs + for (uint i= 0; i < bulk->nodes.elements; i++) + { + FVectorNode *node= *(FVectorNode**)dynamic_element(&bulk->nodes, i, FVectorNode**); + if (int err= node->save(graph)) + return err; + } + + return 0; +} + struct Search_context: public Sql_alloc { Neighborhood found; diff --git a/sql/vector_mhnsw.h b/sql/vector_mhnsw.h index fbb61e14773f9..e6a8622e3e609 100644 --- a/sql/vector_mhnsw.h +++ b/sql/vector_mhnsw.h @@ -34,6 +34,9 @@ int mhnsw_invalidate(TABLE *table, const uchar *rec, KEY *keyinfo); int mhnsw_delete_all(TABLE *table, KEY *keyinfo, bool truncate); void mhnsw_free(TABLE_SHARE *share); Item_func_vec_distance::distance_kind mhnsw_uses_distance(const TABLE *table, KEY *keyinfo); +int mhnsw_bulk_insert_begin(TABLE *table, KEY *keyinfo, ha_rows rows); +int mhnsw_bulk_insert_end(TABLE *table, KEY *keyinfo); +int mhnsw_bulk_insert_row(TABLE *table, KEY *keyinfo); extern ha_create_table_option mhnsw_index_options[]; extern st_plugin_int *mhnsw_plugin;