diff --git a/changelog.md b/changelog.md index f9e9eeca..c2b95988 100644 --- a/changelog.md +++ b/changelog.md @@ -1,3 +1,11 @@ +Upcoming (TBD) +============== + +Features +--------- +* Append dot to database name when completing on a table name. + + 1.71.0 (2026/05/01) ============== diff --git a/mycli/packages/completion_engine.py b/mycli/packages/completion_engine.py index f623a38c..c1dc0757 100644 --- a/mycli/packages/completion_engine.py +++ b/mycli/packages/completion_engine.py @@ -302,7 +302,7 @@ def _emit_relation_name(ctx: SuggestContext) -> list[Suggestion]: schema = _parent_name(ctx) if schema: return [{'type': rel_type, 'schema': schema}] - return [{'type': 'schema'}, {'type': rel_type, 'schema': []}] + return [{'type': 'database'}, {'type': rel_type, 'schema': []}] def _emit_on(ctx: SuggestContext) -> list[Suggestion]: diff --git a/mycli/sqlcompleter.py b/mycli/sqlcompleter.py index 8fe96a68..3ea58595 100644 --- a/mycli/sqlcompleter.py +++ b/mycli/sqlcompleter.py @@ -10,6 +10,7 @@ from prompt_toolkit.completion.base import Document from pygments.lexers._mysql_builtins import MYSQL_DATATYPES, MYSQL_FUNCTIONS, MYSQL_KEYWORDS import rapidfuzz +import sqlparse from mycli.packages.completion_engine import is_inside_quotes, suggest_type from mycli.packages.filepaths import complete_path, parse_path, suggest_path @@ -1419,6 +1420,18 @@ def get_completions( last_for_len = last_word(word_before_cursor, include="most_punctuations") text_for_len = last_for_len.lower() last_for_len_paths = last_word(word_before_cursor, include='alphanum_underscore') + if statements := sqlparse.split(document.text): + total_len = 0 + relevant_statement = '' + for statement in statements: + total_len = total_len + len(statement) + if document.cursor_position <= total_len: + relevant_statement = statement + break + if not relevant_statement: + relevant_statement = statements[-1] + else: + relevant_statement = document.text if smart_completion is None: smart_completion = self.smart_completion @@ -1436,7 +1449,9 @@ def get_completions( return (Completion(x[0], -len(text_for_len)) for x in matches) completions: list[tuple[str, int, int]] = [] - suggestions = suggest_type(document.text, document.text_before_cursor) + suggestions = suggest_type(relevant_statement, document.text_before_cursor) + database_is_qualifier = any(suggestion['type'] in ('table', 'view') for suggestion in suggestions) + database_is_qualifier = database_is_qualifier and not relevant_statement.lstrip().lower().startswith('show ') rigid_sort = False length_based_on_path = False @@ -1530,7 +1545,7 @@ def get_completions( # then only return tables that have one or more of the given columns. # If no columns are given (or able to be parsed), return all tables # as usual. - columns = extract_columns_from_select(document.text) + columns = extract_columns_from_select(relevant_statement) if columns: tables = self.populate_schema_objects(suggestion["schema"], "tables", columns) else: @@ -1538,7 +1553,7 @@ def get_completions( if suggestion.get("join"): # For JOINs, suggest FK-related tables first (lower rank = higher priority) - current_tables = extract_tables(document.text) + current_tables = extract_tables(relevant_statement) fk_map = self.dbmetadata["foreign_keys"].get(self.dbname, {}).get("tables", {}) fk_related: set[str] = set() for tbl_schema, tbl, _alias in current_tables: @@ -1602,6 +1617,8 @@ def get_completions( self.databases, text_before_cursor=document.text_before_cursor, ) + if database_is_qualifier: + dbs_m = ((f'{db}.', fuzziness) for db, fuzziness in dbs_m) completions.extend([(*x, rank) for x in dbs_m]) elif suggestion["type"] == "keyword": diff --git a/test/pytests/test_completion_engine.py b/test/pytests/test_completion_engine.py index b17b218b..ce67b98a 100644 --- a/test/pytests/test_completion_engine.py +++ b/test/pytests/test_completion_engine.py @@ -595,7 +595,7 @@ def test_emit_relation_name_with_schema_parent(): def test_emit_relation_name_without_schema_parent(): context = _build_suggest_context('view', '', None, '', empty_identifier()) - assert _emit_relation_name(context) == [{'type': 'schema'}, {'type': 'view', 'schema': []}] + assert _emit_relation_name(context) == [{'type': 'database'}, {'type': 'view', 'schema': []}] @pytest.mark.xfail @@ -925,9 +925,9 @@ def test_suggest_based_on_last_token_lparen_in_function_call_suggests_columns(): ('database', 'drop database ', 'drop database ', [{'type': 'database'}]), ('template', 'create database foo with template ', 'create database foo with template ', [{'type': 'database'}]), ('collate', 'collate ', 'collate ', [{'type': 'collation'}]), - ('table', 'drop table ', 'drop table ', [{'type': 'schema'}, {'type': 'table', 'schema': []}]), - ('view', 'drop view ', 'drop view ', [{'type': 'schema'}, {'type': 'view', 'schema': []}]), - ('function', 'drop function ', 'drop function ', [{'type': 'schema'}, {'type': 'function', 'schema': []}]), + ('table', 'drop table ', 'drop table ', [{'type': 'database'}, {'type': 'table', 'schema': []}]), + ('view', 'drop view ', 'drop view ', [{'type': 'database'}, {'type': 'view', 'schema': []}]), + ('function', 'drop function ', 'drop function ', [{'type': 'database'}, {'type': 'function', 'schema': []}]), ], ) def test_suggest_based_on_last_token_direct_keyword_branches(token, text_before_cursor, full_text, expected): diff --git a/test/pytests/test_smart_completion_public_schema_only.py b/test/pytests/test_smart_completion_public_schema_only.py index 4b1b5a0d..c76205ed 100644 --- a/test/pytests/test_smart_completion_public_schema_only.py +++ b/test/pytests/test_smart_completion_public_schema_only.py @@ -219,8 +219,8 @@ def test_table_completion(completer, complete_event): Completion(text="time_zone_name", start_position=0), Completion(text="time_zone_transition", start_position=0), Completion(text="time_zone_transition_type", start_position=0), - Completion(text="test", start_position=0), - Completion(text="`test 2`", start_position=0), + Completion(text="test.", start_position=0), + Completion(text="`test 2`.", start_position=0), ] @@ -238,8 +238,8 @@ def test_select_filtered_table_completion(completer, complete_event): Completion(text="time_zone_name", start_position=0), Completion(text="time_zone_transition", start_position=0), Completion(text="time_zone_transition_type", start_position=0), - Completion(text="test", start_position=0), - Completion(text="`test 2`", start_position=0), + Completion(text="test.", start_position=0), + Completion(text="`test 2`.", start_position=0), ] @@ -257,8 +257,8 @@ def test_sub_select_filtered_table_completion(completer, complete_event): Completion(text="time_zone_name", start_position=0), Completion(text="time_zone_transition", start_position=0), Completion(text="time_zone_transition_type", start_position=0), - Completion(text="test", start_position=0), - Completion(text="`test 2`", start_position=0), + Completion(text="test.", start_position=0), + Completion(text="`test 2`.", start_position=0), ] @@ -512,8 +512,8 @@ def test_table_names_after_from(completer, complete_event): Completion(text="time_zone_name", start_position=0), Completion(text="time_zone_transition", start_position=0), Completion(text="time_zone_transition_type", start_position=0), - Completion(text="test", start_position=0), - Completion(text="`test 2`", start_position=0), + Completion(text="test.", start_position=0), + Completion(text="`test 2`.", start_position=0), ] @@ -530,6 +530,51 @@ def test_table_names_leading_partial(completer, complete_event): ] +def test_database_completion_for_table_appends_dot(completer, complete_event): + text = 'SELECT * FROM te' + position = len(text) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert Completion(text='test.', start_position=-2) in result + + +@pytest.mark.parametrize( + 'text', + [ + 'SHOW TABLES FROM te', + 'SHOW FULL TABLES FROM te', + 'SHOW COLUMNS FROM users FROM te', + ], +) +def test_show_database_completion_does_not_append_dot(completer, complete_event, text): + position = len(text) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert Completion(text='test', start_position=-2) in result + assert Completion(text='test.', start_position=-2) not in result + + +@pytest.mark.parametrize( + 'text', + [ + 'SHOW TABLES FROM test; SELECT * FROM te', + ], +) +def test_database_completion_for_table_ignores_previous_statement(completer, complete_event, text): + position = len(text) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert Completion(text='test.', start_position=-2) in result + assert Completion(text='test', start_position=-2) not in result + + +# todo: USE works interactively but not in the test suite +# this is also covered by test_show_database_completion_does_not_append_dot() +@pytest.mark.xfail +def test_database_completion_after_use_does_not_append_dot(completer, complete_event): + text = 'USE tes' + position = len(text) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert Completion(text='test', start_position=-2) in result + + def test_table_names_inter_partial(completer, complete_event): text = "SELECT * FROM time_leap" position = len("SELECT * FROM time_leap") @@ -586,8 +631,8 @@ def test_grant_on_suggets_tables_and_schemata(completer, complete_event): position = len(text) result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == [ - Completion(text="test", start_position=0), - Completion(text="`test 2`", start_position=0), + Completion(text="test.", start_position=0), + Completion(text="`test 2`.", start_position=0), Completion(text='users', start_position=0), Completion(text='orders', start_position=0), Completion(text='`select`', start_position=0), @@ -947,8 +992,8 @@ def test_backticked_table_completion_not_required(completer, complete_event): position = len(text) result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == [ - Completion(text='`test`', start_position=-2), - Completion(text='`test 2`', start_position=-2), + Completion(text='`test`.', start_position=-2), + Completion(text='`test 2`.', start_position=-2), Completion(text='`time_zone`', start_position=-2), Completion(text='`time_zone_name`', start_position=-2), Completion(text='`time_zone_transition`', start_position=-2),