Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1647,7 +1647,7 @@ func (c *conn) activateTransaction() (contextTransaction, error) {
// Add the current value of transaction_timeout to the context that is registered
// on the transaction.
ctx, cancel := c.addTransactionTimeout(c.tx.ctx)
tx, err := spanner.NewReadWriteStmtBasedTransactionWithCallbackForOptions(ctx, c.client, opts, func() spanner.TransactionOptions {
txOptsCallback := func() spanner.TransactionOptions {
defer func() {
// Reset the transaction_tag after starting the transaction.
_ = propertyTransactionTag.ResetValue(c.state, connectionstate.ContextUser)
Expand All @@ -1657,17 +1657,21 @@ func (c *conn) activateTransaction() (contextTransaction, error) {
execOptions = &ExecOptions{}
}
return c.effectiveTransactionOptions(spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED, execOptions)
})
}
tx, err := spanner.NewReadWriteStmtBasedTransactionWithCallbackForOptions(ctx, c.client, opts, txOptsCallback)
if err != nil {
cancel()
return nil, err
}
logger := c.logger.With("tx", "rw")
return &readWriteTransaction{
ctx: ctx,
conn: c,
logger: logger,
rwTx: tx,
ctx: ctx,
conn: c,
logger: logger,
rwTx: tx,
savepoints: make(map[string]savepoint),
txOptions: opts,
txOptionsCallback: txOptsCallback,
close: func(result txResult, commitResponse *spanner.CommitResponse, commitErr error) {
c.prevTx = c.tx
if commitErr == nil {
Expand Down
30 changes: 21 additions & 9 deletions parser/statement_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,17 @@ var updateStatements = map[string]bool{"UPDATE": true}
var deleteStatements = map[string]bool{"DELETE": true}
var dmlStatements = union(insertStatements, union(updateStatements, deleteStatements))
var clientSideKeywords = map[string]bool{
"SHOW": true,
"SET": true,
"RESET": true,
"START": true,
"RUN": true,
"ABORT": true,
"BEGIN": true,
"COMMIT": true,
"ROLLBACK": true,
"SHOW": true,
"SET": true,
"RESET": true,
"START": true,
"RUN": true,
"ABORT": true,
"BEGIN": true,
"COMMIT": true,
"ROLLBACK": true,
"SAVEPOINT": true,
"RELEASE": true,
}
var showStatements = map[string]bool{"SHOW": true}
var setStatements = map[string]bool{"SET": true}
Expand All @@ -55,6 +57,8 @@ var abortStatements = map[string]bool{"ABORT": true}
var beginStatements = map[string]bool{"BEGIN": true}
var commitStatements = map[string]bool{"COMMIT": true}
var rollbackStatements = map[string]bool{"ROLLBACK": true}
var savepointStatements = map[string]bool{"SAVEPOINT": true}
var releaseStatements = map[string]bool{"RELEASE": true}

func union(m1 map[string]bool, m2 map[string]bool) map[string]bool {
res := make(map[string]bool, len(m1)+len(m2))
Expand Down Expand Up @@ -761,6 +765,14 @@ func isRollbackStatementKeyword(keyword string) bool {
return isStatementKeyword(keyword, rollbackStatements)
}

func isSavepointStatementKeyword(keyword string) bool {
return isStatementKeyword(keyword, savepointStatements)
}

func isReleaseStatementKeyword(keyword string) bool {
return isStatementKeyword(keyword, releaseStatements)
}

func isStatementKeyword(keyword string, keywords map[string]bool) bool {
_, ok := keywords[keyword]
return ok
Expand Down
42 changes: 42 additions & 0 deletions parser/statement_parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1745,6 +1745,48 @@ func TestParseClientSideStatement(t *testing.T) {
want: "SET",
exec: true,
},
{
name: "Savepoint",
input: "SAVEPOINT s1",
want: "SAVEPOINT",
exec: true,
},
{
name: "Savepoint quoted",
input: "savepoint `s2`",
want: "SAVEPOINT",
exec: true,
},
{
name: "Release savepoint",
input: "RELEASE SAVEPOINT s1",
want: "RELEASE SAVEPOINT",
exec: true,
},
{
name: "Release savepoint no keyword",
input: "release s1",
want: "RELEASE SAVEPOINT",
exec: true,
},
{
name: "Rollback to savepoint",
input: "ROLLBACK TO SAVEPOINT s1",
want: "ROLLBACK TO SAVEPOINT",
exec: true,
},
{
name: "Rollback to savepoint no keyword",
input: "rollback to s1",
want: "ROLLBACK TO SAVEPOINT",
exec: true,
},
{
name: "Rollback transaction to savepoint",
input: "rollback transaction to s1",
want: "ROLLBACK TO SAVEPOINT",
exec: true,
},
}

parser, err := NewStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000)
Expand Down
127 changes: 126 additions & 1 deletion parser/statements.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,15 @@ func parseStatement(parser *StatementParser, keyword, query string) (ParsedState
} else if isCommitStatementKeyword(keyword) {
stmt = &ParsedCommitStatement{}
} else if isRollbackStatementKeyword(keyword) {
stmt = &ParsedRollbackStatement{}
if isRollbackToSavepoint(parser, query) {
stmt = &ParsedRollbackToSavepointStatement{}
} else {
stmt = &ParsedRollbackStatement{}
}
} else if isSavepointStatementKeyword(keyword) {
stmt = &ParsedSavepointStatement{}
} else if isReleaseStatementKeyword(keyword) {
stmt = &ParsedReleaseSavepointStatement{}
} else {
return nil, nil
}
Expand Down Expand Up @@ -703,3 +711,120 @@ func (s *ParsedRollbackStatement) parse(parser *StatementParser, query string) e
s.query = query
return nil
}

func isRollbackToSavepoint(parser *StatementParser, query string) bool {
sp := &simpleParser{sql: []byte(query), statementParser: parser}
if !sp.eatKeyword("ROLLBACK") {
return false
}
if parser.Dialect == databasepb.DatabaseDialect_POSTGRESQL {
if sp.eatKeyword("TRANSACTION") || sp.eatKeyword("WORK") {
// ignore
}
} else {
_ = sp.eatKeyword("TRANSACTION")
}
return sp.eatKeyword("TO")
}

type ParsedSavepointStatement struct {
query string
SavepointName string
}

func (s *ParsedSavepointStatement) Name() string {
return "SAVEPOINT"
}

func (s *ParsedSavepointStatement) Query() string {
return s.query
}

func (s *ParsedSavepointStatement) parse(parser *StatementParser, query string) error {
sp := &simpleParser{sql: []byte(query), statementParser: parser}
if !sp.eatKeyword("SAVEPOINT") {
return status.Error(codes.InvalidArgument, "statement does not start with SAVEPOINT")
}
name, err := sp.eatIdentifier()
if err != nil {
return err
}
if sp.hasMoreTokens() {
return status.Errorf(codes.InvalidArgument, "unexpected tokens at position %d in %q", sp.pos, sp.sql)
}
s.query = query
s.SavepointName = name.String()
return nil
}

type ParsedRollbackToSavepointStatement struct {
query string
SavepointName string
}

func (s *ParsedRollbackToSavepointStatement) Name() string {
return "ROLLBACK TO SAVEPOINT"
}

func (s *ParsedRollbackToSavepointStatement) Query() string {
return s.query
}

func (s *ParsedRollbackToSavepointStatement) parse(parser *StatementParser, query string) error {
sp := &simpleParser{sql: []byte(query), statementParser: parser}
if !sp.eatKeyword("ROLLBACK") {
return status.Error(codes.InvalidArgument, "statement does not start with ROLLBACK")
}
if parser.Dialect == databasepb.DatabaseDialect_POSTGRESQL {
if sp.eatKeyword("TRANSACTION") || sp.eatKeyword("WORK") {
// ignore
}
} else {
_ = sp.eatKeyword("TRANSACTION")
}
if !sp.eatKeyword("TO") {
return status.Error(codes.InvalidArgument, "missing TO keyword")
}
_ = sp.eatKeyword("SAVEPOINT") // optional
name, err := sp.eatIdentifier()
if err != nil {
return err
}
if sp.hasMoreTokens() {
return status.Errorf(codes.InvalidArgument, "unexpected tokens at position %d in %q", sp.pos, sp.sql)
}
s.query = query
s.SavepointName = name.String()
return nil
}

type ParsedReleaseSavepointStatement struct {
query string
SavepointName string
}

func (s *ParsedReleaseSavepointStatement) Name() string {
return "RELEASE SAVEPOINT"
}

func (s *ParsedReleaseSavepointStatement) Query() string {
return s.query
}

func (s *ParsedReleaseSavepointStatement) parse(parser *StatementParser, query string) error {
sp := &simpleParser{sql: []byte(query), statementParser: parser}
if !sp.eatKeyword("RELEASE") {
return status.Error(codes.InvalidArgument, "statement does not start with RELEASE")
}
_ = sp.eatKeyword("SAVEPOINT") // optional
name, err := sp.eatIdentifier()
if err != nil {
return err
}
if sp.hasMoreTokens() {
return status.Errorf(codes.InvalidArgument, "unexpected tokens at position %d in %q", sp.pos, sp.sql)
}
s.query = query
s.SavepointName = name.String()
return nil
}
Loading
Loading