Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 69 additions & 2 deletions spannerlib/api/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,17 @@ func Rollback(ctx context.Context, poolId, connId int64) error {
return conn.rollback(ctx)
}

// Cancel cancels the currently running statement on the given connection.
// This function is a no-op if there is no statement running.
func Cancel(poolId, connId int64) error {
conn, err := findConnection(poolId, connId)
if err != nil {
return err
}
conn.Cancel()
return nil
}

func Execute(ctx context.Context, poolId, connId int64, executeSqlRequest *spannerpb.ExecuteSqlRequest) (int64, error) {
return ExecuteWithDirectExecuteContext(ctx, nil, poolId, connId, executeSqlRequest)
}
Expand Down Expand Up @@ -126,6 +137,9 @@ type Connection struct {

// backend is the database/sql connection of this connection.
backend *sql.Conn

mu sync.Mutex
cancelActive context.CancelFunc
}

// spannerConn is an internal interface that contains the internal functions that are used by this API.
Expand Down Expand Up @@ -312,14 +326,66 @@ func (conn *Connection) closeResults(ctx context.Context) {
}

func (conn *Connection) Execute(ctx, directExecuteContext context.Context, statement *spannerpb.ExecuteSqlRequest) (int64, error) {
return execute(ctx, directExecuteContext, conn, conn.backend, statement)
ctx, cancel := context.WithCancel(ctx)
if err := conn.setActiveCancel(cancel); err != nil {
cancel()
return 0, err
}
defer conn.clearActiveCancel()

var returnedSuccess bool
defer func() {
if !returnedSuccess {
cancel()
}
}()

id, err := execute(ctx, directExecuteContext, conn, conn.backend, statement, cancel)
if err != nil {
return 0, err
}
returnedSuccess = true
return id, nil
}

func (conn *Connection) ExecuteBatch(ctx context.Context, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (*spannerpb.ExecuteBatchDmlResponse, error) {
ctx, cancel := context.WithCancel(ctx)
if err := conn.setActiveCancel(cancel); err != nil {
cancel()
return nil, err
}
defer func() {
conn.clearActiveCancel()
cancel()
}()
return executeBatch(ctx, conn, conn.backend, statements)
}

func execute(ctx, directExecuteContext context.Context, conn *Connection, executor queryExecutor, statement *spannerpb.ExecuteSqlRequest) (int64, error) {
func (conn *Connection) setActiveCancel(cancel context.CancelFunc) error {
conn.mu.Lock()
defer conn.mu.Unlock()
if conn.cancelActive != nil {
return status.Error(codes.FailedPrecondition, "connection is already executing a statement")
}
conn.cancelActive = cancel
return nil
}

func (conn *Connection) clearActiveCancel() {
conn.mu.Lock()
defer conn.mu.Unlock()
conn.cancelActive = nil
}

func (conn *Connection) Cancel() {
conn.mu.Lock()
defer conn.mu.Unlock()
if conn.cancelActive != nil {
conn.cancelActive()
}
}
Comment on lines +380 to +386

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Currently, Cancel() only cancels conn.cancelActive, which is cleared via defer conn.clearActiveCancel() as soon as Execute() returns. This means that once the query execution completes and row streaming begins, calling Cancel() will be a no-op and will not cancel the active query/rows context.

To ensure that Cancel() can also cancel active queries during row streaming, we should also iterate over and cancel any active results in conn.results.

func (conn *Connection) Cancel() {
	conn.mu.Lock()
	defer conn.mu.Unlock()
	if conn.cancelActive != nil {
		conn.cancelActive()
	}
	if conn.results != nil {
		conn.results.Range(func(key, value interface{}) bool {
			if r, ok := value.(*rows); ok {
				if r.cancel != nil {
					r.cancel()
				}
			}
			return true
		})
	}
}


func execute(ctx, directExecuteContext context.Context, conn *Connection, executor queryExecutor, statement *spannerpb.ExecuteSqlRequest, cancel context.CancelFunc) (int64, error) {
params := extractParams(directExecuteContext, statement)
it, err := executor.QueryContext(ctx, statement.Sql, params...)
if err != nil {
Expand All @@ -336,6 +402,7 @@ func execute(ctx, directExecuteContext context.Context, conn *Connection, execut
// No rows returned. Read the stats now.
_ = res.readStats(ctx)
}
res.cancel = cancel
conn.results.Store(id, res)
return id, nil
}
Expand Down
46 changes: 46 additions & 0 deletions spannerlib/api/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"fmt"
"reflect"
"testing"
"time"

"cloud.google.com/go/longrunning/autogen/longrunningpb"
"cloud.google.com/go/spanner"
Expand Down Expand Up @@ -498,3 +499,48 @@ func TestCreateDatabase(t *testing.T) {
t.Fatalf("ClosePool returned unexpected error: %v", err)
}
}

func TestCancelStatement(t *testing.T) {
t.Parallel()

ctx := context.Background()
server, teardown := setupMockServer(t)
defer teardown()
dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address)

server.TestSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, testutil.SimulatedExecutionTime{MinimumExecutionTime: 200 * time.Millisecond})

poolId, err := CreatePool(ctx, "test", dsn)
if err != nil {
t.Fatalf("CreatePool returned unexpected error: %v", err)
}
defer ClosePool(ctx, poolId)

connId, err := CreateConnection(ctx, poolId)
if err != nil {
t.Fatalf("CreateConnection returned unexpected error: %v", err)
}
defer CloseConnection(ctx, poolId, connId)

errChan := make(chan error, 1)
go func() {
req := &spannerpb.ExecuteSqlRequest{
Sql: "SELECT * FROM Singers",
}
_, err := Execute(ctx, poolId, connId, req)
errChan <- err
}()

time.Sleep(50 * time.Millisecond)
if err := Cancel(poolId, connId); err != nil {
t.Fatalf("Cancel returned unexpected error: %v", err)
}

err = <-errChan
if err == nil {
t.Fatal("expected statement to be cancelled, but got no error")
}
if g, w := spanner.ErrCode(err), codes.Canceled; g != w {
t.Errorf("error code mismatch\n Got: %v\nWant: %v (error: %v)", g, w, err)
}
}
5 changes: 5 additions & 0 deletions spannerlib/api/rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,14 @@ type rows struct {
buffer []any
values *structpb.ListValue
marshalBuffer []byte

cancel context.CancelFunc
}

func (rows *rows) Close(ctx context.Context) error {
if rows.cancel != nil {
rows.cancel()
}
err := rows.backend.Close()
if err != nil {
return err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ public void TestCloseConnectionWithOpenRows([Values] LibType libType)
});
// The error is 'Connection not found' or an internal exception from the underlying driver, depending on exactly
// when the driver detects that the connection and all related objects have been closed.
Assert.That(exception.Code is Code.NotFound or Code.Unknown, Is.True);
Assert.That(exception.Code is Code.NotFound or Code.Unknown or Code.Cancelled, Is.True);
Assert.That(foundRows, Is.LessThan(numRows));
}

Expand Down
Loading