diff --git a/GEMINI.md b/GEMINI.md index 5784027e..391a3b14 100644 --- a/GEMINI.md +++ b/GEMINI.md @@ -60,3 +60,16 @@ Any pull request modifying or extending the driver's features must include: - **Mock Server Tests**: Located in files such as `driver_with_mockserver_test.go`, `conn_with_mockserver_test.go`, and `stmt_with_mockserver_test.go`. Use these (or add new test files) to mock Spanner gRPC API responses (e.g. BeginTransaction, Commit, ExecuteSql) and verify that the driver translates options, tags, and states correctly. - **Emulator Tests**: Validate integration behavior against the Cloud Spanner Emulator (`integration_test.go` and examples). Make sure the test configurations can run locally with `auto_config_emulator=true`. - **Wrapper Tests**: If you modified `spannerlib`, ensure you trigger or run unit/integration tests for the respective wrappers (`python-spanner-lib-wrapper-unit-tests.yml`, `ruby-wrapper-tests.yml`, etc.). +- **Assertion Formatting**: When writing test assertions, strongly prefer using variable names `g` (got) and `w` (want) for comparison, and format error messages using the following aligned layout: + ```go + if g, w := actualValue, expectedValue; g != w { + t.Errorf("some message mismatch\nGot: %v\nWant: %v", g, w) + } + ``` + Note the two spaces after `Got:` to align the values visually. + +--- + +## 6. Code Style & Formatting + +- **Go Code Formatting**: All Go code must be formatted using the standard `gofmt -w -s .` formatter. Running this formatting command is required before submitting a pull request. diff --git a/conn.go b/conn.go index d7d9aafb..9dc1ecd5 100644 --- a/conn.go +++ b/conn.go @@ -19,7 +19,6 @@ import ( "database/sql" "database/sql/driver" "errors" - "fmt" "log/slog" "slices" "sync" @@ -705,43 +704,7 @@ func (c *conn) execDDL(ctx context.Context, statements ...spanner.Statement) (dr return (&executableDropDatabaseStatement{stmt}).execContext(ctx, c, nil, []driver.NamedValue{}) } - op, err := c.adminClient.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ - Database: c.database, - Statements: ddlStatements, - }) - if err != nil { - return nil, err - } - c.lastDDLOperationID = op.Name() - - if err := c.waitForDDLOperation(ctx, op.Name(), func(ctx context.Context) error { - return op.Wait(ctx) - }); err != nil { - if len(statements) > 1 { - be := &BatchError{ - Err: err, - BatchUpdateCounts: []int64{}, - } - metadata, err := op.Metadata() - if err != nil { - c.logger.WarnContext(ctx, fmt.Sprintf("Error getting metadata for UpdateDatabaseDdl: %v", err)) - } else if metadata != nil { - for _, ts := range metadata.CommitTimestamps { - if ts != nil { - be.BatchUpdateCounts = append(be.BatchUpdateCounts, int64(-1)) - } else { - break - } - } - } - return nil, be - } - return nil, err - } - mode := propertyDDLExecutionMode.GetValueOrDefault(c.state) - if mode == DDLExecutionModeAsync || mode == DDLExecutionModeAsyncWait { - return &result{operationID: op.Name()}, nil - } + return c.executeDDLWithDefaultSequenceKindRetry(ctx, statements, ddlStatements) } return driver.ResultNoRows, nil } diff --git a/connection_properties.go b/connection_properties.go index 1a4f5387..2fcf482a 100644 --- a/connection_properties.go +++ b/connection_properties.go @@ -175,6 +175,15 @@ var propertyReadOnlyStaleness = createConnectionProperty( connectionstate.ContextUser, connectionstate.ConvertReadOnlyStaleness, ) +var propertyDefaultSequenceKind = createConnectionProperty( + "default_sequence_kind", + "The default sequence kind to automatically set if a DDL statement fails due to missing sequence kind.", + "", + false, + nil, + connectionstate.ContextUser, + connectionstate.ConvertString, +) var propertyAutoPartitionMode = createConnectionProperty( "auto_partition_mode", diff --git a/default_sequence_kind.go b/default_sequence_kind.go new file mode 100644 index 00000000..60b55719 --- /dev/null +++ b/default_sequence_kind.go @@ -0,0 +1,151 @@ +package spannerdriver + +import ( + "context" + "database/sql/driver" + "fmt" + "regexp" + "strings" + + "cloud.google.com/go/spanner" + adminapi "cloud.google.com/go/spanner/admin/database/apiv1" + adminpb "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" +) + +var reMissingDefaultSequenceKind = regexp.MustCompile(`Please specify the sequence kind explicitly or set the database option\s+['\x60]?default_sequence_kind['\x60]?\.`) + +func isMissingDefaultSequenceKindError(err error) bool { + if err == nil { + return false + } + return reMissingDefaultSequenceKind.MatchString(err.Error()) +} + +func (c *conn) executeDDLWithDefaultSequenceKindRetry(ctx context.Context, originalStatements []spanner.Statement, ddlStatements []string) (driver.Result, error) { + op, err := c.adminClient.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ + Database: c.database, + Statements: ddlStatements, + }) + + var opRetry *adminapi.UpdateDatabaseDdlOperation + var restartIndex int + var retryErr error + + if err != nil { + // The RPC execution returned an error. + defaultSequenceKind := propertyDefaultSequenceKind.GetValueOrDefault(c.state) + if defaultSequenceKind != "" && isMissingDefaultSequenceKindError(err) { + if errAlter := c.setDefaultSequenceKind(ctx, defaultSequenceKind); errAlter == nil { + opRetry, retryErr = c.adminClient.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ + Database: c.database, + Statements: ddlStatements, + }) + } + } + } else { + c.lastDDLOperationID = op.Name() + err = c.waitForDDLOperation(ctx, op.Name(), func(ctx context.Context) error { + return op.Wait(ctx) + }) + if err != nil { + // The long-running operation returned an error. + defaultSequenceKind := propertyDefaultSequenceKind.GetValueOrDefault(c.state) + if defaultSequenceKind != "" && isMissingDefaultSequenceKindError(err) { + if errAlter := c.setDefaultSequenceKind(ctx, defaultSequenceKind); errAlter == nil { + restartIndex = getSuccessCount(op) + if restartIndex < len(ddlStatements) { + opRetry, retryErr = c.adminClient.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ + Database: c.database, + Statements: ddlStatements[restartIndex:], + }) + } + } + } + } + } + + // If a retry was successfully scheduled + if opRetry != nil && retryErr == nil { + c.lastDDLOperationID = opRetry.Name() + err = c.waitForDDLOperation(ctx, opRetry.Name(), func(ctx context.Context) error { + return opRetry.Wait(ctx) + }) + if err == nil { + mode := propertyDDLExecutionMode.GetValueOrDefault(c.state) + if mode == DDLExecutionModeAsync || mode == DDLExecutionModeAsyncWait { + return &result{operationID: opRetry.Name()}, nil + } + return driver.ResultNoRows, nil + } + } else if retryErr != nil { + err = retryErr + } + + if err != nil { + if len(originalStatements) > 1 { + be := &BatchError{ + Err: err, + BatchUpdateCounts: []int64{}, + } + successCount := getSuccessCount(op) + if opRetry != nil { + successCount = restartIndex + getSuccessCount(opRetry) + } + for i := 0; i < successCount; i++ { + be.BatchUpdateCounts = append(be.BatchUpdateCounts, int64(-1)) + } + return nil, be + } + return nil, err + } + + mode := propertyDDLExecutionMode.GetValueOrDefault(c.state) + if mode == DDLExecutionModeAsync || mode == DDLExecutionModeAsyncWait { + return &result{operationID: op.Name()}, nil + } + return driver.ResultNoRows, nil +} + +func (c *conn) setDefaultSequenceKind(ctx context.Context, defaultSequenceKind string) error { + dbID := c.databaseID() + var alterStatement string + if c.parser.Dialect == adminpb.DatabaseDialect_POSTGRESQL { + alterStatement = fmt.Sprintf(`ALTER DATABASE "%s" SET spanner.default_sequence_kind = '%s'`, strings.ReplaceAll(dbID, `"`, `""`), defaultSequenceKind) + } else { + alterStatement = fmt.Sprintf("ALTER DATABASE `%s` SET OPTIONS (default_sequence_kind = '%s')", strings.ReplaceAll(dbID, "`", "``"), defaultSequenceKind) + } + opAlter, errAlter := c.adminClient.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ + Database: c.database, + Statements: []string{alterStatement}, + }) + if errAlter != nil { + return errAlter + } + return c.waitForDDLOperation(ctx, opAlter.Name(), func(ctx context.Context) error { + return opAlter.Wait(ctx) + }) +} + +func (c *conn) databaseID() string { + parts := strings.Split(c.database, "/") + return parts[len(parts)-1] +} + +func getSuccessCount(op *adminapi.UpdateDatabaseDdlOperation) int { + if op == nil { + return 0 + } + metadata, err := op.Metadata() + if err != nil || metadata == nil { + return 0 + } + var count int + for _, ts := range metadata.CommitTimestamps { + if ts != nil { + count++ + } else { + break + } + } + return count +} diff --git a/driver_test.go b/driver_test.go index 83b803d4..e2826dcf 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1087,3 +1087,15 @@ func TestConnectionStateTypeInitialization(t *testing.T) { t.Errorf("ConnectionStateType mismatch. Got: %v, Want: %v", c.connectorConfig.ConnectionStateType, connectionstate.TypeTransactional) } } + +func TestIsMissingDefaultSequenceKindError(t *testing.T) { + err := fmt.Errorf("rpc error: code = InvalidArgument desc = The sequence kind of an identity column id is not specified. Please specify the sequence kind explicitly or set the database option `default_sequence_kind`.") + if !isMissingDefaultSequenceKindError(err) { + t.Errorf("isMissingDefaultSequenceKindError returned false for: %v", err) + } + + errNoQuotes := fmt.Errorf("rpc error: code = InvalidArgument desc = The sequence kind of an identity column id is not specified. Please specify the sequence kind explicitly or set the database option default_sequence_kind.") + if !isMissingDefaultSequenceKindError(errNoQuotes) { + t.Errorf("isMissingDefaultSequenceKindError returned false for: %v", errNoQuotes) + } +} diff --git a/driver_with_mockserver_test.go b/driver_with_mockserver_test.go index 6d90afa6..d375b11d 100644 --- a/driver_with_mockserver_test.go +++ b/driver_with_mockserver_test.go @@ -20,6 +20,7 @@ import ( "database/sql/driver" "encoding/base64" "encoding/json" + "errors" "fmt" "math/big" "math/rand" @@ -42,6 +43,7 @@ import ( "github.com/googleapis/go-sql-spanner/parser" "github.com/googleapis/go-sql-spanner/testutil" "google.golang.org/api/option" + pbstatus "google.golang.org/genproto/googleapis/rpc/status" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" @@ -50,6 +52,7 @@ import ( "google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/structpb" + "google.golang.org/protobuf/types/known/timestamppb" ) func TestPingContext(t *testing.T) { @@ -2498,6 +2501,317 @@ func TestDdlInTransaction(t *testing.T) { } } +func TestAutoDefaultSequenceKindAsyncMode(t *testing.T) { + t.Parallel() + + for _, dialect := range []databasepb.DatabaseDialect{ + databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, + databasepb.DatabaseDialect_POSTGRESQL, + } { + name := "GoogleSQL" + dsnParams := "default_sequence_kind=bit_reversed_positive;ddl_execution_mode=async" + if dialect == databasepb.DatabaseDialect_POSTGRESQL { + name = "PostgreSQL" + dsnParams = "dialect=postgresql;default_sequence_kind=bit_reversed_positive;ddl_execution_mode=async" + } + t.Run(name, func(t *testing.T) { + db, server, teardown := setupTestDBConnectionWithParamsAndDialect(t, dsnParams, dialect) + defer teardown() + + anyResponse, _ := anypb.New(&emptypb.Empty{}) + opSuccess := &longrunningpb.Operation{ + Done: true, + Result: &longrunningpb.Operation_Response{Response: anyResponse}, + Name: "op-success", + } + server.TestDatabaseAdmin.SetResps([]proto.Message{opSuccess}) + + query := "CREATE SEQUENCE my_seq" + _, err := db.ExecContext(context.Background(), query) + if err != nil { + t.Fatal(err) + } + + requests := server.TestDatabaseAdmin.Reqs() + if g, w := len(requests), 1; g != w { + t.Fatalf("requests count mismatch\nGot: %v\nWant: %v", g, w) + } + + // First request should be the original statement + req0, ok := requests[0].(*databasepb.UpdateDatabaseDdlRequest) + if !ok { + t.Fatalf("request 0 type mismatch, got %T", requests[0]) + } + if g, w := req0.Statements[0], query; g != w { + t.Errorf("request 0 statement mismatch\nGot: %s\nWant: %s", g, w) + } + }) + } +} + +func TestAutoDefaultSequenceKindAsyncModeSyncFailure(t *testing.T) { + t.Parallel() + + for _, dialect := range []databasepb.DatabaseDialect{ + databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, + databasepb.DatabaseDialect_POSTGRESQL, + } { + name := "GoogleSQL" + dsnParams := "default_sequence_kind=bit_reversed_positive;ddl_execution_mode=async" + if dialect == databasepb.DatabaseDialect_POSTGRESQL { + name = "PostgreSQL" + dsnParams = "dialect=postgresql;default_sequence_kind=bit_reversed_positive;ddl_execution_mode=async" + } + t.Run(name, func(t *testing.T) { + db, server, teardown := setupTestDBConnectionWithParamsAndDialect(t, dsnParams, dialect) + defer teardown() + + anyResponse, _ := anypb.New(&emptypb.Empty{}) + opSuccess := &longrunningpb.Operation{ + Done: true, + Result: &longrunningpb.Operation_Response{Response: anyResponse}, + Name: "op-success", + } + + // Mock synchronous DDL error on first call, followed by successful ALTER and then retry. + server.TestDatabaseAdmin.SetErrs([]error{ + gstatus.Error(codes.InvalidArgument, "Please specify the sequence kind explicitly or set the database option 'default_sequence_kind'."), + nil, // ALTER DATABASE succeeds + nil, // DDL retry succeeds + }) + server.TestDatabaseAdmin.SetResps([]proto.Message{ + opSuccess, // ALTER DATABASE op + opSuccess, // DDL retry op + }) + + query := "CREATE SEQUENCE my_seq" + _, err := db.ExecContext(context.Background(), query) + if err != nil { + t.Fatal(err) + } + + requests := server.TestDatabaseAdmin.Reqs() + if g, w := len(requests), 3; g != w { + t.Fatalf("requests count mismatch\nGot: %v\nWant: %v", g, w) + } + + // Verifies database was altered and retried + req0, ok := requests[0].(*databasepb.UpdateDatabaseDdlRequest) + if !ok { + t.Fatalf("request 0 type mismatch, got %T", requests[0]) + } + if g, w := req0.Statements[0], query; g != w { + t.Errorf("request 0 statement mismatch\nGot: %s\nWant: %s", g, w) + } + + req1, ok := requests[1].(*databasepb.UpdateDatabaseDdlRequest) + if !ok { + t.Fatalf("request 1 type mismatch, got %T", requests[1]) + } + var wantAlter string + if dialect == databasepb.DatabaseDialect_POSTGRESQL { + wantAlter = `ALTER DATABASE "d" SET spanner.default_sequence_kind = 'bit_reversed_positive'` + } else { + wantAlter = "ALTER DATABASE `d` SET OPTIONS (default_sequence_kind = 'bit_reversed_positive')" + } + if g, w := req1.Statements[0], wantAlter; g != w { + t.Errorf("request 1 statement mismatch\nGot: %s\nWant: %s", g, w) + } + + req2, ok := requests[2].(*databasepb.UpdateDatabaseDdlRequest) + if !ok { + t.Fatalf("request 2 type mismatch, got %T", requests[2]) + } + if g, w := req2.Statements[0], query; g != w { + t.Errorf("request 2 statement mismatch\nGot: %s\nWant: %s", g, w) + } + }) + } +} + +func TestAutoDefaultSequenceKind(t *testing.T) { + t.Parallel() + + for _, dialect := range []databasepb.DatabaseDialect{ + databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, + databasepb.DatabaseDialect_POSTGRESQL, + } { + name := "GoogleSQL" + dsnParams := "default_sequence_kind=bit_reversed_positive" + if dialect == databasepb.DatabaseDialect_POSTGRESQL { + name = "PostgreSQL" + dsnParams = "dialect=postgresql;default_sequence_kind=bit_reversed_positive" + } + t.Run(name, func(t *testing.T) { + db, server, teardown := setupTestDBConnectionWithParamsAndDialect(t, dsnParams, dialect) + defer teardown() + + opError := &longrunningpb.Operation{ + Done: true, + Result: &longrunningpb.Operation_Error{ + Error: &pbstatus.Status{ + Code: int32(codes.InvalidArgument), + Message: "Please specify the sequence kind explicitly or set the database option 'default_sequence_kind'.", + }, + }, + Name: "op-error", + } + anyResponse, _ := anypb.New(&emptypb.Empty{}) + opSuccess := &longrunningpb.Operation{ + Done: true, + Result: &longrunningpb.Operation_Response{Response: anyResponse}, + Name: "op-success", + } + + server.TestDatabaseAdmin.SetResps([]proto.Message{opError, opSuccess, opSuccess}) + + query := "CREATE SEQUENCE my_seq" + if dialect == databasepb.DatabaseDialect_POSTGRESQL { + query = "CREATE SEQUENCE my_seq" // same for both dialects + } + _, err := db.ExecContext(context.Background(), query) + if err != nil { + t.Fatal(err) + } + + requests := server.TestDatabaseAdmin.Reqs() + if g, w := len(requests), 3; g != w { + t.Fatalf("requests count mismatch\nGot: %v\nWant: %v", g, w) + } + + // First request: original DDL + req0, ok := requests[0].(*databasepb.UpdateDatabaseDdlRequest) + if !ok { + t.Fatalf("request 0 type mismatch, got %T", requests[0]) + } + if g, w := req0.Statements[0], query; g != w { + t.Errorf("request 0 statement mismatch\nGot: %s\nWant: %s", g, w) + } + + // Second request: ALTER DATABASE statement setting default_sequence_kind + req1, ok := requests[1].(*databasepb.UpdateDatabaseDdlRequest) + if !ok { + t.Fatalf("request 1 type mismatch, got %T", requests[1]) + } + var wantAlter string + if dialect == databasepb.DatabaseDialect_POSTGRESQL { + wantAlter = `ALTER DATABASE "d" SET spanner.default_sequence_kind = 'bit_reversed_positive'` + } else { + wantAlter = "ALTER DATABASE `d` SET OPTIONS (default_sequence_kind = 'bit_reversed_positive')" + } + if g, w := req1.Statements[0], wantAlter; g != w { + t.Errorf("request 1 statement mismatch\nGot: %s\nWant: %s", g, w) + } + + // Third request: Retried DDL statement + req2, ok := requests[2].(*databasepb.UpdateDatabaseDdlRequest) + if !ok { + t.Fatalf("request 2 type mismatch, got %T", requests[2]) + } + if g, w := req2.Statements[0], query; g != w { + t.Errorf("request 2 statement mismatch\nGot: %s\nWant: %s", g, w) + } + }) + } +} + +func TestAutoDefaultSequenceKindBatchFailure(t *testing.T) { + t.Parallel() + + for _, dialect := range []databasepb.DatabaseDialect{ + databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, + databasepb.DatabaseDialect_POSTGRESQL, + } { + name := "GoogleSQL" + dsnParams := "default_sequence_kind=bit_reversed_positive" + if dialect == databasepb.DatabaseDialect_POSTGRESQL { + name = "PostgreSQL" + dsnParams = "dialect=postgresql;default_sequence_kind=bit_reversed_positive" + } + t.Run(name, func(t *testing.T) { + db, server, teardown := setupTestDBConnectionWithParamsAndDialect(t, dsnParams, dialect) + defer teardown() + + anyResponse, _ := anypb.New(&emptypb.Empty{}) + meta1, _ := anypb.New(&databasepb.UpdateDatabaseDdlMetadata{ + CommitTimestamps: []*timestamppb.Timestamp{{Seconds: time.Now().Unix(), Nanos: 0}}, + }) + opError1 := &longrunningpb.Operation{ + Done: true, + Result: &longrunningpb.Operation_Error{ + Error: &pbstatus.Status{ + Code: int32(codes.InvalidArgument), + Message: "Please specify the sequence kind explicitly or set the database option 'default_sequence_kind'.", + }, + }, + Metadata: meta1, + Name: "op-error-1", + } + + opSuccessAlter := &longrunningpb.Operation{ + Done: true, + Result: &longrunningpb.Operation_Response{Response: anyResponse}, + Name: "op-success-alter", + } + + meta2, _ := anypb.New(&databasepb.UpdateDatabaseDdlMetadata{ + CommitTimestamps: []*timestamppb.Timestamp{{Seconds: time.Now().Unix(), Nanos: 0}}, + }) + opError2 := &longrunningpb.Operation{ + Done: true, + Result: &longrunningpb.Operation_Error{ + Error: &pbstatus.Status{ + Code: int32(codes.InvalidArgument), + Message: "Some other error on statement 3", + }, + }, + Metadata: meta2, + Name: "op-error-2", + } + + server.TestDatabaseAdmin.SetResps([]proto.Message{opError1, opSuccessAlter, opError2}) + + conn, err := db.Conn(context.Background()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + if _, err := conn.ExecContext(context.Background(), "START BATCH DDL"); err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(context.Background(), "CREATE SEQUENCE my_seq1"); err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(context.Background(), "CREATE SEQUENCE my_seq2"); err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(context.Background(), "CREATE TABLE my_table (id INT64) PRIMARY KEY(id)"); err != nil { + t.Fatal(err) + } + + _, err = conn.ExecContext(context.Background(), "RUN BATCH") + if err == nil { + t.Fatal("expected batch error, got nil") + } + + var be *BatchError + if !errors.As(err, &be) { + t.Fatalf("expected BatchError, got: %v", err) + } + + if g, w := len(be.BatchUpdateCounts), 2; g != w { + t.Errorf("successful statements count mismatch\nGot: %v\nWant: %v", g, w) + } + for i, val := range be.BatchUpdateCounts { + if g, w := val, int64(-1); g != w { + t.Errorf("update count at index %d mismatch\nGot: %v\nWant: %v", i, g, w) + } + } + }) + } +} + func TestBegin(t *testing.T) { t.Parallel() diff --git a/integration_test.go b/integration_test.go index bcce85e1..a725edb1 100644 --- a/integration_test.go +++ b/integration_test.go @@ -2652,3 +2652,139 @@ func nullJsonOrStringArray(v []spanner.NullJSON) interface{} { } return res } + +func TestIntegration_AutoDefaultSequenceKind(t *testing.T) { + skipIfShort(t) + t.Parallel() + + for _, dialect := range []databasepb.DatabaseDialect{ + databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, + databasepb.DatabaseDialect_POSTGRESQL, + } { + name := "GoogleSQL" + if dialect == databasepb.DatabaseDialect_POSTGRESQL { + name = "PostgreSQL" + } + t.Run(name, func(t *testing.T) { + ctx := context.Background() + dsn, cleanup, err := createTestDBWithDialect(ctx, dialect) + if err != nil { + t.Fatalf("failed to create test db: %v", err) + } + defer cleanup() + + dbNoParams, err := sql.Open("spanner", dsn) + if err != nil { + t.Fatal(err) + } + defer dbNoParams.Close() + + var createSeqStmt string + if dialect == databasepb.DatabaseDialect_POSTGRESQL { + createSeqStmt = "CREATE TABLE test_seq (id serial PRIMARY KEY, value varchar)" + } else { + createSeqStmt = "CREATE TABLE test_seq (id INT64 AUTO_INCREMENT PRIMARY KEY, value STRING(MAX))" + } + + _, err = dbNoParams.ExecContext(ctx, createSeqStmt) + if err == nil { + t.Fatalf("expected error without default_sequence_kind parameter") + } + + dsnWithParams := fmt.Sprintf("%s;default_sequence_kind=bit_reversed_positive", dsn) + db, err := sql.Open("spanner", dsnWithParams) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + _, err = db.ExecContext(ctx, createSeqStmt) + if err != nil { + t.Fatalf("failed to execute CREATE statement: %v", err) + } + + insertStmt := "INSERT INTO test_seq (value) VALUES ('One')" + _, err = db.ExecContext(ctx, insertStmt) + if err != nil { + t.Fatalf("failed to insert data: %v", err) + } + }) + } +} + +func TestIntegration_AutoDefaultSequenceKindBatch(t *testing.T) { + skipIfShort(t) + t.Parallel() + + for _, dialect := range []databasepb.DatabaseDialect{ + databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, + databasepb.DatabaseDialect_POSTGRESQL, + } { + name := "GoogleSQL" + if dialect == databasepb.DatabaseDialect_POSTGRESQL { + name = "PostgreSQL" + } + t.Run(name, func(t *testing.T) { + ctx := context.Background() + dsn, cleanup, err := createTestDBWithDialect(ctx, dialect) + if err != nil { + t.Fatalf("failed to create test db: %v", err) + } + defer cleanup() + + dsnWithParams := fmt.Sprintf("%s;default_sequence_kind=bit_reversed_positive", dsn) + db, err := sql.Open("spanner", dsnWithParams) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + var stmt1, stmt2 string + if dialect == databasepb.DatabaseDialect_POSTGRESQL { + stmt1 = "CREATE TABLE testseq1 (id1 int8 PRIMARY KEY, value varchar)" + stmt2 = "CREATE TABLE testseq2 (id2 serial PRIMARY KEY, value varchar)" + } else { + stmt1 = "CREATE TABLE testseq1 (id1 INT64 PRIMARY KEY, value STRING(MAX))" + stmt2 = "CREATE TABLE testseq2 (id2 INT64 AUTO_INCREMENT PRIMARY KEY, value STRING(MAX))" + } + + if _, err := conn.ExecContext(ctx, "START BATCH DDL"); err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, stmt1); err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, stmt2); err != nil { + t.Fatal(err) + } + + if _, err := conn.ExecContext(ctx, "RUN BATCH"); err != nil { + t.Fatalf("RUN BATCH failed: %v", err) + } + + checkTable1 := "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'testseq1'" + checkTable2 := "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'testseq2'" + + var count1, count2 int + if err := conn.QueryRowContext(ctx, checkTable1).Scan(&count1); err != nil { + t.Fatalf("failed to query table 1: %v", err) + } + if g, w := count1, 1; g != w { + t.Errorf("table 1 not created, count mismatch\nGot: %v\nWant: %v", g, w) + } + + if err := conn.QueryRowContext(ctx, checkTable2).Scan(&count2); err != nil { + t.Fatalf("failed to query table 2: %v", err) + } + if g, w := count2, 1; g != w { + t.Errorf("table 2 not created, count mismatch\nGot: %v\nWant: %v", g, w) + } + }) + } +} diff --git a/testutil/inmem_database_admin_server.go b/testutil/inmem_database_admin_server.go index a202f390..ed4e6008 100644 --- a/testutil/inmem_database_admin_server.go +++ b/testutil/inmem_database_admin_server.go @@ -39,6 +39,7 @@ type InMemDatabaseAdminServer interface { Reqs() []proto.Message SetReqs([]proto.Message) SetErr(error) + SetErrs([]error) AddDdlResponse(key string, result *longrunningpb.Operation) } @@ -50,6 +51,8 @@ type inMemDatabaseAdminServer struct { reqs []proto.Message // If set, all calls return this error err error + // If set, calls will pop and return these errors in sequence + errs []error // responses to return if err == nil resps []proto.Message @@ -58,11 +61,15 @@ type inMemDatabaseAdminServer struct { // The key is calculated by concatenating all statements in the UpdateDatabaseDdlRequest into one string separated // by semicolons. ddlResults map[string]*longrunningpb.Operation + operations map[string]*longrunningpb.Operation } // NewInMemDatabaseAdminServer creates a new in-mem test server. func NewInMemDatabaseAdminServer() InMemDatabaseAdminServer { - res := &inMemDatabaseAdminServer{ddlResults: make(map[string]*longrunningpb.Operation)} + res := &inMemDatabaseAdminServer{ + ddlResults: make(map[string]*longrunningpb.Operation), + operations: make(map[string]*longrunningpb.Operation), + } return res } @@ -70,6 +77,9 @@ func (s *inMemDatabaseAdminServer) GetOperation(ctx context.Context, req *longru if s.err != nil { return nil, s.err } + if op, ok := s.operations[req.Name]; ok { + return op, nil + } if len(s.resps) > 0 { return s.resps[0].(*longrunningpb.Operation), nil } @@ -101,7 +111,11 @@ func (s *inMemDatabaseAdminServer) CreateDatabase(ctx context.Context, req *data if s.err != nil { return nil, s.err } - return s.resps[0].(*longrunningpb.Operation), nil + resp := s.popOperation() + if resp != nil { + s.operations[resp.Name] = resp + } + return resp, nil } func (s *inMemDatabaseAdminServer) DropDatabase(ctx context.Context, req *databasepb.DropDatabaseRequest) (*emptypb.Empty, error) { @@ -122,14 +136,23 @@ func (s *inMemDatabaseAdminServer) UpdateDatabaseDdl(ctx context.Context, req *d return nil, fmt.Errorf("x-goog-api-client = %v, expected gl-go key", xg) } s.reqs = append(s.reqs, req) + if err := s.popError(); err != nil { + return nil, err + } if s.err != nil { return nil, s.err } key := toKey(req) - if resp, ok := s.ddlResults[key]; ok { - return resp, nil + var resp *longrunningpb.Operation + if r, ok := s.ddlResults[key]; ok { + resp = r + } else { + resp = s.popOperation() } - return s.resps[0].(*longrunningpb.Operation), nil + if resp != nil { + s.operations[resp.Name] = resp + } + return resp, nil } func toKey(req *databasepb.UpdateDatabaseDdlRequest) string { @@ -170,3 +193,34 @@ func (s *inMemDatabaseAdminServer) SetErr(err error) { func (s *inMemDatabaseAdminServer) AddDdlResponse(key string, result *longrunningpb.Operation) { s.ddlResults[key] = result } + +func (s *inMemDatabaseAdminServer) popOperation() *longrunningpb.Operation { + if len(s.resps) == 0 { + return nil + } + op, ok := s.resps[0].(*longrunningpb.Operation) + if !ok { + return nil + } + if len(s.resps) > 1 { + s.resps = s.resps[1:] + } + return op +} + +func (s *inMemDatabaseAdminServer) SetErrs(errs []error) { + s.errs = errs +} + +func (s *inMemDatabaseAdminServer) popError() error { + if len(s.errs) == 0 { + return nil + } + err := s.errs[0] + if len(s.errs) > 1 { + s.errs = s.errs[1:] + } else { + s.errs = nil + } + return err +}