From 3f947933545064420a957f640ff9a83571929a09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Thu, 11 Jun 2026 17:58:11 +0200 Subject: [PATCH 1/2] feat(spannerlib): add Cancel function to allow statement cancellation --- spannerlib/api/connection.go | 71 ++++++++++++++++++++++++++++++- spannerlib/api/connection_test.go | 46 ++++++++++++++++++++ spannerlib/api/rows.go | 5 +++ 3 files changed, 120 insertions(+), 2 deletions(-) diff --git a/spannerlib/api/connection.go b/spannerlib/api/connection.go index 43b56233..e86febf1 100644 --- a/spannerlib/api/connection.go +++ b/spannerlib/api/connection.go @@ -99,6 +99,17 @@ func Rollback(ctx context.Context, poolId, connId int64) error { return conn.rollback(ctx) } +// Cancel cancels the currently running statement on the given connection. +// This function is a no-op if there is no statement running. +func Cancel(poolId, connId int64) error { + conn, err := findConnection(poolId, connId) + if err != nil { + return err + } + conn.Cancel() + return nil +} + func Execute(ctx context.Context, poolId, connId int64, executeSqlRequest *spannerpb.ExecuteSqlRequest) (int64, error) { return ExecuteWithDirectExecuteContext(ctx, nil, poolId, connId, executeSqlRequest) } @@ -126,6 +137,9 @@ type Connection struct { // backend is the database/sql connection of this connection. backend *sql.Conn + + mu sync.Mutex + cancelActive context.CancelFunc } // spannerConn is an internal interface that contains the internal functions that are used by this API. @@ -312,14 +326,66 @@ func (conn *Connection) closeResults(ctx context.Context) { } func (conn *Connection) Execute(ctx, directExecuteContext context.Context, statement *spannerpb.ExecuteSqlRequest) (int64, error) { - return execute(ctx, directExecuteContext, conn, conn.backend, statement) + ctx, cancel := context.WithCancel(ctx) + if err := conn.setActiveCancel(cancel); err != nil { + cancel() + return 0, err + } + defer conn.clearActiveCancel() + + var returnedSuccess bool + defer func() { + if !returnedSuccess { + cancel() + } + }() + + id, err := execute(ctx, directExecuteContext, conn, conn.backend, statement, cancel) + if err != nil { + return 0, err + } + returnedSuccess = true + return id, nil } func (conn *Connection) ExecuteBatch(ctx context.Context, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (*spannerpb.ExecuteBatchDmlResponse, error) { + ctx, cancel := context.WithCancel(ctx) + if err := conn.setActiveCancel(cancel); err != nil { + cancel() + return nil, err + } + defer func() { + conn.clearActiveCancel() + cancel() + }() return executeBatch(ctx, conn, conn.backend, statements) } -func execute(ctx, directExecuteContext context.Context, conn *Connection, executor queryExecutor, statement *spannerpb.ExecuteSqlRequest) (int64, error) { +func (conn *Connection) setActiveCancel(cancel context.CancelFunc) error { + conn.mu.Lock() + defer conn.mu.Unlock() + if conn.cancelActive != nil { + return status.Error(codes.FailedPrecondition, "connection is already executing a statement") + } + conn.cancelActive = cancel + return nil +} + +func (conn *Connection) clearActiveCancel() { + conn.mu.Lock() + defer conn.mu.Unlock() + conn.cancelActive = nil +} + +func (conn *Connection) Cancel() { + conn.mu.Lock() + defer conn.mu.Unlock() + if conn.cancelActive != nil { + conn.cancelActive() + } +} + +func execute(ctx, directExecuteContext context.Context, conn *Connection, executor queryExecutor, statement *spannerpb.ExecuteSqlRequest, cancel context.CancelFunc) (int64, error) { params := extractParams(directExecuteContext, statement) it, err := executor.QueryContext(ctx, statement.Sql, params...) if err != nil { @@ -336,6 +402,7 @@ func execute(ctx, directExecuteContext context.Context, conn *Connection, execut // No rows returned. Read the stats now. _ = res.readStats(ctx) } + res.cancel = cancel conn.results.Store(id, res) return id, nil } diff --git a/spannerlib/api/connection_test.go b/spannerlib/api/connection_test.go index 29e20bd4..c3f4575f 100644 --- a/spannerlib/api/connection_test.go +++ b/spannerlib/api/connection_test.go @@ -19,6 +19,7 @@ import ( "fmt" "reflect" "testing" + "time" "cloud.google.com/go/longrunning/autogen/longrunningpb" "cloud.google.com/go/spanner" @@ -498,3 +499,48 @@ func TestCreateDatabase(t *testing.T) { t.Fatalf("ClosePool returned unexpected error: %v", err) } } + +func TestCancelStatement(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + server.TestSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, testutil.SimulatedExecutionTime{MinimumExecutionTime: 200 * time.Millisecond}) + + poolId, err := CreatePool(ctx, "test", dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + defer ClosePool(ctx, poolId) + + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + defer CloseConnection(ctx, poolId, connId) + + errChan := make(chan error, 1) + go func() { + req := &spannerpb.ExecuteSqlRequest{ + Sql: "SELECT * FROM Singers", + } + _, err := Execute(ctx, poolId, connId, req) + errChan <- err + }() + + time.Sleep(50 * time.Millisecond) + if err := Cancel(poolId, connId); err != nil { + t.Fatalf("Cancel returned unexpected error: %v", err) + } + + err = <-errChan + if err == nil { + t.Fatal("expected statement to be cancelled, but got no error") + } + if g, w := spanner.ErrCode(err), codes.Canceled; g != w { + t.Errorf("error code mismatch\n Got: %v\nWant: %v (error: %v)", g, w, err) + } +} diff --git a/spannerlib/api/rows.go b/spannerlib/api/rows.go index 1935836e..1a95fd0b 100644 --- a/spannerlib/api/rows.go +++ b/spannerlib/api/rows.go @@ -154,9 +154,14 @@ type rows struct { buffer []any values *structpb.ListValue marshalBuffer []byte + + cancel context.CancelFunc } func (rows *rows) Close(ctx context.Context) error { + if rows.cancel != nil { + rows.cancel() + } err := rows.backend.Close() if err != nil { return err From 8497ca7022d44670f6b30b68b3e85e52eae15f47 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Fri, 12 Jun 2026 09:26:25 +0200 Subject: [PATCH 2/2] fix(test): allow Code.Cancelled in TestCloseConnectionWithOpenRows for dotnet wrapper --- .../spannerlib-dotnet/spannerlib-dotnet-tests/RowsTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-tests/RowsTests.cs b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-tests/RowsTests.cs index d7a33685..26a18201 100644 --- a/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-tests/RowsTests.cs +++ b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-tests/RowsTests.cs @@ -316,7 +316,7 @@ public void TestCloseConnectionWithOpenRows([Values] LibType libType) }); // The error is 'Connection not found' or an internal exception from the underlying driver, depending on exactly // when the driver detects that the connection and all related objects have been closed. - Assert.That(exception.Code is Code.NotFound or Code.Unknown, Is.True); + Assert.That(exception.Code is Code.NotFound or Code.Unknown or Code.Cancelled, Is.True); Assert.That(foundRows, Is.LessThan(numRows)); }