From 9323799d6a92dc03441370f8f52f866fadb3b4f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Thu, 11 Jun 2026 14:59:09 +0200 Subject: [PATCH 1/4] feat(driver): support default_sequence_kind and auto-setting on DDL failure --- conn.go | 71 ++++++++++++++++++++ connection_properties.go | 9 +++ driver_with_mockserver_test.go | 87 +++++++++++++++++++++++++ testutil/inmem_database_admin_server.go | 41 ++++++++++-- 4 files changed, 203 insertions(+), 5 deletions(-) diff --git a/conn.go b/conn.go index d7d9aafb..242bce29 100644 --- a/conn.go +++ b/conn.go @@ -21,7 +21,9 @@ import ( "errors" "fmt" "log/slog" + "regexp" "slices" + "strings" "sync" "time" @@ -717,6 +719,61 @@ func (c *conn) execDDL(ctx context.Context, statements ...spanner.Statement) (dr if err := c.waitForDDLOperation(ctx, op.Name(), func(ctx context.Context) error { return op.Wait(ctx) }); err != nil { + defaultSequenceKind := propertyDefaultSequenceKind.GetValueOrDefault(c.state) + if defaultSequenceKind != "" && isMissingDefaultSequenceKindError(err) { + dbID := c.databaseID() + var alterStatement string + if c.parser.Dialect == adminpb.DatabaseDialect_POSTGRESQL { + alterStatement = fmt.Sprintf(`ALTER DATABASE "%s" SET spanner.default_sequence_kind = '%s'`, strings.ReplaceAll(dbID, `"`, `""`), defaultSequenceKind) + } else { + alterStatement = fmt.Sprintf("ALTER DATABASE `%s` SET OPTIONS (default_sequence_kind = '%s')", strings.ReplaceAll(dbID, "`", "``"), defaultSequenceKind) + } + opAlter, errAlter := c.adminClient.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ + Database: c.database, + Statements: []string{alterStatement}, + }) + if errAlter == nil { + errAlter = c.waitForDDLOperation(ctx, opAlter.Name(), func(ctx context.Context) error { + return opAlter.Wait(ctx) + }) + } + if errAlter == nil { + var restartIndex int + metadata, errMetadata := op.Metadata() + if errMetadata == nil && metadata != nil { + for _, ts := range metadata.CommitTimestamps { + if ts != nil { + restartIndex++ + } else { + break + } + } + } + if restartIndex < len(ddlStatements) { + retryStatements := ddlStatements[restartIndex:] + opRetry, errRetry := c.adminClient.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ + Database: c.database, + Statements: retryStatements, + }) + if errRetry == nil { + c.lastDDLOperationID = opRetry.Name() + if errRetry = c.waitForDDLOperation(ctx, opRetry.Name(), func(ctx context.Context) error { + return opRetry.Wait(ctx) + }); errRetry == nil { + mode := propertyDDLExecutionMode.GetValueOrDefault(c.state) + if mode == DDLExecutionModeAsync || mode == DDLExecutionModeAsyncWait { + return &result{operationID: opRetry.Name()}, nil + } + return driver.ResultNoRows, nil + } + } + if errRetry != nil { + err = errRetry + } + } + } + } + if len(statements) > 1 { be := &BatchError{ Err: err, @@ -1878,3 +1935,17 @@ func execAsPartitionedDML(ctx context.Context, c *spanner.Client, statement span queryOptions.ExcludeTxnFromChangeStreams = options.TransactionOptions.ExcludeTxnFromChangeStreams return c.PartitionedUpdateWithOptions(ctx, statement, queryOptions) } + +var reMissingDefaultSequenceKind = regexp.MustCompile(`.*Please specify the sequence kind explicitly or set the database option\s+['\x60]default_sequence_kind['\x60]\.`) + +func isMissingDefaultSequenceKindError(err error) bool { + if err == nil { + return false + } + return reMissingDefaultSequenceKind.MatchString(err.Error()) +} + +func (c *conn) databaseID() string { + parts := strings.Split(c.database, "/") + return parts[len(parts)-1] +} diff --git a/connection_properties.go b/connection_properties.go index 1a4f5387..2fcf482a 100644 --- a/connection_properties.go +++ b/connection_properties.go @@ -175,6 +175,15 @@ var propertyReadOnlyStaleness = createConnectionProperty( connectionstate.ContextUser, connectionstate.ConvertReadOnlyStaleness, ) +var propertyDefaultSequenceKind = createConnectionProperty( + "default_sequence_kind", + "The default sequence kind to automatically set if a DDL statement fails due to missing sequence kind.", + "", + false, + nil, + connectionstate.ContextUser, + connectionstate.ConvertString, +) var propertyAutoPartitionMode = createConnectionProperty( "auto_partition_mode", diff --git a/driver_with_mockserver_test.go b/driver_with_mockserver_test.go index 6d90afa6..39763c49 100644 --- a/driver_with_mockserver_test.go +++ b/driver_with_mockserver_test.go @@ -42,6 +42,7 @@ import ( "github.com/googleapis/go-sql-spanner/parser" "github.com/googleapis/go-sql-spanner/testutil" "google.golang.org/api/option" + pbstatus "google.golang.org/genproto/googleapis/rpc/status" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" @@ -2498,6 +2499,92 @@ func TestDdlInTransaction(t *testing.T) { } } +func TestAutoDefaultSequenceKind(t *testing.T) { + t.Parallel() + + for _, dialect := range []databasepb.DatabaseDialect{ + databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, + databasepb.DatabaseDialect_POSTGRESQL, + } { + name := "GoogleSQL" + dsnParams := "default_sequence_kind=bit_reversed_positive" + if dialect == databasepb.DatabaseDialect_POSTGRESQL { + name = "PostgreSQL" + dsnParams = "dialect=postgresql;default_sequence_kind=bit_reversed_positive" + } + t.Run(name, func(t *testing.T) { + db, server, teardown := setupTestDBConnectionWithParamsAndDialect(t, dsnParams, dialect) + defer teardown() + + opError := &longrunningpb.Operation{ + Done: true, + Result: &longrunningpb.Operation_Error{ + Error: &pbstatus.Status{ + Code: int32(codes.InvalidArgument), + Message: "Please specify the sequence kind explicitly or set the database option 'default_sequence_kind'.", + }, + }, + Name: "op-error", + } + anyResponse, _ := anypb.New(&emptypb.Empty{}) + opSuccess := &longrunningpb.Operation{ + Done: true, + Result: &longrunningpb.Operation_Response{Response: anyResponse}, + Name: "op-success", + } + + server.TestDatabaseAdmin.SetResps([]proto.Message{opError, opSuccess, opSuccess}) + + query := "CREATE SEQUENCE my_seq" + if dialect == databasepb.DatabaseDialect_POSTGRESQL { + query = "CREATE SEQUENCE my_seq" // same for both dialects + } + _, err := db.ExecContext(context.Background(), query) + if err != nil { + t.Fatal(err) + } + + requests := server.TestDatabaseAdmin.Reqs() + if g, w := len(requests), 3; g != w { + t.Fatalf("requests count mismatch\nGot: %v\nWant: %v", g, w) + } + + // First request: original DDL + req0, ok := requests[0].(*databasepb.UpdateDatabaseDdlRequest) + if !ok { + t.Fatalf("request 0 type mismatch, got %T", requests[0]) + } + if g, w := req0.Statements[0], query; g != w { + t.Errorf("request 0 statement mismatch\nGot: %s\nWant: %s", g, w) + } + + // Second request: ALTER DATABASE statement setting default_sequence_kind + req1, ok := requests[1].(*databasepb.UpdateDatabaseDdlRequest) + if !ok { + t.Fatalf("request 1 type mismatch, got %T", requests[1]) + } + var wantAlter string + if dialect == databasepb.DatabaseDialect_POSTGRESQL { + wantAlter = `ALTER DATABASE "d" SET spanner.default_sequence_kind = 'bit_reversed_positive'` + } else { + wantAlter = "ALTER DATABASE `d` SET OPTIONS (default_sequence_kind = 'bit_reversed_positive')" + } + if g, w := req1.Statements[0], wantAlter; g != w { + t.Errorf("request 1 statement mismatch\nGot: %s\nWant: %s", g, w) + } + + // Third request: Retried DDL statement + req2, ok := requests[2].(*databasepb.UpdateDatabaseDdlRequest) + if !ok { + t.Fatalf("request 2 type mismatch, got %T", requests[2]) + } + if g, w := req2.Statements[0], query; g != w { + t.Errorf("request 2 statement mismatch\nGot: %s\nWant: %s", g, w) + } + }) + } +} + func TestBegin(t *testing.T) { t.Parallel() diff --git a/testutil/inmem_database_admin_server.go b/testutil/inmem_database_admin_server.go index a202f390..25d8e1ac 100644 --- a/testutil/inmem_database_admin_server.go +++ b/testutil/inmem_database_admin_server.go @@ -58,11 +58,15 @@ type inMemDatabaseAdminServer struct { // The key is calculated by concatenating all statements in the UpdateDatabaseDdlRequest into one string separated // by semicolons. ddlResults map[string]*longrunningpb.Operation + operations map[string]*longrunningpb.Operation } // NewInMemDatabaseAdminServer creates a new in-mem test server. func NewInMemDatabaseAdminServer() InMemDatabaseAdminServer { - res := &inMemDatabaseAdminServer{ddlResults: make(map[string]*longrunningpb.Operation)} + res := &inMemDatabaseAdminServer{ + ddlResults: make(map[string]*longrunningpb.Operation), + operations: make(map[string]*longrunningpb.Operation), + } return res } @@ -70,6 +74,9 @@ func (s *inMemDatabaseAdminServer) GetOperation(ctx context.Context, req *longru if s.err != nil { return nil, s.err } + if op, ok := s.operations[req.Name]; ok { + return op, nil + } if len(s.resps) > 0 { return s.resps[0].(*longrunningpb.Operation), nil } @@ -101,7 +108,11 @@ func (s *inMemDatabaseAdminServer) CreateDatabase(ctx context.Context, req *data if s.err != nil { return nil, s.err } - return s.resps[0].(*longrunningpb.Operation), nil + resp := s.popOperation() + if resp != nil { + s.operations[resp.Name] = resp + } + return resp, nil } func (s *inMemDatabaseAdminServer) DropDatabase(ctx context.Context, req *databasepb.DropDatabaseRequest) (*emptypb.Empty, error) { @@ -126,10 +137,16 @@ func (s *inMemDatabaseAdminServer) UpdateDatabaseDdl(ctx context.Context, req *d return nil, s.err } key := toKey(req) - if resp, ok := s.ddlResults[key]; ok { - return resp, nil + var resp *longrunningpb.Operation + if r, ok := s.ddlResults[key]; ok { + resp = r + } else { + resp = s.popOperation() } - return s.resps[0].(*longrunningpb.Operation), nil + if resp != nil { + s.operations[resp.Name] = resp + } + return resp, nil } func toKey(req *databasepb.UpdateDatabaseDdlRequest) string { @@ -170,3 +187,17 @@ func (s *inMemDatabaseAdminServer) SetErr(err error) { func (s *inMemDatabaseAdminServer) AddDdlResponse(key string, result *longrunningpb.Operation) { s.ddlResults[key] = result } + +func (s *inMemDatabaseAdminServer) popOperation() *longrunningpb.Operation { + if len(s.resps) == 0 { + return nil + } + op, ok := s.resps[0].(*longrunningpb.Operation) + if !ok { + return nil + } + if len(s.resps) > 1 { + s.resps = s.resps[1:] + } + return op +} From 5d3472f98c6afd1bdb6e3e56089b77b99cb8d6fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Fri, 12 Jun 2026 16:32:23 +0200 Subject: [PATCH 2/4] chore: address review comments and add tests --- conn.go | 42 +++++++--- driver_with_mockserver_test.go | 99 ++++++++++++++++++++++++ integration_test.go | 136 +++++++++++++++++++++++++++++++++ 3 files changed, 265 insertions(+), 12 deletions(-) diff --git a/conn.go b/conn.go index 242bce29..c9a91a71 100644 --- a/conn.go +++ b/conn.go @@ -720,6 +720,8 @@ func (c *conn) execDDL(ctx context.Context, statements ...spanner.Statement) (dr return op.Wait(ctx) }); err != nil { defaultSequenceKind := propertyDefaultSequenceKind.GetValueOrDefault(c.state) + var opRetry *adminapi.UpdateDatabaseDdlOperation + var restartIndex int if defaultSequenceKind != "" && isMissingDefaultSequenceKindError(err) { dbID := c.databaseID() var alterStatement string @@ -738,7 +740,6 @@ func (c *conn) execDDL(ctx context.Context, statements ...spanner.Statement) (dr }) } if errAlter == nil { - var restartIndex int metadata, errMetadata := op.Metadata() if errMetadata == nil && metadata != nil { for _, ts := range metadata.CommitTimestamps { @@ -751,7 +752,8 @@ func (c *conn) execDDL(ctx context.Context, statements ...spanner.Statement) (dr } if restartIndex < len(ddlStatements) { retryStatements := ddlStatements[restartIndex:] - opRetry, errRetry := c.adminClient.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ + var errRetry error + opRetry, errRetry = c.adminClient.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ Database: c.database, Statements: retryStatements, }) @@ -779,17 +781,33 @@ func (c *conn) execDDL(ctx context.Context, statements ...spanner.Statement) (dr Err: err, BatchUpdateCounts: []int64{}, } - metadata, err := op.Metadata() - if err != nil { - c.logger.WarnContext(ctx, fmt.Sprintf("Error getting metadata for UpdateDatabaseDdl: %v", err)) - } else if metadata != nil { - for _, ts := range metadata.CommitTimestamps { - if ts != nil { - be.BatchUpdateCounts = append(be.BatchUpdateCounts, int64(-1)) - } else { - break + var successCount int + if opRetry != nil { + successCount = restartIndex + metadataRetry, errMetadataRetry := opRetry.Metadata() + if errMetadataRetry == nil && metadataRetry != nil { + for _, ts := range metadataRetry.CommitTimestamps { + if ts != nil { + successCount++ + } else { + break + } } } + } else { + metadata, errMetadata := op.Metadata() + if errMetadata == nil && metadata != nil { + for _, ts := range metadata.CommitTimestamps { + if ts != nil { + successCount++ + } else { + break + } + } + } + } + for i := 0; i < successCount; i++ { + be.BatchUpdateCounts = append(be.BatchUpdateCounts, int64(-1)) } return nil, be } @@ -1936,7 +1954,7 @@ func execAsPartitionedDML(ctx context.Context, c *spanner.Client, statement span return c.PartitionedUpdateWithOptions(ctx, statement, queryOptions) } -var reMissingDefaultSequenceKind = regexp.MustCompile(`.*Please specify the sequence kind explicitly or set the database option\s+['\x60]default_sequence_kind['\x60]\.`) +var reMissingDefaultSequenceKind = regexp.MustCompile(`Please specify the sequence kind explicitly or set the database option\s+['\x60]default_sequence_kind['\x60]\.`) func isMissingDefaultSequenceKindError(err error) bool { if err == nil { diff --git a/driver_with_mockserver_test.go b/driver_with_mockserver_test.go index 39763c49..e2726464 100644 --- a/driver_with_mockserver_test.go +++ b/driver_with_mockserver_test.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "database/sql/driver" + "errors" "encoding/base64" "encoding/json" "fmt" @@ -51,6 +52,7 @@ import ( "google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/structpb" + "google.golang.org/protobuf/types/known/timestamppb" ) func TestPingContext(t *testing.T) { @@ -2585,6 +2587,103 @@ func TestAutoDefaultSequenceKind(t *testing.T) { } } +func TestAutoDefaultSequenceKindBatchFailure(t *testing.T) { + t.Parallel() + + for _, dialect := range []databasepb.DatabaseDialect{ + databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, + databasepb.DatabaseDialect_POSTGRESQL, + } { + name := "GoogleSQL" + dsnParams := "default_sequence_kind=bit_reversed_positive" + if dialect == databasepb.DatabaseDialect_POSTGRESQL { + name = "PostgreSQL" + dsnParams = "dialect=postgresql;default_sequence_kind=bit_reversed_positive" + } + t.Run(name, func(t *testing.T) { + db, server, teardown := setupTestDBConnectionWithParamsAndDialect(t, dsnParams, dialect) + defer teardown() + + anyResponse, _ := anypb.New(&emptypb.Empty{}) + meta1, _ := anypb.New(&databasepb.UpdateDatabaseDdlMetadata{ + CommitTimestamps: []*timestamppb.Timestamp{{Seconds: time.Now().Unix(), Nanos: 0}}, + }) + opError1 := &longrunningpb.Operation{ + Done: true, + Result: &longrunningpb.Operation_Error{ + Error: &pbstatus.Status{ + Code: int32(codes.InvalidArgument), + Message: "Please specify the sequence kind explicitly or set the database option 'default_sequence_kind'.", + }, + }, + Metadata: meta1, + Name: "op-error-1", + } + + opSuccessAlter := &longrunningpb.Operation{ + Done: true, + Result: &longrunningpb.Operation_Response{Response: anyResponse}, + Name: "op-success-alter", + } + + meta2, _ := anypb.New(&databasepb.UpdateDatabaseDdlMetadata{ + CommitTimestamps: []*timestamppb.Timestamp{{Seconds: time.Now().Unix(), Nanos: 0}}, + }) + opError2 := &longrunningpb.Operation{ + Done: true, + Result: &longrunningpb.Operation_Error{ + Error: &pbstatus.Status{ + Code: int32(codes.InvalidArgument), + Message: "Some other error on statement 3", + }, + }, + Metadata: meta2, + Name: "op-error-2", + } + + server.TestDatabaseAdmin.SetResps([]proto.Message{opError1, opSuccessAlter, opError2}) + + conn, err := db.Conn(context.Background()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + if _, err := conn.ExecContext(context.Background(), "START BATCH DDL"); err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(context.Background(), "CREATE SEQUENCE my_seq1"); err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(context.Background(), "CREATE SEQUENCE my_seq2"); err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(context.Background(), "CREATE TABLE my_table (id INT64) PRIMARY KEY(id)"); err != nil { + t.Fatal(err) + } + + _, err = conn.ExecContext(context.Background(), "RUN BATCH") + if err == nil { + t.Fatal("expected batch error, got nil") + } + + var be *BatchError + if !errors.As(err, &be) { + t.Fatalf("expected BatchError, got: %v", err) + } + + if got, want := len(be.BatchUpdateCounts), 2; got != want { + t.Errorf("successful statements count mismatch, got %d, want %d", got, want) + } + for i, val := range be.BatchUpdateCounts { + if val != -1 { + t.Errorf("update count at index %d mismatch, got %d, want -1", i, val) + } + } + }) + } +} + func TestBegin(t *testing.T) { t.Parallel() diff --git a/integration_test.go b/integration_test.go index bcce85e1..5715b9d7 100644 --- a/integration_test.go +++ b/integration_test.go @@ -2652,3 +2652,139 @@ func nullJsonOrStringArray(v []spanner.NullJSON) interface{} { } return res } + +func TestIntegration_AutoDefaultSequenceKind(t *testing.T) { + skipIfShort(t) + t.Parallel() + + for _, dialect := range []databasepb.DatabaseDialect{ + databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, + databasepb.DatabaseDialect_POSTGRESQL, + } { + name := "GoogleSQL" + if dialect == databasepb.DatabaseDialect_POSTGRESQL { + name = "PostgreSQL" + } + t.Run(name, func(t *testing.T) { + ctx := context.Background() + dsn, cleanup, err := createTestDBWithDialect(ctx, dialect) + if err != nil { + t.Fatalf("failed to create test db: %v", err) + } + defer cleanup() + + dbNoParams, err := sql.Open("spanner", dsn) + if err != nil { + t.Fatal(err) + } + defer dbNoParams.Close() + + var createSeqStmt string + if dialect == databasepb.DatabaseDialect_POSTGRESQL { + createSeqStmt = "CREATE TABLE test_seq (id serial PRIMARY KEY, value varchar)" + } else { + createSeqStmt = "CREATE TABLE test_seq (id INT64 AUTO_INCREMENT PRIMARY KEY, value STRING(MAX))" + } + + _, err = dbNoParams.ExecContext(ctx, createSeqStmt) + if err == nil { + t.Fatalf("expected error without default_sequence_kind parameter") + } + + dsnWithParams := fmt.Sprintf("%s;default_sequence_kind=bit_reversed_positive", dsn) + db, err := sql.Open("spanner", dsnWithParams) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + _, err = db.ExecContext(ctx, createSeqStmt) + if err != nil { + t.Fatalf("failed to execute CREATE statement: %v", err) + } + + insertStmt := "INSERT INTO test_seq (value) VALUES ('One')" + _, err = db.ExecContext(ctx, insertStmt) + if err != nil { + t.Fatalf("failed to insert data: %v", err) + } + }) + } +} + +func TestIntegration_AutoDefaultSequenceKindBatch(t *testing.T) { + skipIfShort(t) + t.Parallel() + + for _, dialect := range []databasepb.DatabaseDialect{ + databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, + databasepb.DatabaseDialect_POSTGRESQL, + } { + name := "GoogleSQL" + if dialect == databasepb.DatabaseDialect_POSTGRESQL { + name = "PostgreSQL" + } + t.Run(name, func(t *testing.T) { + ctx := context.Background() + dsn, cleanup, err := createTestDBWithDialect(ctx, dialect) + if err != nil { + t.Fatalf("failed to create test db: %v", err) + } + defer cleanup() + + dsnWithParams := fmt.Sprintf("%s;default_sequence_kind=bit_reversed_positive", dsn) + db, err := sql.Open("spanner", dsnWithParams) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + var stmt1, stmt2 string + if dialect == databasepb.DatabaseDialect_POSTGRESQL { + stmt1 = "CREATE TABLE testseq1 (id1 int8 PRIMARY KEY, value varchar)" + stmt2 = "CREATE TABLE testseq2 (id2 serial PRIMARY KEY, value varchar)" + } else { + stmt1 = "CREATE TABLE testseq1 (id1 INT64 PRIMARY KEY, value STRING(MAX))" + stmt2 = "CREATE TABLE testseq2 (id2 INT64 AUTO_INCREMENT PRIMARY KEY, value STRING(MAX))" + } + + if _, err := conn.ExecContext(ctx, "START BATCH DDL"); err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, stmt1); err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, stmt2); err != nil { + t.Fatal(err) + } + + if _, err := conn.ExecContext(ctx, "RUN BATCH"); err != nil { + t.Fatalf("RUN BATCH failed: %v", err) + } + + checkTable1 := "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'testseq1'" + checkTable2 := "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'testseq2'" + + var count1, count2 int + if err := conn.QueryRowContext(ctx, checkTable1).Scan(&count1); err != nil { + t.Fatalf("failed to query table 1: %v", err) + } + if count1 != 1 { + t.Errorf("table 1 not created, count = %d", count1) + } + + if err := conn.QueryRowContext(ctx, checkTable2).Scan(&count2); err != nil { + t.Fatalf("failed to query table 2: %v", err) + } + if count2 != 1 { + t.Errorf("table 2 not created, count = %d", count2) + } + }) + } +} From bb7450bdf2a6102ba99a2379f7b96850a6257a1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Fri, 12 Jun 2026 17:16:55 +0200 Subject: [PATCH 3/4] chore: cleanup and add more tests --- GEMINI.md | 6 + conn.go | 128 +------------------- default_sequence_kind.go | 151 ++++++++++++++++++++++++ driver_test.go | 12 ++ driver_with_mockserver_test.go | 130 +++++++++++++++++++- testutil/inmem_database_admin_server.go | 23 ++++ 6 files changed, 322 insertions(+), 128 deletions(-) create mode 100644 default_sequence_kind.go diff --git a/GEMINI.md b/GEMINI.md index a13d25a6..391a3b14 100644 --- a/GEMINI.md +++ b/GEMINI.md @@ -67,3 +67,9 @@ Any pull request modifying or extending the driver's features must include: } ``` Note the two spaces after `Got:` to align the values visually. + +--- + +## 6. Code Style & Formatting + +- **Go Code Formatting**: All Go code must be formatted using the standard `gofmt -w -s .` formatter. Running this formatting command is required before submitting a pull request. diff --git a/conn.go b/conn.go index c9a91a71..9dc1ecd5 100644 --- a/conn.go +++ b/conn.go @@ -19,11 +19,8 @@ import ( "database/sql" "database/sql/driver" "errors" - "fmt" "log/slog" - "regexp" "slices" - "strings" "sync" "time" @@ -707,116 +704,7 @@ func (c *conn) execDDL(ctx context.Context, statements ...spanner.Statement) (dr return (&executableDropDatabaseStatement{stmt}).execContext(ctx, c, nil, []driver.NamedValue{}) } - op, err := c.adminClient.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ - Database: c.database, - Statements: ddlStatements, - }) - if err != nil { - return nil, err - } - c.lastDDLOperationID = op.Name() - - if err := c.waitForDDLOperation(ctx, op.Name(), func(ctx context.Context) error { - return op.Wait(ctx) - }); err != nil { - defaultSequenceKind := propertyDefaultSequenceKind.GetValueOrDefault(c.state) - var opRetry *adminapi.UpdateDatabaseDdlOperation - var restartIndex int - if defaultSequenceKind != "" && isMissingDefaultSequenceKindError(err) { - dbID := c.databaseID() - var alterStatement string - if c.parser.Dialect == adminpb.DatabaseDialect_POSTGRESQL { - alterStatement = fmt.Sprintf(`ALTER DATABASE "%s" SET spanner.default_sequence_kind = '%s'`, strings.ReplaceAll(dbID, `"`, `""`), defaultSequenceKind) - } else { - alterStatement = fmt.Sprintf("ALTER DATABASE `%s` SET OPTIONS (default_sequence_kind = '%s')", strings.ReplaceAll(dbID, "`", "``"), defaultSequenceKind) - } - opAlter, errAlter := c.adminClient.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ - Database: c.database, - Statements: []string{alterStatement}, - }) - if errAlter == nil { - errAlter = c.waitForDDLOperation(ctx, opAlter.Name(), func(ctx context.Context) error { - return opAlter.Wait(ctx) - }) - } - if errAlter == nil { - metadata, errMetadata := op.Metadata() - if errMetadata == nil && metadata != nil { - for _, ts := range metadata.CommitTimestamps { - if ts != nil { - restartIndex++ - } else { - break - } - } - } - if restartIndex < len(ddlStatements) { - retryStatements := ddlStatements[restartIndex:] - var errRetry error - opRetry, errRetry = c.adminClient.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ - Database: c.database, - Statements: retryStatements, - }) - if errRetry == nil { - c.lastDDLOperationID = opRetry.Name() - if errRetry = c.waitForDDLOperation(ctx, opRetry.Name(), func(ctx context.Context) error { - return opRetry.Wait(ctx) - }); errRetry == nil { - mode := propertyDDLExecutionMode.GetValueOrDefault(c.state) - if mode == DDLExecutionModeAsync || mode == DDLExecutionModeAsyncWait { - return &result{operationID: opRetry.Name()}, nil - } - return driver.ResultNoRows, nil - } - } - if errRetry != nil { - err = errRetry - } - } - } - } - - if len(statements) > 1 { - be := &BatchError{ - Err: err, - BatchUpdateCounts: []int64{}, - } - var successCount int - if opRetry != nil { - successCount = restartIndex - metadataRetry, errMetadataRetry := opRetry.Metadata() - if errMetadataRetry == nil && metadataRetry != nil { - for _, ts := range metadataRetry.CommitTimestamps { - if ts != nil { - successCount++ - } else { - break - } - } - } - } else { - metadata, errMetadata := op.Metadata() - if errMetadata == nil && metadata != nil { - for _, ts := range metadata.CommitTimestamps { - if ts != nil { - successCount++ - } else { - break - } - } - } - } - for i := 0; i < successCount; i++ { - be.BatchUpdateCounts = append(be.BatchUpdateCounts, int64(-1)) - } - return nil, be - } - return nil, err - } - mode := propertyDDLExecutionMode.GetValueOrDefault(c.state) - if mode == DDLExecutionModeAsync || mode == DDLExecutionModeAsyncWait { - return &result{operationID: op.Name()}, nil - } + return c.executeDDLWithDefaultSequenceKindRetry(ctx, statements, ddlStatements) } return driver.ResultNoRows, nil } @@ -1953,17 +1841,3 @@ func execAsPartitionedDML(ctx context.Context, c *spanner.Client, statement span queryOptions.ExcludeTxnFromChangeStreams = options.TransactionOptions.ExcludeTxnFromChangeStreams return c.PartitionedUpdateWithOptions(ctx, statement, queryOptions) } - -var reMissingDefaultSequenceKind = regexp.MustCompile(`Please specify the sequence kind explicitly or set the database option\s+['\x60]default_sequence_kind['\x60]\.`) - -func isMissingDefaultSequenceKindError(err error) bool { - if err == nil { - return false - } - return reMissingDefaultSequenceKind.MatchString(err.Error()) -} - -func (c *conn) databaseID() string { - parts := strings.Split(c.database, "/") - return parts[len(parts)-1] -} diff --git a/default_sequence_kind.go b/default_sequence_kind.go new file mode 100644 index 00000000..cf625c86 --- /dev/null +++ b/default_sequence_kind.go @@ -0,0 +1,151 @@ +package spannerdriver + +import ( + "context" + "database/sql/driver" + "fmt" + "regexp" + "strings" + + "cloud.google.com/go/spanner" + adminapi "cloud.google.com/go/spanner/admin/database/apiv1" + adminpb "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" +) + +var reMissingDefaultSequenceKind = regexp.MustCompile(`Please specify the sequence kind explicitly or set the database option\s+['\x60]?default_sequence_kind['\x60]?\.`) + +func isMissingDefaultSequenceKindError(err error) bool { + if err == nil { + return false + } + return reMissingDefaultSequenceKind.MatchString(err.Error()) +} + +func (c *conn) executeDDLWithDefaultSequenceKindRetry(ctx context.Context, originalStatements []spanner.Statement, ddlStatements []string) (driver.Result, error) { + op, err := c.adminClient.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ + Database: c.database, + Statements: ddlStatements, + }) + + var opRetry *adminapi.UpdateDatabaseDdlOperation + var restartIndex int + var retryErr error + + if err != nil { + // The RPC execution returned an error. + defaultSequenceKind := propertyDefaultSequenceKind.GetValueOrDefault(c.state) + if defaultSequenceKind != "" && isMissingDefaultSequenceKindError(err) { + if errAlter := c.setDefaultSequenceKind(ctx, defaultSequenceKind); errAlter == nil { + opRetry, retryErr = c.adminClient.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ + Database: c.database, + Statements: ddlStatements, + }) + } + } + } else { + c.lastDDLOperationID = op.Name() + err = c.waitForDDLOperation(ctx, op.Name(), func(ctx context.Context) error { + return op.Wait(ctx) + }) + if err != nil { + // The long-running operation returned an error. + defaultSequenceKind := propertyDefaultSequenceKind.GetValueOrDefault(c.state) + if defaultSequenceKind != "" && isMissingDefaultSequenceKindError(err) { + if errAlter := c.setDefaultSequenceKind(ctx, defaultSequenceKind); errAlter == nil { + restartIndex = getSuccessCount(op) + if restartIndex < len(ddlStatements) { + opRetry, retryErr = c.adminClient.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ + Database: c.database, + Statements: ddlStatements[restartIndex:], + }) + } + } + } + } + } + + // If a retry was successfully scheduled + if opRetry != nil && retryErr == nil { + c.lastDDLOperationID = opRetry.Name() + err = c.waitForDDLOperation(ctx, opRetry.Name(), func(ctx context.Context) error { + return opRetry.Wait(ctx) + }) + if err == nil { + mode := propertyDDLExecutionMode.GetValueOrDefault(c.state) + if mode == DDLExecutionModeAsync || mode == DDLExecutionModeAsyncWait { + return &result{operationID: opRetry.Name()}, nil + } + return driver.ResultNoRows, nil + } + } else if retryErr != nil { + err = retryErr + } + + if err != nil { + if len(originalStatements) > 1 { + be := &BatchError{ + Err: err, + BatchUpdateCounts: []int64{}, + } + successCount := restartIndex + if opRetry != nil { + successCount += getSuccessCount(opRetry) + } + for i := 0; i < successCount; i++ { + be.BatchUpdateCounts = append(be.BatchUpdateCounts, int64(-1)) + } + return nil, be + } + return nil, err + } + + mode := propertyDDLExecutionMode.GetValueOrDefault(c.state) + if mode == DDLExecutionModeAsync || mode == DDLExecutionModeAsyncWait { + return &result{operationID: op.Name()}, nil + } + return driver.ResultNoRows, nil +} + +func (c *conn) setDefaultSequenceKind(ctx context.Context, defaultSequenceKind string) error { + dbID := c.databaseID() + var alterStatement string + if c.parser.Dialect == adminpb.DatabaseDialect_POSTGRESQL { + alterStatement = fmt.Sprintf(`ALTER DATABASE "%s" SET spanner.default_sequence_kind = '%s'`, strings.ReplaceAll(dbID, `"`, `""`), defaultSequenceKind) + } else { + alterStatement = fmt.Sprintf("ALTER DATABASE `%s` SET OPTIONS (default_sequence_kind = '%s')", strings.ReplaceAll(dbID, "`", "``"), defaultSequenceKind) + } + opAlter, errAlter := c.adminClient.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ + Database: c.database, + Statements: []string{alterStatement}, + }) + if errAlter != nil { + return errAlter + } + return c.waitForDDLOperation(ctx, opAlter.Name(), func(ctx context.Context) error { + return opAlter.Wait(ctx) + }) +} + +func (c *conn) databaseID() string { + parts := strings.Split(c.database, "/") + return parts[len(parts)-1] +} + +func getSuccessCount(op *adminapi.UpdateDatabaseDdlOperation) int { + if op == nil { + return 0 + } + metadata, err := op.Metadata() + if err != nil || metadata == nil { + return 0 + } + var count int + for _, ts := range metadata.CommitTimestamps { + if ts != nil { + count++ + } else { + break + } + } + return count +} diff --git a/driver_test.go b/driver_test.go index 83b803d4..e2826dcf 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1087,3 +1087,15 @@ func TestConnectionStateTypeInitialization(t *testing.T) { t.Errorf("ConnectionStateType mismatch. Got: %v, Want: %v", c.connectorConfig.ConnectionStateType, connectionstate.TypeTransactional) } } + +func TestIsMissingDefaultSequenceKindError(t *testing.T) { + err := fmt.Errorf("rpc error: code = InvalidArgument desc = The sequence kind of an identity column id is not specified. Please specify the sequence kind explicitly or set the database option `default_sequence_kind`.") + if !isMissingDefaultSequenceKindError(err) { + t.Errorf("isMissingDefaultSequenceKindError returned false for: %v", err) + } + + errNoQuotes := fmt.Errorf("rpc error: code = InvalidArgument desc = The sequence kind of an identity column id is not specified. Please specify the sequence kind explicitly or set the database option default_sequence_kind.") + if !isMissingDefaultSequenceKindError(errNoQuotes) { + t.Errorf("isMissingDefaultSequenceKindError returned false for: %v", errNoQuotes) + } +} diff --git a/driver_with_mockserver_test.go b/driver_with_mockserver_test.go index 87bc3aaa..d375b11d 100644 --- a/driver_with_mockserver_test.go +++ b/driver_with_mockserver_test.go @@ -18,9 +18,9 @@ import ( "context" "database/sql" "database/sql/driver" - "errors" "encoding/base64" "encoding/json" + "errors" "fmt" "math/big" "math/rand" @@ -2501,6 +2501,134 @@ func TestDdlInTransaction(t *testing.T) { } } +func TestAutoDefaultSequenceKindAsyncMode(t *testing.T) { + t.Parallel() + + for _, dialect := range []databasepb.DatabaseDialect{ + databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, + databasepb.DatabaseDialect_POSTGRESQL, + } { + name := "GoogleSQL" + dsnParams := "default_sequence_kind=bit_reversed_positive;ddl_execution_mode=async" + if dialect == databasepb.DatabaseDialect_POSTGRESQL { + name = "PostgreSQL" + dsnParams = "dialect=postgresql;default_sequence_kind=bit_reversed_positive;ddl_execution_mode=async" + } + t.Run(name, func(t *testing.T) { + db, server, teardown := setupTestDBConnectionWithParamsAndDialect(t, dsnParams, dialect) + defer teardown() + + anyResponse, _ := anypb.New(&emptypb.Empty{}) + opSuccess := &longrunningpb.Operation{ + Done: true, + Result: &longrunningpb.Operation_Response{Response: anyResponse}, + Name: "op-success", + } + server.TestDatabaseAdmin.SetResps([]proto.Message{opSuccess}) + + query := "CREATE SEQUENCE my_seq" + _, err := db.ExecContext(context.Background(), query) + if err != nil { + t.Fatal(err) + } + + requests := server.TestDatabaseAdmin.Reqs() + if g, w := len(requests), 1; g != w { + t.Fatalf("requests count mismatch\nGot: %v\nWant: %v", g, w) + } + + // First request should be the original statement + req0, ok := requests[0].(*databasepb.UpdateDatabaseDdlRequest) + if !ok { + t.Fatalf("request 0 type mismatch, got %T", requests[0]) + } + if g, w := req0.Statements[0], query; g != w { + t.Errorf("request 0 statement mismatch\nGot: %s\nWant: %s", g, w) + } + }) + } +} + +func TestAutoDefaultSequenceKindAsyncModeSyncFailure(t *testing.T) { + t.Parallel() + + for _, dialect := range []databasepb.DatabaseDialect{ + databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, + databasepb.DatabaseDialect_POSTGRESQL, + } { + name := "GoogleSQL" + dsnParams := "default_sequence_kind=bit_reversed_positive;ddl_execution_mode=async" + if dialect == databasepb.DatabaseDialect_POSTGRESQL { + name = "PostgreSQL" + dsnParams = "dialect=postgresql;default_sequence_kind=bit_reversed_positive;ddl_execution_mode=async" + } + t.Run(name, func(t *testing.T) { + db, server, teardown := setupTestDBConnectionWithParamsAndDialect(t, dsnParams, dialect) + defer teardown() + + anyResponse, _ := anypb.New(&emptypb.Empty{}) + opSuccess := &longrunningpb.Operation{ + Done: true, + Result: &longrunningpb.Operation_Response{Response: anyResponse}, + Name: "op-success", + } + + // Mock synchronous DDL error on first call, followed by successful ALTER and then retry. + server.TestDatabaseAdmin.SetErrs([]error{ + gstatus.Error(codes.InvalidArgument, "Please specify the sequence kind explicitly or set the database option 'default_sequence_kind'."), + nil, // ALTER DATABASE succeeds + nil, // DDL retry succeeds + }) + server.TestDatabaseAdmin.SetResps([]proto.Message{ + opSuccess, // ALTER DATABASE op + opSuccess, // DDL retry op + }) + + query := "CREATE SEQUENCE my_seq" + _, err := db.ExecContext(context.Background(), query) + if err != nil { + t.Fatal(err) + } + + requests := server.TestDatabaseAdmin.Reqs() + if g, w := len(requests), 3; g != w { + t.Fatalf("requests count mismatch\nGot: %v\nWant: %v", g, w) + } + + // Verifies database was altered and retried + req0, ok := requests[0].(*databasepb.UpdateDatabaseDdlRequest) + if !ok { + t.Fatalf("request 0 type mismatch, got %T", requests[0]) + } + if g, w := req0.Statements[0], query; g != w { + t.Errorf("request 0 statement mismatch\nGot: %s\nWant: %s", g, w) + } + + req1, ok := requests[1].(*databasepb.UpdateDatabaseDdlRequest) + if !ok { + t.Fatalf("request 1 type mismatch, got %T", requests[1]) + } + var wantAlter string + if dialect == databasepb.DatabaseDialect_POSTGRESQL { + wantAlter = `ALTER DATABASE "d" SET spanner.default_sequence_kind = 'bit_reversed_positive'` + } else { + wantAlter = "ALTER DATABASE `d` SET OPTIONS (default_sequence_kind = 'bit_reversed_positive')" + } + if g, w := req1.Statements[0], wantAlter; g != w { + t.Errorf("request 1 statement mismatch\nGot: %s\nWant: %s", g, w) + } + + req2, ok := requests[2].(*databasepb.UpdateDatabaseDdlRequest) + if !ok { + t.Fatalf("request 2 type mismatch, got %T", requests[2]) + } + if g, w := req2.Statements[0], query; g != w { + t.Errorf("request 2 statement mismatch\nGot: %s\nWant: %s", g, w) + } + }) + } +} + func TestAutoDefaultSequenceKind(t *testing.T) { t.Parallel() diff --git a/testutil/inmem_database_admin_server.go b/testutil/inmem_database_admin_server.go index 25d8e1ac..ed4e6008 100644 --- a/testutil/inmem_database_admin_server.go +++ b/testutil/inmem_database_admin_server.go @@ -39,6 +39,7 @@ type InMemDatabaseAdminServer interface { Reqs() []proto.Message SetReqs([]proto.Message) SetErr(error) + SetErrs([]error) AddDdlResponse(key string, result *longrunningpb.Operation) } @@ -50,6 +51,8 @@ type inMemDatabaseAdminServer struct { reqs []proto.Message // If set, all calls return this error err error + // If set, calls will pop and return these errors in sequence + errs []error // responses to return if err == nil resps []proto.Message @@ -133,6 +136,9 @@ func (s *inMemDatabaseAdminServer) UpdateDatabaseDdl(ctx context.Context, req *d return nil, fmt.Errorf("x-goog-api-client = %v, expected gl-go key", xg) } s.reqs = append(s.reqs, req) + if err := s.popError(); err != nil { + return nil, err + } if s.err != nil { return nil, s.err } @@ -201,3 +207,20 @@ func (s *inMemDatabaseAdminServer) popOperation() *longrunningpb.Operation { } return op } + +func (s *inMemDatabaseAdminServer) SetErrs(errs []error) { + s.errs = errs +} + +func (s *inMemDatabaseAdminServer) popError() error { + if len(s.errs) == 0 { + return nil + } + err := s.errs[0] + if len(s.errs) > 1 { + s.errs = s.errs[1:] + } else { + s.errs = nil + } + return err +} From 0b2a05e620e28f7df9e8d55f6b45346a513aee50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Fri, 12 Jun 2026 17:40:55 +0200 Subject: [PATCH 4/4] fix: return correct statement count on error --- default_sequence_kind.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/default_sequence_kind.go b/default_sequence_kind.go index cf625c86..60b55719 100644 --- a/default_sequence_kind.go +++ b/default_sequence_kind.go @@ -87,9 +87,9 @@ func (c *conn) executeDDLWithDefaultSequenceKindRetry(ctx context.Context, origi Err: err, BatchUpdateCounts: []int64{}, } - successCount := restartIndex + successCount := getSuccessCount(op) if opRetry != nil { - successCount += getSuccessCount(opRetry) + successCount = restartIndex + getSuccessCount(opRetry) } for i := 0; i < successCount; i++ { be.BatchUpdateCounts = append(be.BatchUpdateCounts, int64(-1))