Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -639,9 +639,10 @@ func (c *conn) runDMLBatch(ctx context.Context) (SpannerResult, error) {

statements := c.batch.statements
options := c.batch.options
options.QueryOptions.LastStatement = true
localOptions := *options
localOptions.QueryOptions.LastStatement = true
c.batch = nil
return c.execBatchDML(ctx, statements, options)
return c.execBatchDML(ctx, statements, &localOptions)
}

func (c *conn) abortBatch() (driver.Result, error) {
Expand Down Expand Up @@ -1785,13 +1786,14 @@ func (c *conn) executeAutoPartitionedQuery(ctx context.Context, cancel context.C

func queryInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, options *ExecOptions) (rowIterator, *spanner.CommitResponse, error) {
var result *wrappedRowIterator
options.QueryOptions.LastStatement = true
queryOptions := options.QueryOptions
queryOptions.LastStatement = true
fn := func(ctx context.Context, tx *spanner.ReadWriteTransaction) error {
if result != nil {
// in case of a retry
result.Stop()
}
it := tx.QueryWithOptions(ctx, statement, options.QueryOptions)
it := tx.QueryWithOptions(ctx, statement, queryOptions)
row, err := it.Next()
if err == iterator.Done {
result = &wrappedRowIterator{
Expand Down Expand Up @@ -1825,10 +1827,11 @@ var errInvalidDmlForExecContext = spanner.ToSpannerError(status.Error(codes.Fail

func execInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, options *ExecOptions) (*result, *spanner.CommitResponse, error) {
var res *result
options.QueryOptions.LastStatement = true
queryOptions := options.QueryOptions
queryOptions.LastStatement = true
fn := func(ctx context.Context, tx *spanner.ReadWriteTransaction) error {
var err error
res, err = execTransactionalDML(ctx, tx, statement, statementInfo, options.QueryOptions)
res, err = execTransactionalDML(ctx, tx, statement, statementInfo, queryOptions)
if err != nil {
return err
}
Expand Down
80 changes: 80 additions & 0 deletions stmt_with_mockserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -943,3 +943,83 @@ func executeParamTest(t *testing.T, test paramTest, server *testutil.MockedSpann
}
}
}

func TestPrepareStmtGORMSharedOptionsBug(t *testing.T) {
t.Parallel()

db, server, teardown := setupTestDBConnection(t)
defer teardown()
ctx := context.Background()

// Mock the update query.
query := "update test set value = ? where id = ?"
if err := server.TestSpanner.PutStatementResult(
"update test set value = @p1 where id = @p2",
&testutil.StatementResult{
Type: testutil.StatementResultUpdateCount,
UpdateCount: 1,
},
); err != nil {
t.Fatal(err)
}

// 1. Prepare statement.
stmt, err := db.PrepareContext(ctx, query)
if err != nil {
t.Fatalf("failed to prepare query: %v", err)
}
defer stmt.Close()

// 2. Execute statement outside transaction (auto-commit).
_, err = stmt.ExecContext(ctx, "val1", int64(1))
if err != nil {
t.Fatalf("failed auto-commit execution: %v", err)
}

// 3. Start explicit transaction.
tx, err := db.BeginTx(ctx, nil)
if err != nil {
t.Fatalf("failed to begin transaction: %v", err)
}

// 4. Execute the cached statement inside transaction using tx.Stmt().
txStmt := tx.Stmt(stmt)
_, err = txStmt.ExecContext(ctx, "val2", int64(1))
if err != nil {
_ = tx.Rollback()
t.Fatalf("failed in-transaction statement execution: %v", err)
}

// 5. Execute another statement in the same transaction.
_, err = tx.ExecContext(ctx, "update test set value = ? where id = ?", "val3", int64(2))
if err != nil {
_ = tx.Rollback()
t.Fatalf("failed second in-transaction execution: %v", err)
}

// 6. Commit transaction.
if err := tx.Commit(); err != nil {
t.Fatalf("failed to commit: %v", err)
}

// 7. Verify mock server requests.
requests := server.TestSpanner.DrainRequestsFromServer()
executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{}))
if g, w := len(executeRequests), 3; g != w {
t.Fatalf("number of execute requests mismatch\n Got: %v\nWant: %v", g, w)
}

req1 := executeRequests[0].(*spannerpb.ExecuteSqlRequest) // auto-commit outside transaction
req2 := executeRequests[1].(*spannerpb.ExecuteSqlRequest) // first DML in transaction (prepared statement)
req3 := executeRequests[2].(*spannerpb.ExecuteSqlRequest) // second DML in transaction

if !req1.LastStatement {
t.Error("Expected LastStatement=true for auto-commit execution, got false")
}
if req2.LastStatement {
t.Error("Expected LastStatement=false for first in-transaction statement, got true")
}
if req3.LastStatement {
t.Error("Expected LastStatement=false for second in-transaction statement, got true")
}
}
Loading