From 2209fb3b459610a2610540431444526f58cb0aba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Thu, 11 Jun 2026 17:53:05 +0200 Subject: [PATCH 1/3] feat: support SAVEPOINT, ROLLBACK TO SAVEPOINT, and RELEASE SAVEPOINT using emulated savepoints --- conn.go | 16 ++-- parser/statement_parser.go | 30 +++++-- parser/statement_parser_test.go | 42 +++++++++ parser/statements.go | 127 ++++++++++++++++++++++++++- savepoint_test.go | 151 ++++++++++++++++++++++++++++++++ statements.go | 69 +++++++++++++++ transaction.go | 113 +++++++++++++++++++++++- 7 files changed, 531 insertions(+), 17 deletions(-) create mode 100644 savepoint_test.go diff --git a/conn.go b/conn.go index d7d9aafb..40f56cf9 100644 --- a/conn.go +++ b/conn.go @@ -1647,7 +1647,7 @@ func (c *conn) activateTransaction() (contextTransaction, error) { // Add the current value of transaction_timeout to the context that is registered // on the transaction. ctx, cancel := c.addTransactionTimeout(c.tx.ctx) - tx, err := spanner.NewReadWriteStmtBasedTransactionWithCallbackForOptions(ctx, c.client, opts, func() spanner.TransactionOptions { + txOptsCallback := func() spanner.TransactionOptions { defer func() { // Reset the transaction_tag after starting the transaction. _ = propertyTransactionTag.ResetValue(c.state, connectionstate.ContextUser) @@ -1657,17 +1657,21 @@ func (c *conn) activateTransaction() (contextTransaction, error) { execOptions = &ExecOptions{} } return c.effectiveTransactionOptions(spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED, execOptions) - }) + } + tx, err := spanner.NewReadWriteStmtBasedTransactionWithCallbackForOptions(ctx, c.client, opts, txOptsCallback) if err != nil { cancel() return nil, err } logger := c.logger.With("tx", "rw") return &readWriteTransaction{ - ctx: ctx, - conn: c, - logger: logger, - rwTx: tx, + ctx: ctx, + conn: c, + logger: logger, + rwTx: tx, + savepoints: make(map[string]savepoint), + txOptions: opts, + txOptionsCallback: txOptsCallback, close: func(result txResult, commitResponse *spanner.CommitResponse, commitErr error) { c.prevTx = c.tx if commitErr == nil { diff --git a/parser/statement_parser.go b/parser/statement_parser.go index 7d79d71e..4cbc4b15 100644 --- a/parser/statement_parser.go +++ b/parser/statement_parser.go @@ -36,15 +36,17 @@ var updateStatements = map[string]bool{"UPDATE": true} var deleteStatements = map[string]bool{"DELETE": true} var dmlStatements = union(insertStatements, union(updateStatements, deleteStatements)) var clientSideKeywords = map[string]bool{ - "SHOW": true, - "SET": true, - "RESET": true, - "START": true, - "RUN": true, - "ABORT": true, - "BEGIN": true, - "COMMIT": true, - "ROLLBACK": true, + "SHOW": true, + "SET": true, + "RESET": true, + "START": true, + "RUN": true, + "ABORT": true, + "BEGIN": true, + "COMMIT": true, + "ROLLBACK": true, + "SAVEPOINT": true, + "RELEASE": true, } var showStatements = map[string]bool{"SHOW": true} var setStatements = map[string]bool{"SET": true} @@ -55,6 +57,8 @@ var abortStatements = map[string]bool{"ABORT": true} var beginStatements = map[string]bool{"BEGIN": true} var commitStatements = map[string]bool{"COMMIT": true} var rollbackStatements = map[string]bool{"ROLLBACK": true} +var savepointStatements = map[string]bool{"SAVEPOINT": true} +var releaseStatements = map[string]bool{"RELEASE": true} func union(m1 map[string]bool, m2 map[string]bool) map[string]bool { res := make(map[string]bool, len(m1)+len(m2)) @@ -761,6 +765,14 @@ func isRollbackStatementKeyword(keyword string) bool { return isStatementKeyword(keyword, rollbackStatements) } +func isSavepointStatementKeyword(keyword string) bool { + return isStatementKeyword(keyword, savepointStatements) +} + +func isReleaseStatementKeyword(keyword string) bool { + return isStatementKeyword(keyword, releaseStatements) +} + func isStatementKeyword(keyword string, keywords map[string]bool) bool { _, ok := keywords[keyword] return ok diff --git a/parser/statement_parser_test.go b/parser/statement_parser_test.go index b1573851..e4162f39 100644 --- a/parser/statement_parser_test.go +++ b/parser/statement_parser_test.go @@ -1745,6 +1745,48 @@ func TestParseClientSideStatement(t *testing.T) { want: "SET", exec: true, }, + { + name: "Savepoint", + input: "SAVEPOINT s1", + want: "SAVEPOINT", + exec: true, + }, + { + name: "Savepoint quoted", + input: "savepoint `s2`", + want: "SAVEPOINT", + exec: true, + }, + { + name: "Release savepoint", + input: "RELEASE SAVEPOINT s1", + want: "RELEASE SAVEPOINT", + exec: true, + }, + { + name: "Release savepoint no keyword", + input: "release s1", + want: "RELEASE SAVEPOINT", + exec: true, + }, + { + name: "Rollback to savepoint", + input: "ROLLBACK TO SAVEPOINT s1", + want: "ROLLBACK TO SAVEPOINT", + exec: true, + }, + { + name: "Rollback to savepoint no keyword", + input: "rollback to s1", + want: "ROLLBACK TO SAVEPOINT", + exec: true, + }, + { + name: "Rollback transaction to savepoint", + input: "rollback transaction to s1", + want: "ROLLBACK TO SAVEPOINT", + exec: true, + }, } parser, err := NewStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000) diff --git a/parser/statements.go b/parser/statements.go index cfa1619f..07c75efe 100644 --- a/parser/statements.go +++ b/parser/statements.go @@ -55,7 +55,15 @@ func parseStatement(parser *StatementParser, keyword, query string) (ParsedState } else if isCommitStatementKeyword(keyword) { stmt = &ParsedCommitStatement{} } else if isRollbackStatementKeyword(keyword) { - stmt = &ParsedRollbackStatement{} + if isRollbackToSavepoint(parser, query) { + stmt = &ParsedRollbackToSavepointStatement{} + } else { + stmt = &ParsedRollbackStatement{} + } + } else if isSavepointStatementKeyword(keyword) { + stmt = &ParsedSavepointStatement{} + } else if isReleaseStatementKeyword(keyword) { + stmt = &ParsedReleaseSavepointStatement{} } else { return nil, nil } @@ -703,3 +711,120 @@ func (s *ParsedRollbackStatement) parse(parser *StatementParser, query string) e s.query = query return nil } + +func isRollbackToSavepoint(parser *StatementParser, query string) bool { + sp := &simpleParser{sql: []byte(query), statementParser: parser} + if !sp.eatKeyword("ROLLBACK") { + return false + } + if parser.Dialect == databasepb.DatabaseDialect_POSTGRESQL { + if sp.eatKeyword("TRANSACTION") || sp.eatKeyword("WORK") { + // ignore + } + } else { + _ = sp.eatKeyword("TRANSACTION") + } + return sp.eatKeyword("TO") +} + +type ParsedSavepointStatement struct { + query string + SavepointName string +} + +func (s *ParsedSavepointStatement) Name() string { + return "SAVEPOINT" +} + +func (s *ParsedSavepointStatement) Query() string { + return s.query +} + +func (s *ParsedSavepointStatement) parse(parser *StatementParser, query string) error { + sp := &simpleParser{sql: []byte(query), statementParser: parser} + if !sp.eatKeyword("SAVEPOINT") { + return status.Error(codes.InvalidArgument, "statement does not start with SAVEPOINT") + } + name, err := sp.eatIdentifier() + if err != nil { + return err + } + if sp.hasMoreTokens() { + return status.Errorf(codes.InvalidArgument, "unexpected tokens at position %d in %q", sp.pos, sp.sql) + } + s.query = query + s.SavepointName = name.String() + return nil +} + +type ParsedRollbackToSavepointStatement struct { + query string + SavepointName string +} + +func (s *ParsedRollbackToSavepointStatement) Name() string { + return "ROLLBACK TO SAVEPOINT" +} + +func (s *ParsedRollbackToSavepointStatement) Query() string { + return s.query +} + +func (s *ParsedRollbackToSavepointStatement) parse(parser *StatementParser, query string) error { + sp := &simpleParser{sql: []byte(query), statementParser: parser} + if !sp.eatKeyword("ROLLBACK") { + return status.Error(codes.InvalidArgument, "statement does not start with ROLLBACK") + } + if parser.Dialect == databasepb.DatabaseDialect_POSTGRESQL { + if sp.eatKeyword("TRANSACTION") || sp.eatKeyword("WORK") { + // ignore + } + } else { + _ = sp.eatKeyword("TRANSACTION") + } + if !sp.eatKeyword("TO") { + return status.Error(codes.InvalidArgument, "missing TO keyword") + } + _ = sp.eatKeyword("SAVEPOINT") // optional + name, err := sp.eatIdentifier() + if err != nil { + return err + } + if sp.hasMoreTokens() { + return status.Errorf(codes.InvalidArgument, "unexpected tokens at position %d in %q", sp.pos, sp.sql) + } + s.query = query + s.SavepointName = name.String() + return nil +} + +type ParsedReleaseSavepointStatement struct { + query string + SavepointName string +} + +func (s *ParsedReleaseSavepointStatement) Name() string { + return "RELEASE SAVEPOINT" +} + +func (s *ParsedReleaseSavepointStatement) Query() string { + return s.query +} + +func (s *ParsedReleaseSavepointStatement) parse(parser *StatementParser, query string) error { + sp := &simpleParser{sql: []byte(query), statementParser: parser} + if !sp.eatKeyword("RELEASE") { + return status.Error(codes.InvalidArgument, "statement does not start with RELEASE") + } + _ = sp.eatKeyword("SAVEPOINT") // optional + name, err := sp.eatIdentifier() + if err != nil { + return err + } + if sp.hasMoreTokens() { + return status.Errorf(codes.InvalidArgument, "unexpected tokens at position %d in %q", sp.pos, sp.sql) + } + s.query = query + s.SavepointName = name.String() + return nil +} diff --git a/savepoint_test.go b/savepoint_test.go new file mode 100644 index 00000000..f848bd37 --- /dev/null +++ b/savepoint_test.go @@ -0,0 +1,151 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spannerdriver + +import ( + "context" + "database/sql" + "reflect" + "testing" + + "github.com/googleapis/go-sql-spanner/testutil" + spannerpb "google.golang.org/genproto/googleapis/spanner/v1" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestSavepoints(t *testing.T) { + t.Parallel() + + db, server, teardown := setupTestDBConnection(t) + defer teardown() + ctx := context.Background() + + dml1 := "UPDATE Foo SET Val=1 WHERE Id=1" + dml2 := "UPDATE Foo SET Val=2 WHERE Id=2" + server.TestSpanner.PutStatementResult(dml1, &testutil.StatementResult{Type: testutil.StatementResultUpdateCount, UpdateCount: 1}) + server.TestSpanner.PutStatementResult(dml2, &testutil.StatementResult{Type: testutil.StatementResultUpdateCount, UpdateCount: 1}) + + tx, err := db.BeginTx(ctx, &sql.TxOptions{}) + if err != nil { + t.Fatal(err) + } + + if _, err := tx.ExecContext(ctx, dml1); err != nil { + t.Fatalf("dml1 failed: %v", err) + } + + if _, err := tx.ExecContext(ctx, "SAVEPOINT s1"); err != nil { + t.Fatalf("savepoint s1 failed: %v", err) + } + + if _, err := tx.ExecContext(ctx, dml2); err != nil { + t.Fatalf("dml2 failed: %v", err) + } + + if _, err := tx.ExecContext(ctx, "ROLLBACK TO SAVEPOINT s1"); err != nil { + t.Fatalf("rollback to s1 failed: %v", err) + } + + if err := tx.Commit(); err != nil { + t.Fatalf("commit failed: %v", err) + } + + requests := server.TestSpanner.DrainRequestsFromServer() + commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if len(commitRequests) != 1 { + t.Fatalf("expected 1 commit request, got %d", len(commitRequests)) + } + + executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 3; g != w { + t.Fatalf("execute requests count mismatch\n Got: %v\nWant: %v", g, w) + } + + req1 := executeRequests[0].(*spannerpb.ExecuteSqlRequest) + req2 := executeRequests[1].(*spannerpb.ExecuteSqlRequest) + req3 := executeRequests[2].(*spannerpb.ExecuteSqlRequest) + + if req1.Sql != dml1 { + t.Errorf("expected req1 to be dml1, got %q", req1.Sql) + } + if req2.Sql != dml2 { + t.Errorf("expected req2 to be dml2, got %q", req2.Sql) + } + if req3.Sql != dml1 { + t.Errorf("expected req3 to be dml1 (replayed), got %q", req3.Sql) + } + + if req1.Transaction.GetBegin() == nil { + t.Error("expected req1 to begin a transaction") + } + if len(req2.Transaction.GetId()) == 0 || req2.Transaction.GetBegin() != nil { + t.Error("expected req2 to use an existing transaction ID") + } + if req3.Transaction.GetBegin() == nil { + t.Error("expected req3 (replayed) to begin a new transaction") + } +} + +func TestSavepointErrors(t *testing.T) { + t.Parallel() + + db, _, teardown := setupTestDBConnection(t) + defer teardown() + ctx := context.Background() + + // 1. SAVEPOINT outside transaction should fail + if _, err := db.ExecContext(ctx, "SAVEPOINT s1"); err == nil { + t.Error("expected error for SAVEPOINT outside transaction") + } else if g, w := status.Code(err), codes.FailedPrecondition; g != w { + t.Errorf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + + tx, err := db.BeginTx(ctx, &sql.TxOptions{}) + if err != nil { + t.Fatal(err) + } + + // 2. ROLLBACK TO non-existent savepoint should fail + if _, err := tx.ExecContext(ctx, "ROLLBACK TO SAVEPOINT s_none"); err == nil { + t.Error("expected error for rolling back to non-existent savepoint") + } else if g, w := status.Code(err), codes.FailedPrecondition; g != w { + t.Errorf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + + // 3. RELEASE non-existent savepoint should fail + if _, err := tx.ExecContext(ctx, "RELEASE SAVEPOINT s_none"); err == nil { + t.Error("expected error for releasing non-existent savepoint") + } else if g, w := status.Code(err), codes.FailedPrecondition; g != w { + t.Errorf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + + // 4. Create and release should work, then rollback to it should fail + if _, err := tx.ExecContext(ctx, "SAVEPOINT s1"); err != nil { + t.Fatal(err) + } + if _, err := tx.ExecContext(ctx, "RELEASE SAVEPOINT s1"); err != nil { + t.Fatal(err) + } + if _, err := tx.ExecContext(ctx, "ROLLBACK TO SAVEPOINT s1"); err == nil { + t.Error("expected error for rolling back to released savepoint") + } else if g, w := status.Code(err), codes.FailedPrecondition; g != w { + t.Errorf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } +} diff --git a/statements.go b/statements.go index 6fbf5ca1..14b0e3b7 100644 --- a/statements.go +++ b/statements.go @@ -59,6 +59,12 @@ func createExecutableStatement(stmt parser.ParsedStatement) (executableStatement return &executableCommitStatement{stmt: stmt}, nil case *parser.ParsedRollbackStatement: return &executableRollbackStatement{stmt: stmt}, nil + case *parser.ParsedSavepointStatement: + return &executableSavepointStatement{stmt: stmt}, nil + case *parser.ParsedRollbackToSavepointStatement: + return &executableRollbackToSavepointStatement{stmt: stmt}, nil + case *parser.ParsedReleaseSavepointStatement: + return &executableReleaseSavepointStatement{stmt: stmt}, nil } return nil, status.Errorf(codes.Internal, "unsupported statement type: %T", stmt) } @@ -370,3 +376,66 @@ func (s *executableRollbackStatement) queryContext(ctx context.Context, c *conn, } return createEmptyRows(opts), nil } + +type executableSavepointStatement struct { + stmt *parser.ParsedSavepointStatement +} + +func (s *executableSavepointStatement) execContext(ctx context.Context, c *conn, opts *ExecOptions, args []driver.NamedValue) (driver.Result, error) { + if !c.inTransaction() { + return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "SAVEPOINT can only be used in a transaction")) + } + if err := c.tx.Savepoint(s.stmt.SavepointName); err != nil { + return nil, err + } + return driver.ResultNoRows, nil +} + +func (s *executableSavepointStatement) queryContext(ctx context.Context, c *conn, opts *ExecOptions, args []driver.NamedValue) (driver.Rows, error) { + if _, err := s.execContext(ctx, c, opts, args); err != nil { + return nil, err + } + return createEmptyRows(opts), nil +} + +type executableRollbackToSavepointStatement struct { + stmt *parser.ParsedRollbackToSavepointStatement +} + +func (s *executableRollbackToSavepointStatement) execContext(ctx context.Context, c *conn, opts *ExecOptions, args []driver.NamedValue) (driver.Result, error) { + if !c.inTransaction() { + return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "ROLLBACK TO SAVEPOINT can only be used in a transaction")) + } + if err := c.tx.RollbackToSavepoint(ctx, s.stmt.SavepointName); err != nil { + return nil, err + } + return driver.ResultNoRows, nil +} + +func (s *executableRollbackToSavepointStatement) queryContext(ctx context.Context, c *conn, opts *ExecOptions, args []driver.NamedValue) (driver.Rows, error) { + if _, err := s.execContext(ctx, c, opts, args); err != nil { + return nil, err + } + return createEmptyRows(opts), nil +} + +type executableReleaseSavepointStatement struct { + stmt *parser.ParsedReleaseSavepointStatement +} + +func (s *executableReleaseSavepointStatement) execContext(ctx context.Context, c *conn, opts *ExecOptions, args []driver.NamedValue) (driver.Result, error) { + if !c.inTransaction() { + return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "RELEASE SAVEPOINT can only be used in a transaction")) + } + if err := c.tx.ReleaseSavepoint(s.stmt.SavepointName); err != nil { + return nil, err + } + return driver.ResultNoRows, nil +} + +func (s *executableReleaseSavepointStatement) queryContext(ctx context.Context, c *conn, opts *ExecOptions, args []driver.NamedValue) (driver.Rows, error) { + if _, err := s.execContext(ctx, c, opts, args); err != nil { + return nil, err + } + return createEmptyRows(opts), nil +} diff --git a/transaction.go b/transaction.go index c01698f3..e14565b6 100644 --- a/transaction.go +++ b/transaction.go @@ -58,6 +58,10 @@ type contextTransaction interface { IsInBatch() bool BufferWrite(ms []*spanner.Mutation) error + + Savepoint(name string) error + RollbackToSavepoint(ctx context.Context, name string) error + ReleaseSavepoint(name string) error } type rowIterator interface { @@ -168,6 +172,27 @@ func (d *delegatingTransaction) Rollback() error { return d.contextTransaction.Rollback() } +func (d *delegatingTransaction) Savepoint(name string) error { + if err := d.ensureActivated(); err != nil { + return err + } + return d.contextTransaction.Savepoint(name) +} + +func (d *delegatingTransaction) RollbackToSavepoint(ctx context.Context, name string) error { + if err := d.ensureActivated(); err != nil { + return err + } + return d.contextTransaction.RollbackToSavepoint(ctx, name) +} + +func (d *delegatingTransaction) ReleaseSavepoint(name string) error { + if err := d.ensureActivated(); err != nil { + return err + } + return d.contextTransaction.ReleaseSavepoint(name) +} + func (d *delegatingTransaction) resetForRetry(ctx context.Context) error { if d.contextTransaction == nil { return status.Error(codes.FailedPrecondition, "a transaction can only be reset after it has been activated") @@ -371,6 +396,18 @@ func (tx *readOnlyTransaction) BufferWrite([]*spanner.Mutation) error { return spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "read-only transactions cannot write")) } +func (tx *readOnlyTransaction) Savepoint(name string) error { + return nil +} + +func (tx *readOnlyTransaction) RollbackToSavepoint(ctx context.Context, name string) error { + return nil +} + +func (tx *readOnlyTransaction) ReleaseSavepoint(name string) error { + return nil +} + // ErrAbortedDueToConcurrentModification is returned by a read/write transaction // that was aborted by Cloud Spanner, and where the internal retry attempt // failed because it detected that the results during the retry were different @@ -405,7 +442,9 @@ type readWriteTransaction struct { close func(result txResult, commitResponse *spanner.CommitResponse, commitErr error) // retryAborts indicates whether this transaction will automatically retry // the transaction if it is aborted by Spanner. The default is true. - retryAborts func() bool + retryAborts func() bool + txOptions spanner.TransactionOptions + txOptionsCallback func() spanner.TransactionOptions // statements contains the list of statements that has been executed on this // transaction so far. These statements will be replayed on a new read write @@ -415,6 +454,15 @@ type readWriteTransaction struct { // mutations contains the buffered mutations of this transaction. These are // added to the next transaction if the transaction executes an internal retry. mutations []*spanner.Mutation + + // savepoints maps a savepoint name to the number of statements and mutations + // that were executed before the savepoint was created. + savepoints map[string]savepoint +} + +type savepoint struct { + statementCount int + mutationCount int } // retriableStatement is the interface that is used to keep track of statements @@ -810,6 +858,69 @@ func (tx *readWriteTransaction) BufferWrite(ms []*spanner.Mutation) error { return tx.rwTx.BufferWrite(ms) } +func (tx *readWriteTransaction) Savepoint(name string) error { + tx.logger.Debug("creating savepoint", "name", name) + if tx.savepoints == nil { + tx.savepoints = make(map[string]savepoint) + } + tx.savepoints[name] = savepoint{ + statementCount: len(tx.statements), + mutationCount: len(tx.mutations), + } + return nil +} + +func (tx *readWriteTransaction) RollbackToSavepoint(ctx context.Context, name string) error { + tx.logger.Debug("rolling back to savepoint", "name", name) + if tx.savepoints == nil { + return spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "savepoint %q does not exist", name)) + } + sp, ok := tx.savepoints[name] + if !ok { + return spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "savepoint %q does not exist", name)) + } + tx.statements = tx.statements[:sp.statementCount] + tx.mutations = tx.mutations[:sp.mutationCount] + + tx.rwTx.Rollback(context.Background()) + + return tx.recreateAndReplay(ctx) +} + +func (tx *readWriteTransaction) recreateAndReplay(ctx context.Context) (err error) { + tx.logger.Log(ctx, LevelNotice, "starting transaction retry for savepoint rollback") + tx.rwTx, err = spanner.NewReadWriteStmtBasedTransactionWithCallbackForOptions(ctx, tx.conn.client, tx.txOptions, tx.txOptionsCallback) + if err != nil { + tx.logger.Log(ctx, LevelNotice, "failed to recreate transaction") + return err + } + if err := tx.rwTx.BufferWrite(tx.mutations); err != nil { + return err + } + for _, stmt := range tx.statements { + tx.logger.Log(ctx, slog.LevelDebug, "retrying statement", "stmt", stmt) + err = stmt.retry(ctx, tx.rwTx) + if err != nil { + tx.logger.Log(ctx, slog.LevelDebug, "retrying statement failed", "stmt", stmt) + return err + } + } + tx.logger.Log(ctx, LevelNotice, "finished transaction retry for savepoint rollback") + return nil +} + +func (tx *readWriteTransaction) ReleaseSavepoint(name string) error { + tx.logger.Debug("releasing savepoint", "name", name) + if tx.savepoints == nil { + return spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "savepoint %q does not exist", name)) + } + if _, ok := tx.savepoints[name]; !ok { + return spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "savepoint %q does not exist", name)) + } + delete(tx.savepoints, name) + return nil +} + // errorsEqualForRetry returns true if the two errors should be considered equal // when retrying a transaction. This comparison will return true if: // - The errors are the same instances From bfd3ce0efde441086817e2dd18efde2bb0a420f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Fri, 12 Jun 2026 09:21:50 +0200 Subject: [PATCH 2/3] fix(test): resolve staticcheck deprecation warnings for spannerpb --- savepoint_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/savepoint_test.go b/savepoint_test.go index f848bd37..b0f8344e 100644 --- a/savepoint_test.go +++ b/savepoint_test.go @@ -21,7 +21,7 @@ import ( "testing" "github.com/googleapis/go-sql-spanner/testutil" - spannerpb "google.golang.org/genproto/googleapis/spanner/v1" + spannerpb "cloud.google.com/go/spanner/apiv1/spannerpb" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) From 2a99555ddced13e900852bff84aa5933da459ed2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Fri, 12 Jun 2026 09:50:06 +0200 Subject: [PATCH 3/3] fix(test): format imports in savepoint_test.go --- savepoint_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/savepoint_test.go b/savepoint_test.go index b0f8344e..93fc9fc6 100644 --- a/savepoint_test.go +++ b/savepoint_test.go @@ -20,8 +20,8 @@ import ( "reflect" "testing" - "github.com/googleapis/go-sql-spanner/testutil" spannerpb "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/googleapis/go-sql-spanner/testutil" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" )