diff --git a/datafaker/interactive/__init__.py b/datafaker/interactive/__init__.py index c279720f..0b97feed 100644 --- a/datafaker/interactive/__init__.py +++ b/datafaker/interactive/__init__.py @@ -7,6 +7,7 @@ from sqlalchemy import MetaData +from datafaker.interactive.base import DbCmd from datafaker.interactive.generators import GeneratorCmd, try_setting_generator from datafaker.interactive.missingness import MissingnessCmd from datafaker.interactive.table import TableCmd @@ -25,10 +26,15 @@ def update_config_tables( - src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping, + parquet_dir: Path | None, ) -> Mapping[str, Any]: """Ask the user to specify what should happen to each table.""" - with TableCmd(src_dsn, src_schema, metadata, config) as tc: + settings = DbCmd.Settings(src_dsn, src_schema, config, metadata, parquet_dir) + with TableCmd(settings) as tc: tc.cmdloop() return tc.config @@ -38,6 +44,7 @@ def update_missingness( src_schema: str | None, metadata: MetaData, config: MutableMapping[str, Any], + parquet_dir: Path | None, ) -> Mapping[str, Any]: """ Ask the user to update the missingness information in ``config.yaml``. @@ -49,16 +56,14 @@ def update_missingness( :param config: The starting configuration, :return: The updated configuration. """ - with MissingnessCmd(src_dsn, src_schema, metadata, config) as mc: + settings = DbCmd.Settings(src_dsn, src_schema, config, metadata, parquet_dir) + with MissingnessCmd(settings) as mc: mc.cmdloop() return mc.config def update_config_generators( - src_dsn: str, - src_schema: str | None, - metadata: MetaData, - config: MutableMapping[str, Any], + settings: DbCmd.Settings, spec_path: Path | None, ) -> Mapping[str, Any]: """ @@ -68,14 +73,11 @@ def update_config_generators( Column name (or space-separated list of column names), Generator name required, Second choice generator name, Third choice generator name, etcetera. - :param src_dsn: Address of the source database - :param src_schema: Name of the source database schema to read from - :param metadata: SQLAlchemy representation of the source database - :param config: Existing configuration (will be destructively updated) + :param settings: Source database settings. :param spec_path: The path of the CSV file containing the specification :return: Updated configuration. """ - with GeneratorCmd(src_dsn, src_schema, metadata, config) as gc: + with GeneratorCmd(settings) as gc: if spec_path is None: gc.cmdloop() return gc.config diff --git a/datafaker/interactive/base.py b/datafaker/interactive/base.py index a052c049..5ddfb255 100644 --- a/datafaker/interactive/base.py +++ b/datafaker/interactive/base.py @@ -4,6 +4,7 @@ from collections.abc import Mapping, MutableMapping, Sequence from dataclasses import dataclass from enum import Enum +from pathlib import Path from types import TracebackType from typing import Any, Optional, Type @@ -121,22 +122,38 @@ def make_table_entry( :return: The table entry or None if this table should not be interacted with. """ + @dataclass + class Settings: + """Settings for the source database.""" + + dsn: str + schema: str | None + config: MutableMapping[str, Any] + metadata: MetaData + parquet_dir: Path | None + def __init__( self, - src_dsn: str, - src_schema: str | None, - metadata: MetaData, - config: MutableMapping[str, Any], + settings: Settings, ): - """Initialise a DbCmd.""" + """ + Initialise a DbCmd. + + :param src_dsn: The database connection string for the source database. + :param src_schema: The name of the schema name for the source database. + :param metadata: The metadata for the source database. + :param config: The ``config.xml`` object. + :param parquet_dir: The directory where parquet files are stored that + are to be considered part of the source database (only for DuckDB). + """ super().__init__() - self.config: MutableMapping[str, Any] = config - self.metadata = metadata + self.config: MutableMapping[str, Any] = settings.config + self.metadata = settings.metadata self._table_entries: list[TableEntry] = [] - tables_config: MutableMapping = config.get("tables", {}) + tables_config: MutableMapping = self.config.get("tables", {}) if not isinstance(tables_config, MutableMapping): tables_config = {} - for name in metadata.tables.keys(): + for name in self.metadata.tables.keys(): table_config = tables_config.get(name, {}) if not isinstance(table_config, MutableMapping): table_config = {} @@ -144,7 +161,11 @@ def __init__( if entry is not None: self._table_entries.append(entry) self.table_index = 0 - self.engine = create_db_engine(src_dsn, schema_name=src_schema) + self.engine = create_db_engine( + settings.dsn, + schema_name=settings.schema, + parquet_dir=settings.parquet_dir, + ) @property def sync_engine(self) -> Engine: diff --git a/datafaker/interactive/generators.py b/datafaker/interactive/generators.py index 176bd260..ecc8403f 100644 --- a/datafaker/interactive/generators.py +++ b/datafaker/interactive/generators.py @@ -6,7 +6,7 @@ from typing import Any, Callable, Optional, cast import sqlalchemy -from sqlalchemy import Column, MetaData +from sqlalchemy import Column from datafaker.generators import everything_factory from datafaker.generators.base import Generator, PredefinedGenerator @@ -147,10 +147,7 @@ def make_table_entry( def __init__( self, - src_dsn: str, - src_schema: str | None, - metadata: MetaData, - config: MutableMapping[str, Any], + settings: DbCmd.Settings, ) -> None: """ Initialise a ``GeneratorCmd``. @@ -160,7 +157,7 @@ def __init__( :param metadata: SQLAlchemy metadata for the source database :param config: Configuration loaded from ``config.yaml`` """ - super().__init__(src_dsn, src_schema, metadata, config) + super().__init__(settings) self.generators: list[Generator] | None = None self.generator_index = 0 self.generators_valid_columns: Optional[tuple[int, list[str]]] = None diff --git a/datafaker/interactive/missingness.py b/datafaker/interactive/missingness.py index b79fbdc4..bf56d4c0 100644 --- a/datafaker/interactive/missingness.py +++ b/datafaker/interactive/missingness.py @@ -1,11 +1,9 @@ """Missingness configuration shell.""" import re -from collections.abc import Iterable, Mapping, MutableMapping +from collections.abc import Iterable, Mapping from dataclasses import dataclass from typing import cast -from sqlalchemy import MetaData - from datafaker.interactive.base import DbCmd, TableEntry @@ -139,20 +137,14 @@ def make_table_entry( def __init__( self, - src_dsn: str, - src_schema: str | None, - metadata: MetaData, - config: MutableMapping, + settings: DbCmd.Settings, ): """ Initialise a MissingnessCmd. - :param src_dsn: connection string for the source database. - :param src_schema: schema name for the source database. - :param metadata: SQLAlchemy metadata for the source database. - :param config: Configuration from the ``config.yaml`` file. + :param settings: source database settings. """ - super().__init__(src_dsn, src_schema, metadata, config) + super().__init__(settings) self.set_prompt() @property diff --git a/datafaker/interactive/table.py b/datafaker/interactive/table.py index 407f98ed..3bd045f8 100644 --- a/datafaker/interactive/table.py +++ b/datafaker/interactive/table.py @@ -1,10 +1,9 @@ """Table configuration command shell.""" -from collections.abc import Mapping, MutableMapping +from collections.abc import Mapping from dataclasses import dataclass from typing import Any, cast import sqlalchemy -from sqlalchemy import MetaData from datafaker.interactive.base import ( TYPE_LETTER, @@ -92,15 +91,12 @@ def make_table_entry( def __init__( self, - src_dsn: str, - src_schema: str | None, - metadata: MetaData, - config: MutableMapping[str, Any], + settings: DbCmd.Settings, *args: Any, **kwargs: Any, ) -> None: """Initialise a TableCmd.""" - super().__init__(src_dsn, src_schema, metadata, config, *args, **kwargs) + super().__init__(settings, *args, **kwargs) self.set_prompt() @property diff --git a/datafaker/main.py b/datafaker/main.py index 0269781a..ff768c61 100644 --- a/datafaker/main.py +++ b/datafaker/main.py @@ -28,6 +28,7 @@ update_config_tables, update_missingness, ) +from datafaker.interactive.base import DbCmd from datafaker.make import ( make_src_stats, make_table_generators, @@ -36,7 +37,6 @@ ) from datafaker.remove import remove_db_data, remove_db_tables, remove_db_vocab from datafaker.settings import ( - Settings, SettingsError, get_destination_dsn, get_destination_schema, @@ -87,19 +87,8 @@ def _check_file_non_existence(file_path: Path) -> None: sys.exit(1) -def _require_src_db_dsn(settings: Settings) -> str: - """Return the source DB DSN. - - Check that source DB details have been set. Exit with error message if not. - """ - if (src_dsn := settings.src_dsn) is None: - logger.error("Missing source database connection details.") - sys.exit(1) - return src_dsn - - def load_metadata_config( - orm_file_name: str, config: dict | None = None + orm_file_path: Path, config: dict | None = None ) -> dict[str, Any]: """ Load the ``orm.yaml`` file, returning a dict representation. @@ -110,7 +99,7 @@ def load_metadata_config( :return: A dict representing the ``orm.yaml`` file, with the tables the ``config`` says to ignore removed. """ - with open(orm_file_name, encoding="utf-8") as orm_fh: + with orm_file_path.open(encoding="utf-8") as orm_fh: meta_dict = yaml.load(orm_fh, yaml.Loader) if not isinstance(meta_dict, dict): return {} @@ -123,7 +112,7 @@ def load_metadata_config( return meta_dict -def load_metadata(orm_file_name: str, config: dict | None = None) -> MetaData: +def load_metadata(orm_file_path: Path, config: dict | None = None) -> MetaData: """ Load metadata from ``orm.yaml``. @@ -131,15 +120,15 @@ def load_metadata(orm_file_name: str, config: dict | None = None) -> MetaData: :param config: Used to exclude tables that are marked as ``ignore: true``. :return: SQLAlchemy MetaData object representing the database described by the loaded file. """ - meta_dict = load_metadata_config(orm_file_name, config) + meta_dict = load_metadata_config(orm_file_path, config) return dict_to_metadata(meta_dict, None) def load_metadata_for_output( - orm_file_name: str, config: dict | None = None + orm_file_path: Path, config: dict | None = None ) -> MetaData: """Load metadata excluding any foreign keys pointing to ignored tables.""" - meta_dict = load_metadata_config(orm_file_name, config) + meta_dict = load_metadata_config(orm_file_path, config) return dict_to_metadata(meta_dict, config) @@ -153,12 +142,20 @@ def main( @app.command() def create_data( - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), + orm_file: Path = Option( + ORM_FILENAME, + help="The name of the ORM yaml file", + dir_okay=False, + ), df_file: str = Option( DF_FILENAME, help="The name of the generators file. Must be in the current working directory.", + dir_okay=False, + ), + config_file: Optional[Path] = Option( + CONFIG_FILENAME, + help="The configuration file", ), - config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), num_passes: int = Option(1, help="Number of passes (rows or stories) to make"), ) -> None: """Populate the schema in the target directory with synthetic data. @@ -210,8 +207,16 @@ def create_data( @app.command() def create_vocab( - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), - config_file: str = Option(CONFIG_FILENAME, help="The configuration file"), + orm_file: Path = Option( + ORM_FILENAME, + help="The name of the ORM yaml file", + dir_okay=False, + ), + config_file: Path = Option( + CONFIG_FILENAME, + help="The configuration file", + dir_okay=False, + ), ) -> None: """Import vocabulary data into the target database. @@ -229,8 +234,16 @@ def create_vocab( @app.command() def create_tables( - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), - config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), + orm_file: Path = Option( + ORM_FILENAME, + help="The name of the ORM yaml file", + dir_okay=False, + ), + config_file: Optional[Path] = Option( + CONFIG_FILENAME, + help="The configuration file", + dir_okay=False, + ), ) -> None: """Create schema from the ORM YAML file. @@ -249,16 +262,29 @@ def create_tables( @app.command() def create_generators( - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), - df_file: str = Option(DF_FILENAME, help="Path to write Python generators to."), - config_file: str = Option(CONFIG_FILENAME, help="The configuration file"), - stats_file: Optional[str] = Option( + orm_file: Path = Option( + ORM_FILENAME, + help="The name of the ORM yaml file", + dir_okay=False, + ), + df_file: Path = Option( + DF_FILENAME, + help="Path to write Python generators to.", + dir_okay=False, + ), + config_file: Path = Option( + CONFIG_FILENAME, + help="The configuration file", + dir_okay=False, + ), + stats_file: Optional[Path] = Option( None, help=( "Statistics file (output of make-stats); default is src-stats.yaml if the " "config file references SRC_STATS, or None otherwise." ), show_default=False, + dir_okay=False, ), force: bool = Option( False, "--force", "-f", help="Overwrite any existing Python generators file." @@ -274,13 +300,12 @@ def create_generators( """ logger.debug("Making %s.", df_file) - df_file_path = Path(df_file) if not force: - _check_file_non_existence(df_file_path) + _check_file_non_existence(df_file) generator_config = read_config_file(config_file) if config_file is not None else {} if stats_file is None and generators_require_stats(generator_config): - stats_file = STATS_FILENAME + stats_file = Path(STATS_FILENAME) orm_metadata = load_metadata_for_output(orm_file, generator_config) result: str = make_table_generators( orm_metadata, @@ -290,15 +315,23 @@ def create_generators( stats_file, ) - df_file_path.write_text(result, encoding="utf-8") + df_file.write_text(result, encoding="utf-8") logger.debug("%s created.", df_file) @app.command() def make_vocab( - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), - config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), + orm_file: Path = Option( + ORM_FILENAME, + help="The name of the ORM yaml file", + dir_okay=False, + ), + config_file: Optional[Path] = Option( + CONFIG_FILENAME, + help="The configuration file", + dir_okay=False, + ), force: bool = Option( False, "--force/--no-force", @@ -328,35 +361,39 @@ def make_vocab( @app.command() def make_stats( - config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), - stats_file: str = Option(STATS_FILENAME), + orm_file: Path = Option( + ORM_FILENAME, + help="The name of the ORM yaml file", + dir_okay=False, + ), + config_file: Optional[Path] = Option( + CONFIG_FILENAME, + help="The configuration file", + dir_okay=False, + ), + stats_file: Path = Option(STATS_FILENAME), force: bool = Option( False, "--force", "-f", help="Overwrite any existing vocabulary file." ), ) -> None: - """Compute summary statistics from the source database. - - Writes the statistics to a YAML file. - - Example: - $ datafaker make_stats --config-file=example_config.yaml - """ + """Compute summary statistics from the source database.""" logger.debug("Creating %s.", stats_file) - stats_file_path = Path(stats_file) if not force: - _check_file_non_existence(stats_file_path) + _check_file_non_existence(stats_file) config = read_config_file(config_file) if config_file is not None else {} + meta_dict = load_metadata_config(orm_file, config) src_stats = asyncio.get_event_loop().run_until_complete( make_src_stats( get_source_dsn(), config, get_source_schema(), + parquet_dir=meta_dict.get("parquet-dir", None), ) ) - stats_file_path.write_text(yaml.dump(src_stats), encoding="utf-8") + stats_file.write_text(yaml.dump(src_stats), encoding="utf-8") logger.debug("%s created.", stats_file) @@ -399,73 +436,93 @@ def make_tables( @app.command() def configure_tables( - config_file: str = Option( - CONFIG_FILENAME, help="Path to write the configuration file to" + config_file: Path = Option( + CONFIG_FILENAME, + help="Path to write the configuration file to", + dir_okay=False, + ), + orm_file: Path = Option( + ORM_FILENAME, + help="The name of the ORM yaml file", + dir_okay=False, ), - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), ) -> None: """Interactively set tables to ignored, vocabulary or primary private.""" logger.debug("Configuring tables in %s.", config_file) - config_file_path = Path(config_file) config = {} - if config_file_path.exists(): + if config_file.exists(): config = yaml.load( - config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader + config_file.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader ) # we don't pass config here so that no tables are ignored - metadata = load_metadata(orm_file) + meta_dict = load_metadata_config(orm_file) + metadata = dict_to_metadata(meta_dict, None) config_updated = update_config_tables( get_source_dsn(), get_source_schema(), metadata, config, + Path(meta_dict["parquet-dir"]) if "parquet-dir" in meta_dict else None, ) if config_updated is None: logger.debug("Cancelled") return content = yaml.dump(config_updated) - config_file_path.write_text(content, encoding="utf-8") + config_file.write_text(content, encoding="utf-8") logger.debug("Tables configured in %s.", config_file) @app.command() def configure_missing( - config_file: str = Option( - CONFIG_FILENAME, help="Path to write the configuration file to" + config_file: Path = Option( + CONFIG_FILENAME, + help="Path to write the configuration file to", + dir_okay=False, + ), + orm_file: Path = Option( + ORM_FILENAME, + help="The name of the ORM yaml file", + dir_okay=False, ), - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), ) -> None: """Interactively set the missingness of the generated data.""" logger.debug("Configuring missingness in %s.", config_file) - config_file_path = Path(config_file) config: dict[str, Any] = {} - if config_file_path.exists(): + if config_file.exists(): config_any = yaml.load( - config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader + config_file.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader ) if isinstance(config_any, dict): config = config_any - metadata = load_metadata(orm_file, config) + meta_dict = load_metadata_config(orm_file, config) + metadata = dict_to_metadata(meta_dict, None) config_updated = update_missingness( get_source_dsn(), get_source_schema(), metadata, config, + Path(meta_dict["parquet-dir"]) if "parquet-dir" in meta_dict else None, ) if config_updated is None: logger.debug("Cancelled") return content = yaml.dump(config_updated) - config_file_path.write_text(content, encoding="utf-8") + config_file.write_text(content, encoding="utf-8") logger.debug("Generators missingness in %s.", config_file) @app.command() def configure_generators( - config_file: str = Option( - CONFIG_FILENAME, help="Path of the configuration file to alter" + config_file: Path = Option( + CONFIG_FILENAME, + help="Path of the configuration file to alter", + dir_okay=False, + ), + orm_file: Path = Option( + ORM_FILENAME, + help="The name of the ORM yaml file", + dir_okay=False, ), - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), spec: Path = Option( None, help=( @@ -476,25 +533,28 @@ def configure_generators( ) -> None: """Interactively set generators for column data.""" logger.debug("Configuring generators in %s.", config_file) - config_file_path = Path(config_file) config = {} - if config_file_path.exists(): + if config_file.exists(): config = yaml.load( - config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader + config_file.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader ) - metadata = load_metadata(orm_file) + meta_dict = load_metadata_config(orm_file) + metadata = dict_to_metadata(meta_dict, None) config_updated = update_config_generators( - get_source_dsn(), - get_source_schema(), - metadata, - config, + DbCmd.Settings( + get_source_dsn(), + get_source_schema(), + config, + metadata, + meta_dict.get("parquet-dir", None), + ), spec_path=spec, ) if config_updated is None: logger.debug("Cancelled") return content = yaml.dump(config_updated) - config_file_path.write_text(content, encoding="utf-8") + config_file.write_text(content, encoding="utf-8") logger.debug("Generators configured in %s.", config_file) @@ -538,12 +598,12 @@ def _dump_csv_to_stdout( def _get_writer( parquet: bool, - output: str | None, + output: Path | None, metadata: MetaData, dsn: str, schema_name: str | None, ) -> TableWriter: - if parquet or output and output.endswith(ParquetTableWriter.EXTENSION): + if parquet or output and output.suffix == ParquetTableWriter.EXTENSION: return get_parquet_table_writer(metadata, dsn, schema_name) return CsvTableWriter(metadata, dsn, schema_name) @@ -569,21 +629,29 @@ def _dump_tables_to_directory( @app.command() def dump_data( - config_file: Optional[str] = Option( - CONFIG_FILENAME, help="Path of the configuration file to use" + config_file: Optional[Path] = Option( + CONFIG_FILENAME, + help="Path of the configuration file to use", + dir_okay=False, + ), + orm_file: Path = Option( + ORM_FILENAME, + help="The name of the ORM yaml file", + dir_okay=False, ), - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), table: list[str] = Option( default=[], help="The tables to dump (default is all non-ignored, non-vocabulary tables)", ), - output: str + output: Path | None = Option( None, help=( "Output CSV or Parquet file name," " directory to write into or - to output to the console" ), + file_okay=True, + dir_okay=True, ), parquet: bool = Option( False, @@ -640,8 +708,16 @@ def validate_config( @app.command() def remove_data( - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), - config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), + orm_file: Path = Option( + ORM_FILENAME, + help="The name of the ORM yaml file", + dir_okay=False, + ), + config_file: Optional[Path] = Option( + CONFIG_FILENAME, + help="The configuration file", + dir_okay=False, + ), yes: bool = Option( False, "--yes", prompt="Are you sure?", help="Just remove, don't ask first" ), @@ -659,8 +735,16 @@ def remove_data( @app.command() def remove_vocab( - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), - config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), + orm_file: Path = Option( + ORM_FILENAME, + help="The name of the ORM yaml file", + dir_okay=False, + ), + config_file: Optional[Path] = Option( + CONFIG_FILENAME, + help="The configuration file", + dir_okay=False, + ), yes: bool = Option( False, "--yes", prompt="Are you sure?", help="Just remove, don't ask first" ), @@ -679,8 +763,16 @@ def remove_vocab( @app.command() def remove_tables( - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), - config_file: str = Option(CONFIG_FILENAME, help="The configuration file"), + orm_file: Path = Option( + ORM_FILENAME, + help="The name of the ORM yaml file", + dir_okay=False, + ), + config_file: Path = Option( + CONFIG_FILENAME, + help="The configuration file", + dir_okay=False, + ), # pylint: disable=redefined-builtin all: bool = Option( False, @@ -722,8 +814,16 @@ class TableType(str, Enum): @app.command() def list_tables( - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), - config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), + orm_file: Path = Option( + ORM_FILENAME, + help="The name of the ORM yaml file", + dir_okay=False, + ), + config_file: Optional[Path] = Option( + CONFIG_FILENAME, + help="The configuration file", + dir_okay=False, + ), tables: TableType = Option(TableType.GENERATED, help="Which tables to list"), ) -> None: """List the names of tables described in the metadata file.""" diff --git a/datafaker/make.py b/datafaker/make.py index 6cc4742a..a8dd58d0 100644 --- a/datafaker/make.py +++ b/datafaker/make.py @@ -2,11 +2,12 @@ import asyncio import decimal import inspect +from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass, field from datetime import datetime from pathlib import Path from types import TracebackType -from typing import Any, Callable, Final, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Final, Optional, Tuple, Type, Union import pandas as pd import snsql @@ -582,9 +583,9 @@ def make_vocabulary_tables( def make_table_generators( # pylint: disable=too-many-locals metadata: MetaData, config: Mapping, - orm_filename: str, - config_filename: str, - src_stats_filename: Optional[str], + orm_filename: Path, + config_filename: Path, + src_stats_filename: Optional[Path], ) -> str: """ Create datafaker generator classes. @@ -704,7 +705,7 @@ def make_tables_file( metadata = MetaData() metadata.reflect(engine) - meta_dict = metadata_to_dict(metadata, schema_name, engine) + meta_dict = metadata_to_dict(metadata, schema_name, engine, parquet_dir) if parquet_dir is not None: extra_meta = get_parquet_orm(parquet_dir) @@ -800,7 +801,10 @@ def fix_types(dics: list[dict]) -> list[dict]: async def make_src_stats( - dsn: str, config: Mapping, schema_name: Optional[str] = None + dsn: str, + config: Mapping, + schema_name: Optional[str] = None, + parquet_dir: Optional[Path] = None, ) -> dict[str, dict[str, Any]]: """ Run the src-stats queries specified by the configuration. @@ -815,7 +819,12 @@ async def make_src_stats( :return: The dictionary of src-stats. """ use_asyncio = config.get("use-asyncio", False) - engine = create_db_engine(dsn, schema_name=schema_name, use_asyncio=use_asyncio) + engine = create_db_engine( + dsn, + schema_name=schema_name, + use_asyncio=use_asyncio, + parquet_dir=parquet_dir, + ) async with DbConnection(engine) as db_conn: return await make_src_stats_connection(config, db_conn) diff --git a/datafaker/serialize_metadata.py b/datafaker/serialize_metadata.py index 5c689124..62bc01c7 100644 --- a/datafaker/serialize_metadata.py +++ b/datafaker/serialize_metadata.py @@ -1,6 +1,7 @@ """Convert between a Python dict describing a database schema and a SQLAlchemy MetaData.""" import typing from functools import partial +from pathlib import Path import parsy from sqlalchemy import Column, Dialect, Engine, ForeignKey, MetaData, Table @@ -282,14 +283,17 @@ def dict_to_table( def metadata_to_dict( - meta: MetaData, schema_name: str | None, engine: Engine + meta: MetaData, + schema_name: str | None, + engine: Engine, + parquet_dir: Path | None, ) -> dict[str, typing.Any]: """ Convert a metadata object into a Python dict. The output will be ready for output to ``orm.yaml``. """ - return { + d = { "tables": { str(table.name): table_to_dict(table, engine.dialect) for table in meta.tables.values() @@ -297,6 +301,9 @@ def metadata_to_dict( "dsn": str(engine.url), "schema": schema_name, } + if parquet_dir is not None: + d["parquet-dir"] = str(parquet_dir) + return d def should_ignore_fk(tables_dict: dict[str, TableT], fk: str) -> bool: diff --git a/datafaker/utils.py b/datafaker/utils.py index ae36ad03..9a1ffbcb 100644 --- a/datafaker/utils.py +++ b/datafaker/utils.py @@ -5,6 +5,7 @@ import io import json import logging +import os import random import re import string @@ -67,7 +68,7 @@ def iterable(cls) -> Iterable[T]: return (x for x in e) -def read_config_file(path: str) -> dict: +def read_config_file(path: Path) -> dict: """Read a config file, warning if it is invalid. Args: @@ -76,7 +77,7 @@ def read_config_file(path: str) -> dict: Returns: The config file as a dictionary. """ - with open(path, "r", encoding="utf8") as f: + with path.open("r", encoding="utf8") as f: config = yaml.safe_load(f) if not isinstance(config, dict): @@ -188,6 +189,7 @@ def create_db_engine( db_dsn: str, schema_name: Optional[str] = None, use_asyncio: bool = False, + parquet_dir: Optional[Path] = None, **kwargs: Any, ) -> MaybeAsyncEngine: """Create a SQLAlchemy Engine.""" @@ -197,12 +199,22 @@ def create_db_engine( else: engine = create_engine(db_dsn, **kwargs) + settings = {} if schema_name is not None: + settings["search_path"] = schema_name + if parquet_dir is not None: + joined = ",".join(_find_parquet_directories(parquet_dir)) + # double up single quotes + dj = joined.replace("'", "''") + # enclose in single quotes + settings["file_search_path"] = f"'{dj}'" + + if settings: event_engine = get_sync_engine(engine) @event.listens_for(event_engine, "connect", insert=True) def connect(dbapi_connection: DBAPIConnection, _: Any) -> None: - set_search_path(dbapi_connection, schema_name) + set_db_settings(dbapi_connection, settings) return engine @@ -236,7 +248,24 @@ def create_db_engine_dst( return create_db_engine(db_dsn, schema_name, use_asyncio) -def set_search_path(connection: DBAPIConnection, schema: str) -> None: +def _find_parquet_directories(parquet_dir: Path) -> list[str]: + """Find all the directories under ``parquet_dir`` that contain parquet files.""" + return [ + path + for path, _, filenames in os.walk(parquet_dir) + if _names_include_parquet(Path(path), filenames) + ] + + +def _names_include_parquet(path: Path, file_names: Iterable[str]) -> bool: + for fn in file_names: + entry = path / fn + if entry.is_file() and entry.suffix in {".parquet", ".parq"}: + return True + return False + + +def set_db_settings(connection: DBAPIConnection, settings: Mapping[str, str]) -> None: """Set the SEARCH_PATH for a PostgreSQL connection.""" # https://docs.sqlalchemy.org/en/20/dialects/postgresql.html#remote-schema-table-introspection-and-postgresql-search-path existing_autocommit = connection.autocommit @@ -244,7 +273,8 @@ def set_search_path(connection: DBAPIConnection, schema: str) -> None: cursor = connection.cursor() # Parametrised queries don't work with asyncpg, hence the f-string. - cursor.execute(f"SET search_path TO {schema};") + sql = "".join(f"SET {k} TO {v};" for k, v in settings.items()) + cursor.execute(sql) cursor.close() connection.autocommit = existing_autocommit @@ -787,25 +817,26 @@ def generators_require_stats(config: Mapping) -> bool: stats_required = False for where, call in (ois | sgs | table_calls).items(): for n, arg in enumerate(call.get("args", [])): - try: - names = ( - node.id - for node in ast.walk(ast.parse(arg)) - if isinstance(node, ast.Name) - ) - if any(name == "SRC_STATS" for name in names): - stats_required = True - except SyntaxError as e: - errors.append( - ( - "Syntax error in argument %d of %s: %s\n%s%s", - n + 1, - where, - e.msg, - arg, - underline_error(e), + if isinstance(arg, str): + try: + names = ( + node.id + for node in ast.walk(ast.parse(arg)) + if isinstance(node, ast.Name) + ) + if any(name == "SRC_STATS" for name in names): + stats_required = True + except SyntaxError as e: + errors.append( + ( + "Syntax error in argument %d of %s: %s\n%s%s", + n + 1, + where, + e.msg, + arg, + underline_error(e), + ) ) - ) for k, arg in call.get("kwargs", {}).items(): if isinstance(arg, str): try: diff --git a/docs/source/duckdb.rst b/docs/source/duckdb.rst index 962177e5..69ddc724 100644 --- a/docs/source/duckdb.rst +++ b/docs/source/duckdb.rst @@ -140,11 +140,14 @@ for Mac or Linux, or on Windows use: set SRC_DSN=duckdb:///:memory: set DST_DSN=duckdb:///./fake.db -Now generate the ``orm.yaml``: +Now generate the ``orm.yaml``. Here we are assuming the parquet files are +in a directory named ``inputdir``. This directory can be ``.`` but it is better +to put these files in a different directory so that we can be sure that the faked +output files don't get mixed up with these source files: .. code-block:: shell - datafaker make-tables --parquet-dir . + datafaker make-tables --parquet-dir inputdir ... and edit the ``orm.yaml`` file as detailed above. diff --git a/tests/examples/duckdb/signature_model.parquet b/tests/examples/duckdb/signature_model.parquet index 03b00d82..3771900d 100644 Binary files a/tests/examples/duckdb/signature_model.parquet and b/tests/examples/duckdb/signature_model.parquet differ diff --git a/tests/examples/orm.yaml b/tests/examples/orm.yaml new file mode 100644 index 00000000..fb8938ef --- /dev/null +++ b/tests/examples/orm.yaml @@ -0,0 +1,60 @@ +tables: + manufacturer: + columns: + founded: + type: TIMESTAMP WITH TIME ZONE + id: + primary: true + type: INTEGER + name: + type: TEXT + model: + columns: + id: + primary: true + type: INTEGER + introduced: + type: TIMESTAMP WITH TIME ZONE + manufacturer_id: + foreign_keys: + - manufacturer.id + type: INTEGER + name: + type: TEXT + player: + columns: + family_name: + type: TEXT + given_name: + type: TEXT + id: + primary: true + type: INTEGER + signature_model: + columns: + based_on: + foreign_keys: + - model.id + type: INTEGER + id: + primary: true + type: INTEGER + name: + type: VARCHAR(20) + player_id: + foreign_keys: + - player.id + type: INTEGER + string: + columns: + frequency: + type: DOUBLE PRECISION + id: + primary: true + type: INTEGER + model_id: + foreign_keys: + - model.id + type: INTEGER + position: + type: INTEGER diff --git a/tests/test_create.py b/tests/test_create.py index face6894..3ecdce39 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -47,7 +47,7 @@ def test_create_vocab(self) -> None: } self.set_configuration(config) meta_dict = metadata_to_dict( - self.metadata, self.schema_name, self.sync_engine + self.metadata, self.schema_name, self.sync_engine, None ) create_db_tables(self.metadata) create_db_vocab(self.metadata, meta_dict, config, Path("./tests/examples")) diff --git a/tests/test_dump.py b/tests/test_dump.py index 7663bc71..d6812185 100644 --- a/tests/test_dump.py +++ b/tests/test_dump.py @@ -6,6 +6,7 @@ from pathlib import Path import pandas as pd +import yaml from typer.testing import CliRunner from datafaker.dump import CsvTableWriter, get_parquet_table_writer @@ -89,6 +90,12 @@ class EndToEndParquetTestCase(DatafakerTestCase): database_type = TestDuckDb examples_dir = Path("tests/examples/duckdb") + def set_working_dir(self) -> None: + """Change to our working directory.""" + working_dir = tempfile.mkdtemp() + shutil.move(self.parquet_dir / "config.yaml", Path(working_dir) / "config.yaml") + os.chdir(working_dir) + def setUp(self) -> None: """Set up the files in a temporary directory.""" super().setUp() @@ -97,13 +104,42 @@ def setUp(self) -> None: for fname in os.listdir(self.examples_dir): shutil.copy(self.examples_dir / fname, self.parquet_dir / fname) self.start_dir = os.getcwd() - os.chdir(self.parquet_dir) + self.set_working_dir() def tearDown(self) -> None: """Return to the start directory.""" os.chdir(self.start_dir) return super().tearDown() + def make_orm_yaml(self, runner: CliRunner) -> None: + """Make the orm.yaml file, if necessary.""" + runner.invoke( + app, + [ + "make-tables", + "--parquet-dir", + str(self.parquet_dir), + "--orm-file", + "orm_auto.yaml", + ], + ) + # Fix up the orm.yaml; the dates might not have types set + with Path("orm_auto.yaml").open(encoding="utf-8") as orm_fh: + orm = yaml.load(orm_fh, yaml.SafeLoader) + t = orm["tables"] + t["manufacturer.parquet"]["columns"]["founded"]["type"] = "DATETIME" + t["model.parquet"]["columns"]["introduced"]["type"] = "DATETIME" + t["signature_model.parquet"]["columns"]["player_id"]["type"] = "INTEGER" + t["signature_model.parquet"]["columns"]["player_id"]["foreign_keys"] = [ + "player.parquet.id" + ] + t["signature_model.parquet"]["columns"]["based_on"]["type"] = "INTEGER" + t["signature_model.parquet"]["columns"]["based_on"]["foreign_keys"] = [ + "model.parquet.id" + ] + with Path("orm.yaml").open("w", encoding="utf-8") as out_fh: + yaml.dump(orm, out_fh, yaml.SafeDumper) + def test_end_to_end_parquet(self) -> None: """ Test that parquet with an orm.yaml works. @@ -122,6 +158,8 @@ def test_end_to_end_parquet(self) -> None: }, ) + self.make_orm_yaml(runner) + # Configure with the spec file result = runner.invoke( app, ["configure-generators", "--spec", str(self.parquet_dir / "spec.csv")] @@ -173,3 +211,19 @@ def test_end_to_end_parquet(self) -> None: self.assertLessEqual(v, i + 1) # Check that many of the possible keys have been used self.assertLess(num_passes / 3, len(player_ids)) + + +class CurrentDirEndToEndParquetTestCase(EndToEndParquetTestCase): + """ + Read in parquet, make some generators, output parquet. + + Do it from the parquet directory. + """ + + def set_working_dir(self) -> None: + """Change to our working directory.""" + os.chdir(self.parquet_dir) + + def make_orm_yaml(self, _runner: CliRunner) -> None: + """Make the orm.yaml file, if necessary.""" + # not necessary, we already have an orm.yaml diff --git a/tests/test_functional.py b/tests/test_functional.py index e8068fb5..eac67081 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -213,6 +213,7 @@ def test_workflow_maximal_args(self) -> None: completed_process = self.invoke( "--verbose", "make-stats", + f"--orm-file={self.alt_orm_file_path}", f"--stats-file={self.stats_file_path}", f"--config-file={self.config_file_path}", "--force", @@ -498,6 +499,7 @@ def test_unique_constraint_fail(self) -> None: ) self.invoke( "make-stats", + f"--orm-file={self.alt_orm_file_path}", f"--stats-file={self.stats_file_path}", f"--config-file={self.config_file_path}", "--force", diff --git a/tests/test_interactive_generators.py b/tests/test_interactive_generators.py index 807a6950..7ef2dbc3 100644 --- a/tests/test_interactive_generators.py +++ b/tests/test_interactive_generators.py @@ -7,6 +7,7 @@ from sqlalchemy import Connection, MetaData, select from datafaker.generators.choice import ChoiceGeneratorFactory +from datafaker.interactive.base import DbCmd from datafaker.interactive.generators import GeneratorCmd from tests.utils import ( GeneratesDBTestCase, @@ -39,7 +40,9 @@ class ConfigureGeneratorsTests(RequiresDBTestCase): def _get_cmd(self, config: MutableMapping[str, Any]) -> TestGeneratorCmd: """Get the command we are using for this test case.""" - return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) + return TestGeneratorCmd( + DbCmd.Settings(self.dsn, self.schema_name, config, self.metadata, None) + ) def test_null_configuration(self) -> None: """Test that the tables having null configuration does not break.""" @@ -594,7 +597,9 @@ def setUp(self) -> None: ChoiceGeneratorFactory.SUPPRESS_COUNT = 5 def _get_cmd(self, config: MutableMapping[str, Any]) -> TestGeneratorCmd: - return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) + return TestGeneratorCmd( + DbCmd.Settings(self.dsn, self.schema_name, config, self.metadata, None) + ) def _propose(self, gc: TestGeneratorCmd) -> dict[str, tuple[int, str, list[str]]]: gc.reset() @@ -756,7 +761,9 @@ class GeneratorTests(GeneratesDBTestCase): def _get_cmd(self, config: MutableMapping[str, Any]) -> TestGeneratorCmd: """We are using configure-generators.""" - return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) + return TestGeneratorCmd( + DbCmd.Settings(self.dsn, self.schema_name, config, self.metadata, None) + ) def test_set_null(self) -> None: """Test that we can sample real missingness and reproduce it.""" diff --git a/tests/test_interactive_generators_partitioned.py b/tests/test_interactive_generators_partitioned.py index d5319165..fd26fd69 100644 --- a/tests/test_interactive_generators_partitioned.py +++ b/tests/test_interactive_generators_partitioned.py @@ -7,6 +7,7 @@ from sqlalchemy import Connection, MetaData, insert, select from datafaker.generators import NullPartitionedNormalGeneratorFactory +from datafaker.interactive.base import DbCmd from tests.test_interactive_generators import TestGeneratorCmd from tests.utils import GeneratesDBTestCase @@ -140,7 +141,9 @@ def setUp(self) -> None: def _get_cmd(self, config: MutableMapping[str, Any]) -> TestGeneratorCmd: """Get the configure-generators object as our command.""" - return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) + return TestGeneratorCmd( + DbCmd.Settings(self.dsn, self.schema_name, config, self.metadata, None) + ) def _propose(self, gc: TestGeneratorCmd) -> dict[str, tuple[int, str, list[str]]]: gc.reset() diff --git a/tests/test_interactive_missingness.py b/tests/test_interactive_missingness.py index 7a63ea52..e8b47cbe 100644 --- a/tests/test_interactive_missingness.py +++ b/tests/test_interactive_missingness.py @@ -6,6 +6,7 @@ from sqlalchemy import select from datafaker.interactive import MissingnessCmd +from datafaker.interactive.base import DbCmd from tests.utils import GeneratesDBTestCase, RequiresDBTestCase, TestDbCmdMixin @@ -22,7 +23,9 @@ class ConfigureMissingnessTests(RequiresDBTestCase): def _get_cmd(self, config: MutableMapping[str, Any]) -> TestMissingnessCmd: """We are using configure-missingness.""" - return TestMissingnessCmd(self.dsn, self.schema_name, self.metadata, config) + return TestMissingnessCmd( + DbCmd.Settings(self.dsn, self.schema_name, config, self.metadata, None) + ) def test_set_missingness_to_sampled(self) -> None: """Test that we can set one table to sampled missingness.""" @@ -74,7 +77,9 @@ class ConfigureMissingnessTestsWithGeneration(GeneratesDBTestCase): schema_name = "public" def _get_cmd(self, config: MutableMapping[str, Any]) -> TestMissingnessCmd: - return TestMissingnessCmd(self.dsn, self.schema_name, self.metadata, config) + return TestMissingnessCmd( + DbCmd.Settings(self.dsn, self.schema_name, config, self.metadata, None) + ) def test_create_with_missingness(self) -> None: """Test that we can sample real missingness and reproduce it.""" diff --git a/tests/test_interactive_table.py b/tests/test_interactive_table.py index 1ab9f7ef..34b5da5d 100644 --- a/tests/test_interactive_table.py +++ b/tests/test_interactive_table.py @@ -6,6 +6,7 @@ from sqlalchemy import select from datafaker.interactive import TableCmd +from datafaker.interactive.base import DbCmd from datafaker.serialize_metadata import dict_to_metadata from tests.utils import RequiresDBTestCase, TestDbCmdMixin @@ -18,7 +19,15 @@ class ConfigureTablesTests(RequiresDBTestCase): """Testing configure-tables.""" def _get_cmd(self, config: MutableMapping[str, Any]) -> TestTableCmd: - return TestTableCmd(self.dsn, self.schema_name, self.metadata, config) + return TestTableCmd( + DbCmd.Settings( + self.dsn, + self.schema_name, + config, + self.metadata, + None, + ) + ) class ConfigureTablesSrcTests(ConfigureTablesTests): @@ -384,7 +393,9 @@ def test_sanity_checks_warnings_only(self) -> None: }, }, } - with TestTableCmd(self.dsn, self.schema_name, self.metadata, config) as tc: + with TestTableCmd( + DbCmd.Settings(self.dsn, self.schema_name, config, self.metadata, None) + ) as tc: tc.do_next("manufacturer") tc.do_vocabulary("") tc.reset() @@ -426,7 +437,9 @@ def test_sanity_checks_errors_only(self) -> None: }, }, } - with TestTableCmd(self.dsn, self.schema_name, self.metadata, config) as tc: + with TestTableCmd( + DbCmd.Settings(self.dsn, self.schema_name, config, self.metadata, None) + ) as tc: tc.do_next("signature_model") tc.do_empty("") tc.reset() @@ -525,10 +538,13 @@ def test_repeated_field_does_not_throw_exception(self) -> None: Select with repeated fields (#70). """ with TestTableCmd( - src_dsn=self.dsn, - src_schema=self.schema_name, - metadata=self.metadata, - config={}, + DbCmd.Settings( + self.dsn, + self.schema_name, + config={}, + metadata=self.metadata, + parquet_dir=None, + ), print_tables=True, ) as tc: tc.reset() diff --git a/tests/test_main.py b/tests/test_main.py index 1171af87..50d86027 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,6 +1,7 @@ """Tests for the main module.""" import os from pathlib import Path +from typing import Any from unittest.mock import MagicMock, call, patch import yaml @@ -14,38 +15,17 @@ runner = CliRunner(mix_stderr=False) -class TestCLI(DatafakerTestCase): +class TestCliGeneratorOutput(DatafakerTestCase): """Tests for the command-line interface.""" - @patch("datafaker.main.read_config_file") - @patch("datafaker.main.dict_to_metadata") - @patch("datafaker.main.load_metadata_config") - @patch("datafaker.main.create_db_vocab") - def test_create_vocab( - self, - mock_create: MagicMock, - mock_mdict: MagicMock, - mock_meta: MagicMock, - mock_config: MagicMock, - ) -> None: - """Test the create-vocab sub-command.""" - result = runner.invoke( - app, - [ - "create-vocab", - ], - catch_exceptions=False, - ) - - mock_create.assert_called_once_with( - mock_meta.return_value, mock_mdict.return_value, mock_config.return_value - ) - self.assertSuccess(result) + use_temporary_cwd = True + example_conf = "example_config.yaml" + copy_files = [example_conf, "orm.yaml"] + copy_from_directory = Path("examples") @patch("datafaker.main.read_config_file") @patch("datafaker.main.load_metadata_for_output") @patch("datafaker.settings.get_settings") - @patch("datafaker.main.Path") @patch("datafaker.main.make_table_generators") @patch("datafaker.main.generators_require_stats") # pylint: disable=too-many-positional-arguments,too-many-arguments @@ -53,14 +33,12 @@ def test_create_generators( self, mock_require_stats: MagicMock, mock_make: MagicMock, - mock_path: MagicMock, mock_settings: MagicMock, mock_load_meta: MagicMock, mock_config: MagicMock, ) -> None: """Test the create-generators sub-command.""" mock_require_stats.return_value = False - mock_path.return_value.exists.return_value = False mock_make.return_value = "some text" mock_settings.return_value.src_postges_dsn = "" @@ -68,26 +46,26 @@ def test_create_generators( app, [ "create-generators", + "--config-file", + self.example_conf, ], catch_exceptions=False, ) + self.assertSuccess(result) mock_make.assert_called_once_with( mock_load_meta.return_value, mock_config.return_value, - "orm.yaml", - "config.yaml", + Path("orm.yaml"), + Path(self.example_conf), None, ) - mock_path.return_value.write_text.assert_called_once_with( - "some text", encoding="utf-8" - ) - self.assertSuccess(result) + with Path("df.py").open(encoding="utf-8") as dfh: + self.assertEqual(dfh.read(), "some text") @patch("datafaker.main.read_config_file") @patch("datafaker.main.load_metadata_for_output") @patch("datafaker.settings.get_settings") - @patch("datafaker.main.Path") @patch("datafaker.main.make_table_generators") @patch("datafaker.main.generators_require_stats") # pylint: disable=too-many-positional-arguments,too-many-arguments @@ -95,14 +73,12 @@ def test_create_generators_uses_default_stats_file_if_necessary( self, mock_require_stats: MagicMock, mock_make: MagicMock, - mock_path: MagicMock, mock_settings: MagicMock, mock_load_meta: MagicMock, mock_config: MagicMock, ) -> None: """Test the create-generators sub-command.""" mock_require_stats.return_value = True - mock_path.return_value.exists.return_value = False mock_make.return_value = "some text" mock_settings.return_value.src_postges_dsn = "" @@ -110,6 +86,8 @@ def test_create_generators_uses_default_stats_file_if_necessary( app, [ "create-generators", + "--config-file", + self.example_conf, ], catch_exceptions=False, ) @@ -117,54 +95,58 @@ def test_create_generators_uses_default_stats_file_if_necessary( mock_make.assert_called_once_with( mock_load_meta.return_value, mock_config.return_value, - "orm.yaml", - "config.yaml", - "src-stats.yaml", - ) - mock_path.return_value.write_text.assert_called_once_with( - "some text", encoding="utf-8" + Path("orm.yaml"), + Path(self.example_conf), + Path("src-stats.yaml"), ) self.assertSuccess(result) + with Path("df.py").open(encoding="utf-8") as dfh: + self.assertEqual(dfh.read(), "some text") - @patch("datafaker.main.Path") @patch("datafaker.main.logger") def test_create_generators_errors_if_file_exists( - self, mock_logger: MagicMock, mock_path: MagicMock + self, + mock_logger: MagicMock, ) -> None: """Test the create-generators sub-command doesn't overwrite.""" + df_path = Path("df.py") - mock_path.return_value.exists.return_value = True - mock_path.return_value.__str__.return_value = "df.py" + with df_path.open(mode="w", encoding="utf-8") as dfh: + dfh.write("already exists!\n") result = runner.invoke( app, [ "create-generators", + "--config-file", + self.example_conf, ], catch_exceptions=False, ) mock_logger.error.assert_called_once_with( - "%s should not already exist. Exiting...", mock_path.return_value + "%s should not already exist. Exiting...", + df_path, ) self.assertEqual(1, result.exit_code) + +class TestCLI(DatafakerTestCase): + """Tests for the command-line interface.""" + @patch("datafaker.main.read_config_file") @patch("datafaker.main.load_metadata_for_output") @patch("datafaker.settings.get_settings") - @patch("datafaker.main.Path") @patch("datafaker.main.make_table_generators") # pylint: disable=too-many-positional-arguments,too-many-arguments def test_create_generators_with_force_enabled( self, mock_make: MagicMock, - mock_path: MagicMock, mock_settings: MagicMock, mock_load_meta: MagicMock, mock_config: MagicMock, ) -> None: """Tests the create-generators sub-commands overwrite files when instructed.""" - mock_path.return_value.exists.return_value = True mock_make.return_value = "make result" mock_settings.return_value.src_postges_dsn = "" @@ -178,20 +160,15 @@ def test_create_generators_with_force_enabled( ], ) + self.assertSuccess(result) mock_make.assert_called_once_with( mock_load_meta.return_value, mock_config.return_value, - "orm.yaml", - "config.yaml", + Path("orm.yaml"), + Path("config.yaml"), None, ) - mock_path.return_value.write_text.assert_called_once_with( - "make result", encoding="utf-8" - ) - self.assertSuccess(result) - mock_make.reset_mock() - mock_path.reset_mock() @patch("datafaker.main.create_db_tables") @patch("datafaker.main.read_config_file") @@ -364,19 +341,134 @@ def test_make_tables_with_force_enabled( mock_make_tables.reset_mock() mock_path.reset_mock() - @patch("datafaker.main.Path") + def test_validate_config(self) -> None: + """Test the validate-config sub-command.""" + result = runner.invoke( + app, + ["validate-config", "tests/examples/example_config.yaml"], + catch_exceptions=False, + ) + + self.assertSuccess(result) + + def test_validate_config_invalid(self) -> None: + """Test the validate-config sub-command.""" + result = runner.invoke( + app, + ["validate-config", "tests/examples/invalid_config.yaml"], + catch_exceptions=False, + ) + + self.assertEqual(1, result.exit_code) + + @patch("datafaker.main.remove_db_data") + @patch("datafaker.main.read_config_file") + @patch("datafaker.main.load_metadata_for_output") + def test_remove_data( + self, + mock_meta: MagicMock, + mock_config: MagicMock, + mock_remove: MagicMock, + ) -> None: + """Test the remove-data command.""" + result = runner.invoke( + app, + ["remove-data", "--yes"], + catch_exceptions=False, + ) + self.assertEqual(0, result.exit_code) + mock_remove.assert_called_once_with( + mock_meta.return_value, mock_config.return_value + ) + + @patch("datafaker.main.read_config_file") + @patch("datafaker.main.remove_db_vocab") + @patch("datafaker.main.load_metadata_config") + @patch("datafaker.main.dict_to_metadata") + def test_remove_vocab( + self, + mock_d2m: MagicMock, + mock_load_metadata: MagicMock, + mock_remove: MagicMock, + mock_read_config: MagicMock, + ) -> None: + """Test the remove-vocab command.""" + result = runner.invoke( + app, + ["remove-vocab", "--yes"], + catch_exceptions=False, + ) + self.assertEqual(0, result.exit_code) + mock_read_config.assert_called_once_with(Path("config.yaml")) + mock_remove.assert_called_once_with( + mock_d2m.return_value, + mock_load_metadata.return_value, + mock_read_config.return_value, + ) + + @patch("datafaker.main.remove_db_tables") + @patch("datafaker.main.load_metadata_for_output") + @patch("datafaker.main.read_config_file") + def test_remove_tables( + self, _: MagicMock, mock_meta: MagicMock, mock_remove: MagicMock + ) -> None: + """Test the remove-tables command.""" + result = runner.invoke( + app, + ["remove-tables", "--yes"], + catch_exceptions=False, + ) + self.assertEqual(0, result.exit_code) + mock_remove.assert_called_once_with(mock_meta.return_value) + + +class TestCliOutput(DatafakerTestCase): + """Test CLI commands that output files.""" + + use_temporary_cwd = True + example_conf = "example_config.yaml" + copy_files = [example_conf, "orm.yaml"] + copy_from_directory = Path("examples") + + def load_yaml(self, file_name: str | Path) -> Any: + """Load the YAML and return it as a dict.""" + with open(file_name, "r", encoding="utf-8") as fh: + return yaml.load(fh, yaml.SafeLoader) + + @patch("datafaker.main.read_config_file") + @patch("datafaker.main.dict_to_metadata") + @patch("datafaker.main.load_metadata_config") + @patch("datafaker.main.create_db_vocab") + def test_create_vocab( + self, + mock_create: MagicMock, + mock_mdict: MagicMock, + mock_meta: MagicMock, + mock_config: MagicMock, + ) -> None: + """Test the create-vocab sub-command.""" + result = runner.invoke( + app, + [ + "create-vocab", + ], + catch_exceptions=False, + ) + + mock_create.assert_called_once_with( + mock_meta.return_value, mock_mdict.return_value, mock_config.return_value + ) + self.assertSuccess(result) + @patch("datafaker.main.make_src_stats") @patch("datafaker.settings.get_settings") def test_make_stats( self, mock_get_settings: MagicMock, mock_make: MagicMock, - mock_path: MagicMock, ) -> None: """Test the make-stats sub-command.""" - example_conf_path = "tests/examples/example_config.yaml" output_path = Path("make_stats_output.yaml") - mock_path.return_value.exists.return_value = False mock_make.return_value = {"a": 1} mock_get_settings.return_value = get_test_settings() result = runner.invoke( @@ -384,47 +476,45 @@ def test_make_stats( [ "make-stats", f"--stats-file={output_path}", - f"--config-file={example_conf_path}", + f"--config-file={self.example_conf}", ], catch_exceptions=False, ) self.assertSuccess(result) - with open(example_conf_path, "r", encoding="utf8") as f: - config = yaml.safe_load(f) - mock_make.assert_called_once_with(get_test_settings().src_dsn, config, None) - mock_path.return_value.write_text.assert_called_once_with( - "a: 1\n", encoding="utf-8" + config = self.load_yaml(self.example_conf) + mock_make.assert_called_once_with( + get_test_settings().src_dsn, config, None, parquet_dir=None ) + output = self.load_yaml(output_path) + self.assertDictEqual(output, {"a": 1}) - @patch("datafaker.main.Path") @patch("datafaker.main.logger") def test_make_stats_errors_if_file_exists( - self, mock_logger: MagicMock, mock_path: MagicMock + self, + mock_logger: MagicMock, ) -> None: """Test the make-stats sub-command when the stats file already exists.""" - mock_path.return_value.exists.return_value = True - example_conf_path = "tests/examples/example_config.yaml" output_path = "make_stats_output.yaml" - mock_path.return_value.__str__.return_value = output_path + with open(output_path, "w", encoding="utf-8") as fh: + fh.write("some content\n") result = runner.invoke( app, [ "make-stats", f"--stats-file={output_path}", - f"--config-file={example_conf_path}", + f"--config-file={self.example_conf}", ], catch_exceptions=False, ) mock_logger.error.assert_called_once_with( - "%s should not already exist. Exiting...", mock_path.return_value + "%s should not already exist. Exiting...", Path(output_path) ) self.assertEqual(1, result.exit_code) @patch.dict(os.environ, {"SRC_SCHEMA": "myschema"}, clear=True) def test_make_stats_errors_if_no_src_dsn(self) -> None: """Test the make-stats sub-command with missing settings.""" - example_conf_path = "tests/examples/example_config.yaml" self.assertRaises( SettingsError, @@ -432,27 +522,23 @@ def test_make_stats_errors_if_no_src_dsn(self) -> None: app, [ "make-stats", - f"--config-file={example_conf_path}", - "--stats-file=tests/examples/does-not-exist.yaml", + f"--config-file={self.example_conf}", + "--stats-file=does-not-exist.yaml", ], catch_exceptions=False, ) - @patch("datafaker.main.Path") @patch("datafaker.main.make_src_stats") @patch("datafaker.settings.get_settings") def test_make_stats_with_force_enabled( self, mock_get_settings: MagicMock, mock_make_src_stats: MagicMock, - mock_path: MagicMock, ) -> None: """Tests that the make-stats command overwrite files when instructed.""" - test_config_file: str = "tests/examples/example_config.yaml" - with open(test_config_file, "r", encoding="utf8") as f: + with open(self.example_conf, "r", encoding="utf8") as f: config_file_content: dict = yaml.safe_load(f) - mock_path.return_value.exists.return_value = True test_settings: Settings = get_test_settings() mock_get_settings.return_value = test_settings make_test_output: dict = {"some_stat": 0} @@ -465,7 +551,7 @@ def test_make_stats_with_force_enabled( [ "make-stats", "--stats-file=stats_file.yaml", - f"--config-file={test_config_file}", + f"--config-file={self.example_conf}", force_option, ], ) @@ -474,91 +560,10 @@ def test_make_stats_with_force_enabled( test_settings.src_dsn, config_file_content, test_settings.src_schema, + parquet_dir=None, ) - mock_path.return_value.write_text.assert_called_once_with( - "some_stat: 0\n", encoding="utf-8" + mock_make_src_stats.reset_mock() + self.assertDictEqual( + self.load_yaml("stats_file.yaml"), {"some_stat": 0} ) self.assertSuccess(result) - - mock_make_src_stats.reset_mock() - mock_path.reset_mock() - - def test_validate_config(self) -> None: - """Test the validate-config sub-command.""" - result = runner.invoke( - app, - ["validate-config", "tests/examples/example_config.yaml"], - catch_exceptions=False, - ) - - self.assertSuccess(result) - - def test_validate_config_invalid(self) -> None: - """Test the validate-config sub-command.""" - result = runner.invoke( - app, - ["validate-config", "tests/examples/invalid_config.yaml"], - catch_exceptions=False, - ) - - self.assertEqual(1, result.exit_code) - - @patch("datafaker.main.remove_db_data") - @patch("datafaker.main.read_config_file") - @patch("datafaker.main.load_metadata_for_output") - def test_remove_data( - self, - mock_meta: MagicMock, - mock_config: MagicMock, - mock_remove: MagicMock, - ) -> None: - """Test the remove-data command.""" - result = runner.invoke( - app, - ["remove-data", "--yes"], - catch_exceptions=False, - ) - self.assertEqual(0, result.exit_code) - mock_remove.assert_called_once_with( - mock_meta.return_value, mock_config.return_value - ) - - @patch("datafaker.main.read_config_file") - @patch("datafaker.main.remove_db_vocab") - @patch("datafaker.main.load_metadata_config") - @patch("datafaker.main.dict_to_metadata") - def test_remove_vocab( - self, - mock_d2m: MagicMock, - mock_load_metadata: MagicMock, - mock_remove: MagicMock, - mock_read_config: MagicMock, - ) -> None: - """Test the remove-vocab command.""" - result = runner.invoke( - app, - ["remove-vocab", "--yes"], - catch_exceptions=False, - ) - self.assertEqual(0, result.exit_code) - mock_read_config.assert_called_once_with("config.yaml") - mock_remove.assert_called_once_with( - mock_d2m.return_value, - mock_load_metadata.return_value, - mock_read_config.return_value, - ) - - @patch("datafaker.main.remove_db_tables") - @patch("datafaker.main.load_metadata_for_output") - @patch("datafaker.main.read_config_file") - def test_remove_tables( - self, _: MagicMock, mock_meta: MagicMock, mock_remove: MagicMock - ) -> None: - """Test the remove-tables command.""" - result = runner.invoke( - app, - ["remove-tables", "--yes"], - catch_exceptions=False, - ) - self.assertEqual(0, result.exit_code) - mock_remove.assert_called_once_with(mock_meta.return_value) diff --git a/tests/test_make.py b/tests/test_make.py index b890e4fa..601aae47 100644 --- a/tests/test_make.py +++ b/tests/test_make.py @@ -1,16 +1,19 @@ """Tests for the main module.""" import asyncio import os +import tempfile from pathlib import Path +from typing import Any from unittest.mock import MagicMock, patch +import pandas as pd import yaml from sqlalchemy import BigInteger, Column, String, select from sqlalchemy.dialects.mysql.types import INTEGER from sqlalchemy.dialects.postgresql import UUID from datafaker.make import _get_provider_for_column, make_src_stats -from tests.utils import GeneratesDBTestCase, RequiresDBTestCase +from tests.utils import DatafakerTestCase, GeneratesDBTestCase, RequiresDBTestCase class TestMakeGenerators(GeneratesDBTestCase): @@ -227,3 +230,55 @@ def test_make_stats_empty_result(self, mock_logger: MagicMock) -> None: debug_template = "src-stats query %s returned no results" mock_logger.debug.assert_any_call(debug_template, query_name1) mock_logger.debug.assert_any_call(debug_template, query_name2) + + +class TestMakeStatsParquet(DatafakerTestCase): + """ + Output to the database should not have access to parquet files. + + Otherwise there is a risk of leakage of source data. + """ + + parquet_name = "fruit.parquet" + + def setUp(self) -> None: + """Go to the directory where there are parquet files.""" + super().setUp() + self.parquet_dir = Path(tempfile.mkdtemp("parq")) + self.write_parquet() + + def write_parquet(self) -> None: + """Write a parquet file into the current directory.""" + fruit: dict[str, list[Any]] = { + "id": [1, 2, 3], + "one": ["lemon", "orange", "lime"], + "two": ["grape", "fig", "melon"], + } + pd.DataFrame.from_dict(fruit).to_parquet( + Path(self.parquet_dir) / self.parquet_name + ) + + def test_make_stats_parquet(self) -> None: + """Test that make stats can access parquet if we want it to.""" + src_stats = asyncio.get_event_loop().run_until_complete( + make_src_stats( + "duckdb:///:memory:", + { + "src-stats": [ + {"name": "one_query", "query": "SELECT one FROM fruit.parquet"}, + {"name": "two_query", "query": "SELECT two FROM fruit.parquet"}, + ] + }, + parquet_dir=self.parquet_dir, + ) + ) + self.assertIn("one_query", src_stats) + self.assertSetEqual( + {v.get("one") for v in src_stats["one_query"]["results"]}, + {"lemon", "orange", "lime"}, + ) + self.assertIn("two_query", src_stats) + self.assertSetEqual( + {v.get("two") for v in src_stats["two_query"]["results"]}, + {"grape", "fig", "melon"}, + ) diff --git a/tests/test_noninteractive_generators.py b/tests/test_noninteractive_generators.py index 93431147..d3d18c02 100644 --- a/tests/test_noninteractive_generators.py +++ b/tests/test_noninteractive_generators.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock, Mock, patch from datafaker.interactive import update_config_generators +from datafaker.interactive.base import DbCmd from tests.utils import RequiresDBTestCase @@ -49,7 +50,14 @@ def test_non_interactive_configure_generators( config: MutableMapping[str, Any] = {} spec_csv = Mock(return_value="mock spec.csv file") update_config_generators( - self.dsn, self.schema_name, self.metadata, config, spec_csv + DbCmd.Settings( + self.dsn, + self.schema_name, + config, + self.metadata, + None, + ), + spec_csv, ) row_gens = { f"{table}{sorted(rg['columns_assigned'])}": rg["name"] @@ -87,7 +95,14 @@ def test_non_interactive_configure_null_partitioned( config: MutableMapping[str, Any] = {} spec_csv = Mock(return_value="mock spec.csv file") update_config_generators( - self.dsn, self.schema_name, self.metadata, config, spec_csv + DbCmd.Settings( + self.dsn, + self.schema_name, + config, + self.metadata, + None, + ), + spec_csv, ) row_gens = { f"{table}{sorted(rg['columns_assigned'])}": rg @@ -152,7 +167,14 @@ def test_non_interactive_configure_null_partitioned_where_existing_merges( } spec_csv = Mock(return_value="mock spec.csv file") update_config_generators( - self.dsn, self.schema_name, self.metadata, config, spec_csv + DbCmd.Settings( + self.dsn, + self.schema_name, + config, + self.metadata, + None, + ), + spec_csv, ) row_gens: Mapping[str, Any] = { f"{table}{sorted(rg['columns_assigned'])}": rg diff --git a/tests/test_remove.py b/tests/test_remove.py index f0213b76..abbcefb5 100644 --- a/tests/test_remove.py +++ b/tests/test_remove.py @@ -54,7 +54,12 @@ def test_remove_vocab(self, mock_get_settings: MagicMock) -> None: src_dsn=self.dsn, dst_dsn=self.dsn, ) - meta_dict = metadata_to_dict(self.metadata, self.schema_name, self.sync_engine) + meta_dict = metadata_to_dict( + self.metadata, + self.schema_name, + self.sync_engine, + parquet_dir=None, + ) config = { "tables": { "manufacturer": {"vocabulary_table": True}, diff --git a/tests/test_utils.py b/tests/test_utils.py index ac82d124..9f0ab9b0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -111,7 +111,7 @@ class TestReadConfig(DatafakerTestCase): def test_warns_of_invalid_config(self) -> None: """Test that we get a warning if the config is invalid.""" with patch("datafaker.utils.logger") as mock_logger: - read_config_file("tests/examples/invalid_config.yaml") + read_config_file(Path("tests/examples/invalid_config.yaml")) mock_logger.error.assert_called_with( "The config file is invalid: %s", "'a' is not of type 'integer'" ) diff --git a/tests/utils.py b/tests/utils.py index c5fef7f8..791ffedd 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,11 +5,13 @@ import re import shutil import string +import sys import time import traceback from abc import ABC, abstractmethod from collections.abc import MutableSequence, Sequence from functools import lru_cache +from importlib import resources from pathlib import Path from subprocess import run from tempfile import mkdtemp, mkstemp @@ -234,9 +236,29 @@ class DatafakerTestCase(TestCase): examples_dir = Path("tests/examples") dump_file_path: str | None = None database_name: str | None = None + use_temporary_cwd = False + copy_files: list[str] = [] + copy_from_directory: Path = Path(".") def setUp(self) -> None: + super().setUp() settings.get_settings.cache_clear() + if self.use_temporary_cwd: + self.start_dir = os.getcwd() + self.working_dir = mkdtemp("test") + from_dir = resources.files(sys.modules["tests"]) / str( + self.copy_from_directory + ) + for cf in self.copy_files: + with resources.as_file(from_dir / cf) as cff: + shutil.copy(cff, self.working_dir) + os.chdir(self.working_dir) + + def tearDown(self) -> None: + if self.use_temporary_cwd: + os.chdir(self.start_dir) + shutil.rmtree(self.working_dir) + super().tearDown() def assertReturnCode( # pylint: disable=invalid-name self, result: Any, expected_code: int @@ -448,9 +470,9 @@ def create_generators(self, config: Mapping[str, Any]) -> None: datafaker_content = make_table_generators( self.metadata, config, - self.orm_file_path, - self.config_file_path, - self.stats_file_path, + Path(self.orm_file_path), + Path(self.config_file_path), + Path(self.stats_file_path), ) (generators_fd, self.generators_file_path) = mkstemp(".py", "dfgen_", text=True) with os.fdopen(generators_fd, "w", encoding="utf-8") as datafaker_fh: