diff --git a/conn.go b/conn.go index d7d9aafb..d1c4fa75 100644 --- a/conn.go +++ b/conn.go @@ -639,9 +639,10 @@ func (c *conn) runDMLBatch(ctx context.Context) (SpannerResult, error) { statements := c.batch.statements options := c.batch.options - options.QueryOptions.LastStatement = true + localOptions := *options + localOptions.QueryOptions.LastStatement = true c.batch = nil - return c.execBatchDML(ctx, statements, options) + return c.execBatchDML(ctx, statements, &localOptions) } func (c *conn) abortBatch() (driver.Result, error) { @@ -1785,13 +1786,14 @@ func (c *conn) executeAutoPartitionedQuery(ctx context.Context, cancel context.C func queryInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, options *ExecOptions) (rowIterator, *spanner.CommitResponse, error) { var result *wrappedRowIterator - options.QueryOptions.LastStatement = true + queryOptions := options.QueryOptions + queryOptions.LastStatement = true fn := func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { if result != nil { // in case of a retry result.Stop() } - it := tx.QueryWithOptions(ctx, statement, options.QueryOptions) + it := tx.QueryWithOptions(ctx, statement, queryOptions) row, err := it.Next() if err == iterator.Done { result = &wrappedRowIterator{ @@ -1825,10 +1827,11 @@ var errInvalidDmlForExecContext = spanner.ToSpannerError(status.Error(codes.Fail func execInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, options *ExecOptions) (*result, *spanner.CommitResponse, error) { var res *result - options.QueryOptions.LastStatement = true + queryOptions := options.QueryOptions + queryOptions.LastStatement = true fn := func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { var err error - res, err = execTransactionalDML(ctx, tx, statement, statementInfo, options.QueryOptions) + res, err = execTransactionalDML(ctx, tx, statement, statementInfo, queryOptions) if err != nil { return err } diff --git a/stmt_with_mockserver_test.go b/stmt_with_mockserver_test.go index d9ade319..b1bc77d6 100644 --- a/stmt_with_mockserver_test.go +++ b/stmt_with_mockserver_test.go @@ -943,3 +943,83 @@ func executeParamTest(t *testing.T, test paramTest, server *testutil.MockedSpann } } } + +func TestPrepareStmtGORMSharedOptionsBug(t *testing.T) { + t.Parallel() + + db, server, teardown := setupTestDBConnection(t) + defer teardown() + ctx := context.Background() + + // Mock the update query. + query := "update test set value = ? where id = ?" + if err := server.TestSpanner.PutStatementResult( + "update test set value = @p1 where id = @p2", + &testutil.StatementResult{ + Type: testutil.StatementResultUpdateCount, + UpdateCount: 1, + }, + ); err != nil { + t.Fatal(err) + } + + // 1. Prepare statement. + stmt, err := db.PrepareContext(ctx, query) + if err != nil { + t.Fatalf("failed to prepare query: %v", err) + } + defer stmt.Close() + + // 2. Execute statement outside transaction (auto-commit). + _, err = stmt.ExecContext(ctx, "val1", int64(1)) + if err != nil { + t.Fatalf("failed auto-commit execution: %v", err) + } + + // 3. Start explicit transaction. + tx, err := db.BeginTx(ctx, nil) + if err != nil { + t.Fatalf("failed to begin transaction: %v", err) + } + + // 4. Execute the cached statement inside transaction using tx.Stmt(). + txStmt := tx.Stmt(stmt) + _, err = txStmt.ExecContext(ctx, "val2", int64(1)) + if err != nil { + _ = tx.Rollback() + t.Fatalf("failed in-transaction statement execution: %v", err) + } + + // 5. Execute another statement in the same transaction. + _, err = tx.ExecContext(ctx, "update test set value = ? where id = ?", "val3", int64(2)) + if err != nil { + _ = tx.Rollback() + t.Fatalf("failed second in-transaction execution: %v", err) + } + + // 6. Commit transaction. + if err := tx.Commit(); err != nil { + t.Fatalf("failed to commit: %v", err) + } + + // 7. Verify mock server requests. + requests := server.TestSpanner.DrainRequestsFromServer() + executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 3; g != w { + t.Fatalf("number of execute requests mismatch\n Got: %v\nWant: %v", g, w) + } + + req1 := executeRequests[0].(*spannerpb.ExecuteSqlRequest) // auto-commit outside transaction + req2 := executeRequests[1].(*spannerpb.ExecuteSqlRequest) // first DML in transaction (prepared statement) + req3 := executeRequests[2].(*spannerpb.ExecuteSqlRequest) // second DML in transaction + + if !req1.LastStatement { + t.Error("Expected LastStatement=true for auto-commit execution, got false") + } + if req2.LastStatement { + t.Error("Expected LastStatement=false for first in-transaction statement, got true") + } + if req3.LastStatement { + t.Error("Expected LastStatement=false for second in-transaction statement, got true") + } +}