diff --git a/driver.go b/driver.go index 8d102f36..a3fc5dad 100644 --- a/driver.go +++ b/driver.go @@ -1158,6 +1158,9 @@ func runTransactionWithOptions(ctx context.Context, db *sql.DB, opts *sql.TxOpti if err == nil { err = tx.Commit() if err == nil { + if !isSpannerConn || (opts != nil && opts.ReadOnly) { + return nil, nil + } resp, err := getCommitResponse(conn) if err != nil { return nil, err diff --git a/driver_non_spanner_test.go b/driver_non_spanner_test.go new file mode 100644 index 00000000..6cffe3fe --- /dev/null +++ b/driver_non_spanner_test.go @@ -0,0 +1,70 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spannerdriver + +import ( + "context" + "database/sql" + "database/sql/driver" + "testing" + + "cloud.google.com/go/spanner" +) + +func init() { + sql.Register("fake", &fakeDriver{}) +} + +type fakeDriver struct{} + +func (d *fakeDriver) Open(name string) (driver.Conn, error) { return &fakeDriverConn{}, nil } + +type fakeDriverConn struct{} + +func (c *fakeDriverConn) Prepare(query string) (driver.Stmt, error) { return &fakeStmt{}, nil } +func (c *fakeDriverConn) Close() error { return nil } +func (c *fakeDriverConn) Begin() (driver.Tx, error) { return &fakeTx{}, nil } + +type fakeStmt struct{} + +func (s *fakeStmt) Close() error { return nil } +func (s *fakeStmt) NumInput() int { return 0 } +func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { return nil, nil } +func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { return nil, nil } + +type fakeTx struct{} + +func (t *fakeTx) Commit() error { return nil } +func (t *fakeTx) Rollback() error { return nil } + +func TestRunTransaction_NonSpannerConnection(t *testing.T) { + db, err := sql.Open("fake", "any-dsn") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + ctx := context.Background() + // RunTransaction should succeed and return nil, nil commit response for non-Spanner connections + resp, err := RunTransactionWithCommitResponse(ctx, db, nil, func(ctx context.Context, tx *sql.Tx) error { + return nil + }, spanner.TransactionOptions{}) + if err != nil { + t.Fatalf("RunTransactionWithCommitResponse failed: %v", err) + } + if resp != nil { + t.Errorf("expected nil commit response, got: %v", resp) + } +} diff --git a/driver_with_mockserver_test.go b/driver_with_mockserver_test.go index 6d90afa6..b4f466d6 100644 --- a/driver_with_mockserver_test.go +++ b/driver_with_mockserver_test.go @@ -4765,6 +4765,51 @@ func TestRunTransaction(t *testing.T) { } } +func TestRunTransactionReadOnly(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db, server, teardown := setupTestDBConnection(t) + defer teardown() + + err := RunTransaction(ctx, db, &sql.TxOptions{ReadOnly: true}, func(ctx context.Context, tx *sql.Tx) error { + rows, err := tx.Query(testutil.SelectFooFromBar) + if err != nil { + return err + } + defer silentClose(rows) + for want := int64(1); rows.Next(); want++ { + var got int64 + if err := rows.Scan(&got); err != nil { + return err + } + if got != want { + return fmt.Errorf("value mismatch\nGot: %v\nWant: %v", got, want) + } + } + return rows.Err() + }) + if err != nil { + t.Fatal(err) + } + + requests := server.TestSpanner.DrainRequestsFromServer() + sqlRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{})) + if g, w := len(sqlRequests), 1; g != w { + t.Fatalf("ExecuteSqlRequests count mismatch\nGot: %v\nWant: %v", g, w) + } + req := sqlRequests[0].(*sppb.ExecuteSqlRequest) + if req.Transaction == nil { + t.Fatalf("missing transaction for ExecuteSqlRequest") + } + + // Verify that NO commit request was sent + commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{})) + if g, w := len(commitRequests), 0; g != w { + t.Fatalf("commit requests count mismatch\nGot: %v\nWant: %v", g, w) + } +} + func TestRunTransactionCommitAborted(t *testing.T) { t.Parallel()