diff --git a/integration/cypher_template_test.go b/integration/cypher_template_test.go index 204b29f..6a616c5 100644 --- a/integration/cypher_template_test.go +++ b/integration/cypher_template_test.go @@ -21,7 +21,6 @@ package integration import ( "context" "encoding/json" - "errors" "fmt" "os" "path/filepath" @@ -72,7 +71,7 @@ func TestCypherTemplates(t *testing.T) { templateFiles := loadCypherTemplateFiles(t) nodeKinds, edgeKinds := cypherTemplateKinds(templateFiles) - db, ctx := SetupDBWithKindsNoGraphCleanup(t, nodeKinds, edgeKinds) + db, ctx := SetupDBWithKindsNoGraphCleanup(t, 0, nodeKinds, edgeKinds) ClearGraph(t, db, ctx) for _, templateFile := range templateFiles { @@ -200,30 +199,24 @@ func runWithTemplateFixture(t *testing.T, ctx context.Context, db graph.Database t.Fatal("template cases must define an inline fixture") } - var ( - queryErrorObserved = false - err = db.WriteTransaction(ctx, func(tx graph.Transaction) error { - idMap, err := opengraph.WriteGraphTx(tx, tc.Fixture) - if err != nil { - return fmt.Errorf("creating fixture: %w", err) - } - - result := tx.Query(tc.Cypher, tc.Params) - defer result.Close() - assertion.checkResult(t, result, newAssertionContext(idMap)) - if assertion.expectQueryError { - queryErrorObserved = true - } + queryErrorObserved := false + session := &Session{DB: db, Ctx: ctx} + err := session.WithRollbackFixture(t, tc.Fixture, false, func(tx graph.Transaction, idMap opengraph.IDMap) error { + result := tx.Query(tc.Cypher, tc.Params) + defer result.Close() + assertion.checkResult(t, result, newAssertionContext(idMap)) + if assertion.expectQueryError { + queryErrorObserved = true + } - return errFixtureRollback - }) - ) + return nil + }) if assertion.expectQueryError && queryErrorObserved && err != nil { return } - if !errors.Is(err, errFixtureRollback) { + if err != nil { t.Fatalf("unexpected transaction error: %v", err) } } @@ -239,17 +232,11 @@ func runMetamorphicFamily(t *testing.T, ctx context.Context, db graph.Database, t.Fatal("metamorphic cases must define at least two queries") } - err := db.WriteTransaction(ctx, func(tx graph.Transaction) error { - idMap, err := opengraph.WriteGraphTx(tx, family.Fixture) - if err != nil { - return fmt.Errorf("creating fixture: %w", err) - } - - var ( - assertCtx = newAssertionContext(idMap) - baselineName string - baseline []string - ) + session := &Session{DB: db, Ctx: ctx} + err := session.WithRollbackFixture(t, family.Fixture, false, func(tx graph.Transaction, idMap opengraph.IDMap) error { + assertCtx := newAssertionContext(idMap) + var baselineName string + var baseline []string for _, query := range family.Queries { var ( @@ -282,10 +269,10 @@ func runMetamorphicFamily(t *testing.T, ctx context.Context, db graph.Database, t.Fatal("all metamorphic queries were skipped") } - return errFixtureRollback + return nil }) - if !errors.Is(err, errFixtureRollback) { + if err != nil { t.Fatalf("unexpected transaction error: %v", err) } } @@ -333,7 +320,8 @@ func comparisonModeSignature(t *testing.T, result queryResult, ctx assertionCont var signature []string switch mode { case "row_count": - signature = []string{fmt.Sprintf("%d", len(result.rows))} + row := fmt.Sprintf("%d", len(result.rows)) + signature = []string{row} case "scalar_values": signature = sortedSignatures(firstScalarSignatures(t, result)) case "ordered_scalar_values": diff --git a/integration/cypher_test.go b/integration/cypher_test.go index 085f81d..91231a9 100644 --- a/integration/cypher_test.go +++ b/integration/cypher_test.go @@ -21,7 +21,6 @@ package integration import ( "context" "encoding/json" - "errors" "fmt" "math" "os" @@ -257,9 +256,6 @@ func parseAssertion(t *testing.T, raw json.RawMessage) caseAssertion { } } -// errFixtureRollback is returned to unconditionally roll back inline fixture data. -var errFixtureRollback = errors.New("fixture rollback") - // runReadOnly executes a test case against the pre-loaded dataset. func runReadOnly(t *testing.T, ctx context.Context, db graph.Database, idMap opengraph.IDMap, tc testCase, assertion caseAssertion) { t.Helper() @@ -291,34 +287,24 @@ func runReadOnly(t *testing.T, ctx context.Context, db graph.Database, idMap ope func runWithFixture(t *testing.T, ctx context.Context, db graph.Database, tc testCase, assertion caseAssertion) { t.Helper() - var ( - queryErrorObserved = false - err = db.WriteTransaction(ctx, func(tx graph.Transaction) error { - if err := tx.Nodes().Delete(); err != nil { - return fmt.Errorf("clearing graph before fixture: %w", err) - } - - idMap, err := opengraph.WriteGraphTx(tx, tc.Fixture) - if err != nil { - return fmt.Errorf("creating fixture: %w", err) - } - - result := tx.Query(tc.Cypher, tc.Params) - defer result.Close() - assertion.checkResult(t, result, newAssertionContext(idMap)) - if assertion.expectQueryError { - queryErrorObserved = true - } + queryErrorObserved := false + session := &Session{DB: db, Ctx: ctx} + err := session.WithRollbackFixture(t, tc.Fixture, true, func(tx graph.Transaction, idMap opengraph.IDMap) error { + result := tx.Query(tc.Cypher, tc.Params) + defer result.Close() + assertion.checkResult(t, result, newAssertionContext(idMap)) + if assertion.expectQueryError { + queryErrorObserved = true + } - return errFixtureRollback - }) - ) + return nil + }) if assertion.expectQueryError && queryErrorObserved && err != nil { return } - if !errors.Is(err, errFixtureRollback) { + if err != nil { t.Fatalf("unexpected transaction error: %v", err) } } diff --git a/integration/harness.go b/integration/harness.go index 5075b56..9d667d3 100644 --- a/integration/harness.go +++ b/integration/harness.go @@ -18,6 +18,7 @@ package integration import ( "context" + "errors" "flag" "fmt" "net/url" @@ -28,20 +29,52 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/specterops/dawgs" + "github.com/specterops/dawgs/drivers/neo4j" "github.com/specterops/dawgs/drivers/pg" "github.com/specterops/dawgs/graph" "github.com/specterops/dawgs/opengraph" "github.com/specterops/dawgs/util/size" - - "github.com/specterops/dawgs/drivers/neo4j" ) +const ConnectionStringEnv = "CONNECTION_STRING" + var ( - localDatasetFlag = flag.String("local-dataset", "", "name of a local dataset to test (e.g. local/phantom)") + localDatasetFlag = flag.String("local-dataset", "", "name of a local dataset to test (e.g. local/phantom)") + errFixtureRollback = errors.New("fixture rollback") ) -// driverFromConnStr returns the dawgs driver name based on the connection string scheme. -func driverFromConnStr(connStr string) (string, error) { +type CleanupMode int + +const ( + CleanupGraph CleanupMode = iota + CloseOnly +) + +type Options struct { + RequireDriver string + SkipIfNoConnection bool + SkipIfDriverMismatch bool + ConnectionStringEnvVar string + GraphName string + GraphQueryMemoryLimit size.Size + Schema *graph.Schema + ExtraNodeKinds graph.Kinds + ExtraEdgeKinds graph.Kinds + Datasets []string + DatasetPath func(name string) string + CleanupMode CleanupMode +} + +type Session struct { + ConnectionString string + Driver string + DB graph.Database + PGPool *pgxpool.Pool + Ctx context.Context +} + +// DriverFromConnStr returns the dawgs driver name based on the connection string scheme. +func DriverFromConnectionString(connStr string) (string, error) { u, err := url.Parse(connStr) if err != nil { return "", fmt.Errorf("failed to parse connection string: %w", err) @@ -57,52 +90,44 @@ func driverFromConnStr(connStr string) (string, error) { } } -// SetupDB opens a database connection for the selected driver, asserts a schema -// derived from the given datasets, and registers cleanup. Returns the database -// and a background context. -func SetupDB(t *testing.T, datasets ...string) (graph.Database, context.Context) { - t.Helper() - - return setupDB(t, true, nil, nil, datasets...) -} - -// SetupDBWithKinds opens a database connection like SetupDB, then extends the -// asserted schema with additional node and edge kinds. -func SetupDBWithKinds(t *testing.T, extraNodeKinds, extraEdgeKinds graph.Kinds, datasets ...string) (graph.Database, context.Context) { - t.Helper() - - return setupDB(t, true, extraNodeKinds, extraEdgeKinds, datasets...) -} - -// SetupDBWithKindsNoGraphCleanup opens a database connection like SetupDBWithKinds -// but only closes the connection during cleanup. Use this for rollback-only tests -// that must not clear a shared database. -func SetupDBWithKindsNoGraphCleanup(t *testing.T, extraNodeKinds, extraEdgeKinds graph.Kinds, datasets ...string) (graph.Database, context.Context) { - t.Helper() - - return setupDB(t, false, extraNodeKinds, extraEdgeKinds, datasets...) -} - -func setupDB(t *testing.T, cleanupGraph bool, extraNodeKinds, extraEdgeKinds graph.Kinds, datasets ...string) (graph.Database, context.Context) { +func Open(t *testing.T, opts Options) *Session { t.Helper() - var ( - ctx = context.Background() - connStr = os.Getenv("CONNECTION_STRING") - ) + ctx := context.Background() + connEnv := opts.ConnectionStringEnvVar + if connEnv == "" { + connEnv = ConnectionStringEnv + } + connStr := os.Getenv(connEnv) if connStr == "" { - t.Skip("CONNECTION_STRING env var is not set") + if opts.SkipIfNoConnection { + t.Skipf("%s env var is not set", connEnv) + } + t.Fatalf("%s env var is not set", connEnv) } - driver, err := driverFromConnStr(connStr) + driver, err := DriverFromConnectionString(connStr) if err != nil { - t.Fatalf("Failed to detect driver: %v", err) + t.Fatalf("failed to detect driver: %v", err) + } + + if opts.RequireDriver != "" && driver != opts.RequireDriver { + if opts.SkipIfDriverMismatch { + t.Skipf("%s is not a %s connection string", connEnv, opts.RequireDriver) + } + t.Fatalf("%s is not a %s connection string", connEnv, opts.RequireDriver) } cfg := dawgs.Config{ - GraphQueryMemoryLimit: size.Gibibyte, ConnectionString: connStr, + GraphQueryMemoryLimit: opts.graphQueryMemoryLimit(), + } + + session := &Session{ + ConnectionString: connStr, + Driver: driver, + Ctx: ctx, } if driver == pg.DriverName { @@ -112,35 +137,30 @@ func setupDB(t *testing.T, cleanupGraph bool, extraNodeKinds, extraEdgeKinds gra } pool, err := pg.NewPool(poolCfg) if err != nil { - t.Fatalf("Failed to create PG pool: %v", err) + t.Fatalf("failed to create PG pool: %v", err) } cfg.Pool = pool + session.PGPool = pool } db, err := dawgs.Open(ctx, driver, cfg) if err != nil { - t.Fatalf("Failed to open database: %v", err) + t.Fatalf("failed to open database: %v", err) } + session.DB = db - nodeKinds, edgeKinds := collectKinds(t, datasets) - nodeKinds = nodeKinds.Add(extraNodeKinds...) - edgeKinds = edgeKinds.Add(extraEdgeKinds...) - - schema := graph.Schema{ - Graphs: []graph.Graph{{ - Name: "integration_test", - Nodes: nodeKinds, - Edges: edgeKinds, - }}, - DefaultGraph: graph.Graph{Name: "integration_test"}, + schema := opts.Schema + if schema == nil { + schema = buildSchema(t, opts) } - - if err := db.AssertSchema(ctx, schema); err != nil { - t.Fatalf("Failed to assert schema: %v", err) + if schema != nil { + if err := db.AssertSchema(ctx, *schema); err != nil { + t.Fatalf("failed to assert schema: %v", err) + } } t.Cleanup(func() { - if cleanupGraph { + if opts.CleanupMode != CloseOnly { _ = db.WriteTransaction(ctx, func(tx graph.Transaction) error { return tx.Nodes().Delete() }) @@ -148,11 +168,109 @@ func setupDB(t *testing.T, cleanupGraph bool, extraNodeKinds, extraEdgeKinds gra db.Close(ctx) }) - return db, ctx + return session +} + +func (s *Session) ClearGraph(t *testing.T) { + t.Helper() + + if err := s.DB.WriteTransaction(s.Ctx, func(tx graph.Transaction) error { + return tx.Nodes().Delete() + }); err != nil { + t.Fatalf("failed to clear graph: %v", err) + } +} + +func (s *Session) LoadDataset(t *testing.T, path string) opengraph.IDMap { + t.Helper() + + f, err := os.Open(path) + if err != nil { + t.Fatalf("failed to open dataset %q: %v", path, err) + } + defer f.Close() + + idMap, err := opengraph.Load(s.Ctx, s.DB, f) + if err != nil { + t.Fatalf("failed to load dataset %q: %v", path, err) + } + + return idMap +} + +func (s *Session) WithRollbackFixture(t *testing.T, fixture *opengraph.Graph, clearGraph bool, delegate func(tx graph.Transaction, idMap opengraph.IDMap) error) error { + t.Helper() + + return s.withRollback(t, func(tx graph.Transaction) error { + if clearGraph { + if err := tx.Nodes().Delete(); err != nil { + return fmt.Errorf("clearing graph before fixture: %w", err) + } + } + + idMap, err := opengraph.WriteGraphTx(tx, fixture) + if err != nil { + return fmt.Errorf("creating fixture: %w", err) + } + + if delegate != nil { + return delegate(tx, idMap) + } + + return nil + }) +} + +func (s *Session) WithRollback(t *testing.T, delegate func(tx graph.Transaction) error) error { + t.Helper() + return s.withRollback(t, delegate) +} + +func (s *Session) withRollback(t *testing.T, delegate func(tx graph.Transaction) error) error { + t.Helper() + + err := s.DB.WriteTransaction(s.Ctx, func(tx graph.Transaction) error { + if err := delegate(tx); err != nil { + return err + } + + return errFixtureRollback + }) + if errors.Is(err, errFixtureRollback) { + return nil + } + + return err +} + +func buildSchema(t *testing.T, opts Options) *graph.Schema { + t.Helper() + + nodeKinds, edgeKinds := collectKinds(t, opts.Datasets, opts.datasetPath()) + nodeKinds = nodeKinds.Add(opts.ExtraNodeKinds...) + edgeKinds = edgeKinds.Add(opts.ExtraEdgeKinds...) + + if len(nodeKinds) == 0 && len(edgeKinds) == 0 { + return nil + } + + graphName := opts.GraphName + if graphName == "" { + graphName = "integration_test" + } + + return &graph.Schema{ + Graphs: []graph.Graph{{ + Name: graphName, + Nodes: nodeKinds, + Edges: edgeKinds, + }}, + DefaultGraph: graph.Graph{Name: graphName}, + } } // collectKinds parses the given datasets and returns the union of all node and edge kinds. -func collectKinds(t *testing.T, datasets []string) (graph.Kinds, graph.Kinds) { +func collectKinds(t *testing.T, datasets []string, datasetPath func(name string) string) (graph.Kinds, graph.Kinds) { t.Helper() var nodeKinds, edgeKinds graph.Kinds @@ -177,6 +295,74 @@ func collectKinds(t *testing.T, datasets []string) (graph.Kinds, graph.Kinds) { return nodeKinds, edgeKinds } +func (s *Options) datasetPath() func(name string) string { + if s.DatasetPath != nil { + return s.DatasetPath + } + + return func(name string) string { + return "testdata/" + name + ".json" + } +} + +func (s Options) graphQueryMemoryLimit() size.Size { + if s.GraphQueryMemoryLimit == 0 { + return size.Gibibyte + } + + return s.GraphQueryMemoryLimit +} + +// SetupDB opens a database connection for the selected driver, asserts a schema +// derived from the given datasets, and registers cleanup. Returns the database +// and a background context. +func SetupDB(t *testing.T, datasets ...string) (graph.Database, context.Context) { + t.Helper() + + session := Open(t, Options{ + CleanupMode: 0, + Datasets: datasets, + ExtraNodeKinds: nil, + ExtraEdgeKinds: nil, + DatasetPath: datasetPath, + }) + + return session.DB, session.Ctx +} + +// SetupDBWithKinds opens a database connection like SetupDB, then extends the +// asserted schema with additional node and edge kinds. +func SetupDBWithKinds(t *testing.T, cleanupMode CleanupMode, extraNodeKinds, extraEdgeKinds graph.Kinds, datasets ...string) (graph.Database, context.Context) { + t.Helper() + + session := Open(t, Options{ + CleanupMode: cleanupMode, + Datasets: datasets, + ExtraNodeKinds: extraNodeKinds, + ExtraEdgeKinds: extraEdgeKinds, + DatasetPath: datasetPath, + }) + + return session.DB, session.Ctx +} + +// SetupDBWithKindsNoGraphCleanup opens a database connection like SetupDBWithKinds +// but only closes the connection during cleanup. Use this for rollback-only tests +// that must not clear a shared database. +func SetupDBWithKindsNoGraphCleanup(t *testing.T, cleanupMode CleanupMode, extraNodeKinds, extraEdgeKinds graph.Kinds, datasets ...string) (graph.Database, context.Context) { + t.Helper() + + session := Open(t, Options{ + CleanupMode: cleanupMode, + Datasets: datasets, + ExtraNodeKinds: extraNodeKinds, + ExtraEdgeKinds: extraEdgeKinds, + DatasetPath: datasetPath, + }) + + return session.DB, session.Ctx +} + // ClearGraph deletes all nodes (and cascading edges) from the database. func ClearGraph(t *testing.T, db graph.Database, ctx context.Context) { t.Helper() diff --git a/integration/pgsql_aggregate_traversal_plan_test.go b/integration/pgsql_aggregate_traversal_plan_test.go index 3706415..4fd25df 100644 --- a/integration/pgsql_aggregate_traversal_plan_test.go +++ b/integration/pgsql_aggregate_traversal_plan_test.go @@ -79,7 +79,7 @@ func TestPostgreSQLLiveAggregateTraversalCountPlanShape(t *testing.T) { t.Skip("CONNECTION_STRING env var is not set") } - driver, err := driverFromConnStr(connStr) + driver, err := DriverFromConnectionString(connStr) if err != nil { t.Fatalf("failed to detect driver: %v", err) } diff --git a/integration/pgsql_count_fast_path_test.go b/integration/pgsql_count_fast_path_test.go index a29a961..08cc355 100644 --- a/integration/pgsql_count_fast_path_test.go +++ b/integration/pgsql_count_fast_path_test.go @@ -33,7 +33,7 @@ func TestPostgreSQLCountStoreFastPathRequiresRelationshipEndpoints(t *testing.T) t.Skip("CONNECTION_STRING env var is not set") } - driver, err := driverFromConnStr(connStr) + driver, err := DriverFromConnectionString(connStr) if err != nil { t.Fatalf("failed to detect driver: %v", err) } @@ -44,7 +44,7 @@ func TestPostgreSQLCountStoreFastPathRequiresRelationshipEndpoints(t *testing.T) var ( nodeKind = graph.StringKind("CountFastPathNode") edgeKind = graph.StringKind("CountFastPathEdge") - db, ctx = SetupDBWithKinds(t, graph.Kinds{nodeKind}, graph.Kinds{edgeKind}) + db, ctx = SetupDBWithKinds(t, 0, graph.Kinds{nodeKind}, graph.Kinds{edgeKind}) ) if err := db.WriteTransaction(ctx, func(tx graph.Transaction) error { diff --git a/integration/pgsql_property_equality_test.go b/integration/pgsql_property_equality_test.go index 5db0293..662e8db 100644 --- a/integration/pgsql_property_equality_test.go +++ b/integration/pgsql_property_equality_test.go @@ -58,7 +58,7 @@ func TestPostgreSQLPropertyTextEqualityCompatibility(t *testing.T) { t.Skip("CONNECTION_STRING env var is not set") } - driver, err := driverFromConnStr(connStr) + driver, err := DriverFromConnectionString(connStr) if err != nil { t.Fatalf("failed to detect driver: %v", err) } @@ -70,7 +70,7 @@ func TestPostgreSQLPropertyTextEqualityCompatibility(t *testing.T) { userKind = graph.StringKind("User") groupKind = graph.StringKind("Group") memberOf = graph.StringKind("MemberOf") - db, ctx = SetupDBWithKinds(t, graph.Kinds{userKind, groupKind}, graph.Kinds{memberOf}) + db, ctx = SetupDBWithKinds(t, 0, graph.Kinds{userKind, groupKind}, graph.Kinds{memberOf}) boolTrue *graph.Relationship boolFalse *graph.Relationship stringTrue *graph.Relationship @@ -205,7 +205,7 @@ func TestPostgreSQLLiveObjectIDEqualityPlanUsesTextExpressionIndex(t *testing.T) t.Skip("CONNECTION_STRING env var is not set") } - driver, err := driverFromConnStr(connStr) + driver, err := DriverFromConnectionString(connStr) if err != nil { t.Fatalf("failed to detect driver: %v", err) } diff --git a/integration/pgsql_property_index_plan_test.go b/integration/pgsql_property_index_plan_test.go index 0921e21..fc96ba2 100644 --- a/integration/pgsql_property_index_plan_test.go +++ b/integration/pgsql_property_index_plan_test.go @@ -171,7 +171,7 @@ func setupIndexedPostgresDB(t *testing.T, graphName string, nodeIndexes []graph. t.Skip("CONNECTION_STRING env var is not set") } - driverName, err := driverFromConnStr(connStr) + driverName, err := DriverFromConnectionString(connStr) if err != nil { t.Fatalf("failed to detect driver: %v", err) }