From 9dfec4aa235b61fb7ea213c813aa8efb44e74b58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Thu, 11 Jun 2026 14:32:10 +0200 Subject: [PATCH 1/3] feat(parser): support SHOW TRANSACTION variables and transaction_isolation alias --- client_side_statement_test.go | 137 ++++++++++++++++++++++++++++++++++ conn.go | 16 +++- parser/statements.go | 41 +++++++++- parser/statements_test.go | 111 +++++++++++++++++++++------ statements.go | 8 ++ 5 files changed, 288 insertions(+), 25 deletions(-) diff --git a/client_side_statement_test.go b/client_side_statement_test.go index 2c42f062..5762091e 100644 --- a/client_side_statement_test.go +++ b/client_side_statement_test.go @@ -688,5 +688,142 @@ func TestStatementExecutor_UsesExecOptions(t *testing.T) { if rows.HasNextResultSet() { t.Fatal("got unexpected next result set") } +} + +func TestStatementExecutor_ShowTransaction(t *testing.T) { + t.Parallel() + + p, _ := parser.NewStatementParser(databasepb.DatabaseDialect_POSTGRESQL, 1000) + c := &conn{ + logger: noopLogger, + state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}), + parser: p, + } + ctx := context.Background() + + // Initial checks. + // 1. SHOW TRANSACTION ISOLATION LEVEL should show the default (serializable) + rows, err := c.QueryContext(ctx, "show transaction isolation level", []driver.NamedValue{}) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + values := make([]driver.Value, 1) + if err := rows.Next(values); err != nil { + t.Fatal(err) + } + if got, want := fmt.Sprintf("%v", values[0]), "serializable"; got != want { + t.Errorf("isolation level got %q, want %q", got, want) + } + rows.Close() + + // 2. SHOW transaction_isolation (alias) should show same + rows, err = c.QueryContext(ctx, "show transaction_isolation", []driver.NamedValue{}) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + if err := rows.Next(values); err != nil { + t.Fatal(err) + } + if got, want := fmt.Sprintf("%v", values[0]), "serializable"; got != want { + t.Errorf("transaction_isolation got %q, want %q", got, want) + } + rows.Close() + + // 3. SHOW TRANSACTION READ ONLY should show default (false) + rows, err = c.QueryContext(ctx, "show transaction read only", []driver.NamedValue{}) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + if err := rows.Next(values); err != nil { + t.Fatal(err) + } + if got, want := fmt.Sprintf("%v", values[0]), "false"; got != want { + t.Errorf("read only got %q, want %q", got, want) + } + rows.Close() + + // 4. SHOW TRANSACTION DEFERRABLE should show default (false) + rows, err = c.QueryContext(ctx, "show transaction deferrable", []driver.NamedValue{}) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + if err := rows.Next(values); err != nil { + t.Fatal(err) + } + if got, want := fmt.Sprintf("%v", values[0]), "false"; got != want { + t.Errorf("deferrable got %q, want %q", got, want) + } + rows.Close() + + // Now modify the properties using SET, and verify they show new values! + if _, err := c.ExecContext(ctx, "set isolation_level = 'repeatable_read'", []driver.NamedValue{}); err != nil { + t.Fatal(err) + } + if _, err := c.ExecContext(ctx, "set transaction_read_only = true", []driver.NamedValue{}); err != nil { + t.Fatal(err) + } + if _, err := c.ExecContext(ctx, "set transaction_deferrable = true", []driver.NamedValue{}); err != nil { + t.Fatal(err) + } + + // Verify new values + // 1. SHOW TRANSACTION ISOLATION LEVEL + rows, err = c.QueryContext(ctx, "show transaction isolation level", []driver.NamedValue{}) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + if err := rows.Next(values); err != nil { + t.Fatal(err) + } + if got, want := fmt.Sprintf("%v", values[0]), "repeatable read"; got != want { + t.Errorf("isolation level got %q, want %q", got, want) + } + rows.Close() + + // 2. SHOW transaction_isolation + rows, err = c.QueryContext(ctx, "show transaction_isolation", []driver.NamedValue{}) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + if err := rows.Next(values); err != nil { + t.Fatal(err) + } + if got, want := fmt.Sprintf("%v", values[0]), "repeatable read"; got != want { + t.Errorf("transaction_isolation got %q, want %q", got, want) + } + rows.Close() + + // 3. SHOW TRANSACTION READ ONLY + rows, err = c.QueryContext(ctx, "show transaction read only", []driver.NamedValue{}) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + if err := rows.Next(values); err != nil { + t.Fatal(err) + } + if got, want := fmt.Sprintf("%v", values[0]), "true"; got != want { + t.Errorf("read only got %q, want %q", got, want) + } + rows.Close() + // 4. SHOW TRANSACTION DEFERRABLE + rows, err = c.QueryContext(ctx, "show transaction deferrable", []driver.NamedValue{}) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + if err := rows.Next(values); err != nil { + t.Fatal(err) + } + if got, want := fmt.Sprintf("%v", values[0]), "true"; got != want { + t.Errorf("deferrable got %q, want %q", got, want) + } + rows.Close() } diff --git a/conn.go b/conn.go index d7d9aafb..2f676128 100644 --- a/conn.go +++ b/conn.go @@ -333,7 +333,21 @@ func (c *conn) showConnectionVariable(identifier parser.Identifier) (any, bool, if err != nil { return nil, false, err } - return c.state.GetValue(extension, name) + if extension == "" && name == "transaction_isolation" { + name = "isolation_level" + } + val, hasValue, err := c.state.GetValue(extension, name) + if err != nil { + return nil, false, err + } + if name == "isolation_level" { + if lvl, ok := val.(sql.IsolationLevel); ok && lvl == sql.LevelDefault { + if c.parser.Dialect == adminpb.DatabaseDialect_POSTGRESQL { + val = sql.LevelSerializable + } + } + } + return val, hasValue, nil } func (c *conn) setConnectionVariable(identifier parser.Identifier, value string, local bool, transaction bool, statementScoped bool) error { diff --git a/parser/statements.go b/parser/statements.go index cfa1619f..8bc6ef06 100644 --- a/parser/statements.go +++ b/parser/statements.go @@ -150,11 +150,14 @@ func (s *ParsedShowStatement) Query() string { func (s *ParsedShowStatement) parse(parser *StatementParser, query string) error { // Parse a statement of the form - // SHOW [VARIABLE] [my_extension.]my_property + // SHOW [VARIABLE | TRANSACTION] [my_extension.]my_property sp := &simpleParser{sql: []byte(query), statementParser: parser} if !sp.eatKeyword("SHOW") { return status.Error(codes.InvalidArgument, "statement does not start with SHOW") } + if sp.eatKeyword("TRANSACTION") { + return s.parseShowTransaction(sp, query) + } if parser.Dialect == databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL { // Just eat and ignore the keyword VARIABLE. if !sp.eatKeyword("VARIABLE") { @@ -173,6 +176,42 @@ func (s *ParsedShowStatement) parse(parser *StatementParser, query string) error return nil } +func (s *ParsedShowStatement) parseShowTransaction(sp *simpleParser, query string) error { + if !sp.hasMoreTokens() { + return status.Errorf(codes.InvalidArgument, "syntax error: missing TRANSACTION option, expected one of ISOLATION LEVEL, READ WRITE, or READ ONLY") + } + s.query = query + + if sp.eatKeyword("ISOLATION") { + if !sp.eatKeyword("LEVEL") { + return status.Error(codes.InvalidArgument, "invalid TRANSACTION option, expected ISOLATION LEVEL") + } + s.Identifier = Identifier{Parts: []string{"isolation_level"}} + } else if sp.eatKeyword("READ") { + if sp.eatKeyword("ONLY") { + s.Identifier = Identifier{Parts: []string{"transaction_read_only"}} + } else if sp.eatKeyword("WRITE") { + s.Identifier = Identifier{Parts: []string{"transaction_read_only"}} + } else { + return status.Error(codes.InvalidArgument, "invalid TRANSACTION option, expected READ ONLY or READ WRITE") + } + } else if sp.eatKeyword("DEFERRABLE") { + s.Identifier = Identifier{Parts: []string{"transaction_deferrable"}} + } else if sp.eatKeyword("NOT") { + if !sp.eatKeyword("DEFERRABLE") { + return status.Error(codes.InvalidArgument, "invalid TRANSACTION option, expected NOT DEFERRABLE") + } + s.Identifier = Identifier{Parts: []string{"transaction_deferrable"}} + } else { + return status.Error(codes.InvalidArgument, "invalid TRANSACTION option, expected one of ISOLATION LEVEL, READ WRITE, or READ ONLY") + } + + if sp.hasMoreTokens() { + return status.Errorf(codes.InvalidArgument, "unexpected tokens at position %d in %q", sp.pos, sp.sql) + } + return nil +} + // ParsedSetStatement is a statement of the form // SET [SESSION | LOCAL] [my_extension.]my_property {=|to} // diff --git a/parser/statements_test.go b/parser/statements_test.go index 5879ef41..19a752b1 100644 --- a/parser/statements_test.go +++ b/parser/statements_test.go @@ -26,10 +26,6 @@ import ( func TestParseShowStatement(t *testing.T) { t.Parallel() - parser, err := NewStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000) - if err != nil { - t.Fatal(err) - } type test struct { input string want ParsedShowStatement @@ -37,8 +33,7 @@ func TestParseShowStatement(t *testing.T) { } tests := []test{ { - input: "show my_property", - wantErr: parser.Dialect == databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, + input: "show my_property", want: ParsedShowStatement{ query: "show my_property", Identifier: Identifier{Parts: []string{"my_property"}}, @@ -97,29 +92,95 @@ func TestParseShowStatement(t *testing.T) { // Garbled comment. input: "show variable /*should have been a comment* my_property", wantErr: true, + want: ParsedShowStatement{ + query: "show variable /*should have been a comment* my_property", + Identifier: Identifier{Parts: []string{"variable"}}, + }, + }, + { + input: "show transaction isolation level", + want: ParsedShowStatement{ + query: "show transaction isolation level", + Identifier: Identifier{Parts: []string{"isolation_level"}}, + }, + }, + { + input: "show transaction read only", + want: ParsedShowStatement{ + query: "show transaction read only", + Identifier: Identifier{Parts: []string{"transaction_read_only"}}, + }, + }, + { + input: "show transaction read write", + want: ParsedShowStatement{ + query: "show transaction read write", + Identifier: Identifier{Parts: []string{"transaction_read_only"}}, + }, + }, + { + input: "show transaction deferrable", + want: ParsedShowStatement{ + query: "show transaction deferrable", + Identifier: Identifier{Parts: []string{"transaction_deferrable"}}, + }, + }, + { + input: "show transaction not deferrable", + want: ParsedShowStatement{ + query: "show transaction not deferrable", + Identifier: Identifier{Parts: []string{"transaction_deferrable"}}, + }, + }, + { + input: "show transaction foo", + wantErr: true, }, } keyword := "SHOW" - for _, test := range tests { - t.Run(test.input, func(t *testing.T) { - stmt, err := parseStatement(parser, keyword, test.input) - if test.wantErr { - if err == nil { - t.Fatalf("parseStatement(%q) should have failed", test.input) - } - } else { - if err != nil { - t.Fatal(err) + for _, dialect := range []databasepb.DatabaseDialect{databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, databasepb.DatabaseDialect_POSTGRESQL} { + parser, err := NewStatementParser(dialect, 1000) + if err != nil { + t.Fatal(err) + } + for _, test := range tests { + t.Run(fmt.Sprintf("%s %s", dialect, test.input), func(t *testing.T) { + isGsql := dialect == databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL + normalizedInput := normalizeSpace(strings.ToLower(test.input)) + hasVariableKeyword := strings.Contains(normalizedInput, "show variable") + hasTxKeyword := strings.Contains(normalizedInput, "show transaction") + + wantErr := test.wantErr + if isGsql && !hasVariableKeyword && !hasTxKeyword { + wantErr = true } - showStmt, ok := stmt.(*ParsedShowStatement) - if !ok { - t.Fatalf("parseStatement(%q) should have returned a *ParsedShowStatement", test.input) + if !isGsql && hasVariableKeyword { + if test.input == "show variable /*should have been a comment* my_property" { + wantErr = false + } else { + wantErr = true + } } - if !reflect.DeepEqual(*showStmt, test.want) { - t.Errorf("parseStatement(%q) = %v, want %v", test.input, *showStmt, test.want) + + stmt, err := parseStatement(parser, keyword, test.input) + if wantErr { + if err == nil { + t.Fatalf("parseStatement(%q) should have failed", test.input) + } + } else { + if err != nil { + t.Fatal(err) + } + showStmt, ok := stmt.(*ParsedShowStatement) + if !ok { + t.Fatalf("parseStatement(%q) should have returned a *ParsedShowStatement", test.input) + } + if !reflect.DeepEqual(*showStmt, test.want) { + t.Errorf("parseStatement(%q) = %v, want %v", test.input, *showStmt, test.want) + } } - } - }) + }) + } } } @@ -712,3 +773,7 @@ func TestParseRunPartitionedQuery(t *testing.T) { }) } } + +func normalizeSpace(s string) string { + return strings.Join(strings.Fields(s), " ") +} diff --git a/statements.go b/statements.go index 6fbf5ca1..9b581a37 100644 --- a/statements.go +++ b/statements.go @@ -16,9 +16,11 @@ package spannerdriver import ( "context" + "database/sql" "database/sql/driver" "encoding/json" "fmt" + "strings" "time" "cloud.google.com/go/spanner" @@ -87,6 +89,12 @@ func (s *executableShowStatement) queryContext(ctx context.Context, c *conn, opt it, err = createStringIterator(col, val) case *time.Time: it, err = createTimestampIterator(col, val) + case sql.IsolationLevel: + isolationStr := val.String() + if c.parser.Dialect == databasepb.DatabaseDialect_POSTGRESQL { + isolationStr = strings.ToLower(isolationStr) + } + it, err = createStringIterator(col, isolationStr) default: stringVal := "" if hasValue { From 8094422cf58361c4cc75deb483ab8bc987c55a95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Fri, 12 Jun 2026 15:07:47 +0200 Subject: [PATCH 2/3] chore: cleanup and introduce aliases --- client_side_statement_test.go | 7 +++- conn.go | 16 +------- connection_properties.go | 10 +++++ connectionstate/connection_state.go | 32 ++++++++++++++++ connectionstate/connection_state_test.go | 48 ++++++++++++++++++++++++ driver.go | 2 +- parser/statements.go | 8 ++-- parser/statements_test.go | 7 +--- 8 files changed, 103 insertions(+), 27 deletions(-) diff --git a/client_side_statement_test.go b/client_side_statement_test.go index 5762091e..f92d7468 100644 --- a/client_side_statement_test.go +++ b/client_side_statement_test.go @@ -696,7 +696,7 @@ func TestStatementExecutor_ShowTransaction(t *testing.T) { p, _ := parser.NewStatementParser(databasepb.DatabaseDialect_POSTGRESQL, 1000) c := &conn{ logger: noopLogger, - state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}), + state: createInitialConnectionStateWithDialect(databasepb.DatabaseDialect_POSTGRESQL, connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}), parser: p, } ctx := context.Background() @@ -826,4 +826,9 @@ func TestStatementExecutor_ShowTransaction(t *testing.T) { t.Errorf("deferrable got %q, want %q", got, want) } rows.Close() + + // 5. Try to modify read-only alias variable, should return error + if _, err := c.ExecContext(ctx, "set transaction_isolation = 'repeatable_read'", []driver.NamedValue{}); err == nil { + t.Error("expected error when setting read-only alias transaction_isolation, got nil") + } } diff --git a/conn.go b/conn.go index 2f676128..d7d9aafb 100644 --- a/conn.go +++ b/conn.go @@ -333,21 +333,7 @@ func (c *conn) showConnectionVariable(identifier parser.Identifier) (any, bool, if err != nil { return nil, false, err } - if extension == "" && name == "transaction_isolation" { - name = "isolation_level" - } - val, hasValue, err := c.state.GetValue(extension, name) - if err != nil { - return nil, false, err - } - if name == "isolation_level" { - if lvl, ok := val.(sql.IsolationLevel); ok && lvl == sql.LevelDefault { - if c.parser.Dialect == adminpb.DatabaseDialect_POSTGRESQL { - val = sql.LevelSerializable - } - } - } - return val, hasValue, nil + return c.state.GetValue(extension, name) } func (c *conn) setConnectionVariable(identifier parser.Identifier, value string, local bool, transaction bool, statementScoped bool) error { diff --git a/connection_properties.go b/connection_properties.go index 1a4f5387..ff1458e2 100644 --- a/connection_properties.go +++ b/connection_properties.go @@ -750,6 +750,16 @@ func createConfiguredConnectionState(initialValues map[string]connectionstate.Co } func createInitialConnectionState(connectionStateType connectionstate.Type, initialValues map[string]connectionstate.ConnectionPropertyValue) *connectionstate.ConnectionState { + return createInitialConnectionStateWithDialect(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, connectionStateType, initialValues) +} + +func createInitialConnectionStateWithDialect(dialect databasepb.DatabaseDialect, connectionStateType connectionstate.Type, initialValues map[string]connectionstate.ConnectionPropertyValue) *connectionstate.ConnectionState { state, _ := connectionstate.NewConnectionState(connectionStateType, connectionProperties, initialValues) + if dialect == databasepb.DatabaseDialect_POSTGRESQL { + state.AddAlias("transaction_isolation", "isolation_level", true /* readOnly */) + if val := propertyIsolationLevel.GetValueOrDefault(state); val == sql.LevelDefault { + _ = propertyIsolationLevel.SetValue(state, sql.LevelSerializable, connectionstate.ContextStartup) + } + } return state } diff --git a/connectionstate/connection_state.go b/connectionstate/connection_state.go index 26e43ed6..10434006 100644 --- a/connectionstate/connection_state.go +++ b/connectionstate/connection_state.go @@ -53,6 +53,26 @@ type ConnectionState struct { transactionProperties map[string]ConnectionPropertyValue localProperties map[string]ConnectionPropertyValue statementScopedProperties map[string]ConnectionPropertyValue + aliases map[string]alias +} + +// alias defines an alias for a connection property. +type alias struct { + name string + target string + readOnly bool +} + +// AddAlias adds a new alias for a connection property. +func (cs *ConnectionState) AddAlias(name, target string, readOnly bool) { + if cs.aliases == nil { + cs.aliases = make(map[string]alias) + } + cs.aliases[strings.ToLower(name)] = alias{ + name: name, + target: target, + readOnly: readOnly, + } } // ExtractValues extracts a map of ConnectionPropertyValue from a map of strings. @@ -155,6 +175,13 @@ const ( var errInvalidValueType = status.Error(codes.InvalidArgument, "invalid value type") func (cs *ConnectionState) setValue(extension, name, value string, context Context, valueType valueType) error { + if extension == "" { + if a, ok := cs.aliases[strings.ToLower(name)]; ok { + if a.readOnly { + return status.Errorf(codes.InvalidArgument, "variable %q is read-only", name) + } + } + } prop, err := cs.findProperty(extension, name) if err != nil { return err @@ -207,6 +234,11 @@ func (cs *ConnectionState) setValue(extension, name, value string, context Conte } func (cs *ConnectionState) findProperty(extension, name string) (ConnectionProperty, error) { + if extension == "" { + if a, ok := cs.aliases[strings.ToLower(name)]; ok { + name = a.target + } + } key := toKey(extension, name) var prop ConnectionProperty existingValue, ok := cs.properties[key] diff --git a/connectionstate/connection_state_test.go b/connectionstate/connection_state_test.go index 43dc44ab..e8517e44 100644 --- a/connectionstate/connection_state_test.go +++ b/connectionstate/connection_state_test.go @@ -519,3 +519,51 @@ func TestResetWithInitialValues(t *testing.T) { } } } + +func TestConnectionState_Aliases(t *testing.T) { + t.Parallel() + + properties := map[string]ConnectionProperty{ + "target_prop": &TypedConnectionProperty[string]{ + key: "target_prop", + name: "target_prop", + defaultValue: "default-val", + hasDefaultValue: true, + context: ContextUser, + converter: ConvertString, + }, + } + + state, _ := NewConnectionState(TypeNonTransactional, properties, map[string]ConnectionPropertyValue{}) + state.AddAlias("alias_rw", "target_prop", false) + state.AddAlias("alias_ro", "target_prop", true) + + // 1. Get through alias + val, ok, err := state.GetValue("", "alias_rw") + if err != nil { + t.Fatal(err) + } + if !ok || val != "default-val" { + t.Errorf("GetValue got %v, %t, want 'default-val', true", val, ok) + } + + // 2. Set through read-write alias + if err := state.SetValue("", "alias_rw", "new-val", ContextUser); err != nil { + t.Fatal(err) + } + val, _, _ = state.GetValue("", "target_prop") + if val != "new-val" { + t.Errorf("expected target_prop to be updated to 'new-val', got %v", val) + } + + // 3. Set through read-only alias should fail + if err := state.SetValue("", "alias_ro", "other-val", ContextUser); err == nil { + t.Error("expected error when setting read-only alias, got nil") + } + + // 4. Get through read-only alias should still work and reflect the updated target value + val, _, _ = state.GetValue("", "alias_ro") + if val != "new-val" { + t.Errorf("expected GetValue through read-only alias to return 'new-val', got %v", val) + } +} diff --git a/driver.go b/driver.go index 8d102f36..8ec65c40 100644 --- a/driver.go +++ b/driver.go @@ -863,7 +863,7 @@ func openDriverConn(ctx context.Context, c *connector) (driver.Conn, error) { logger: logger, instance: instanceName, database: databaseName, - state: createInitialConnectionState(connectionStateType, c.initialPropertyValues), + state: createInitialConnectionStateWithDialect(c.parser.Dialect, connectionStateType, c.initialPropertyValues), execSingleQuery: queryInSingleUse, execSingleQueryTransactional: queryInNewRWTransaction, execSingleDMLTransactional: execInNewRWTransaction, diff --git a/parser/statements.go b/parser/statements.go index 8bc6ef06..c98997d7 100644 --- a/parser/statements.go +++ b/parser/statements.go @@ -178,7 +178,7 @@ func (s *ParsedShowStatement) parse(parser *StatementParser, query string) error func (s *ParsedShowStatement) parseShowTransaction(sp *simpleParser, query string) error { if !sp.hasMoreTokens() { - return status.Errorf(codes.InvalidArgument, "syntax error: missing TRANSACTION option, expected one of ISOLATION LEVEL, READ WRITE, or READ ONLY") + return status.Errorf(codes.InvalidArgument, "syntax error: missing TRANSACTION option, expected one of ISOLATION LEVEL, READ ONLY, or [NOT] DEFERRABLE") } s.query = query @@ -190,10 +190,8 @@ func (s *ParsedShowStatement) parseShowTransaction(sp *simpleParser, query strin } else if sp.eatKeyword("READ") { if sp.eatKeyword("ONLY") { s.Identifier = Identifier{Parts: []string{"transaction_read_only"}} - } else if sp.eatKeyword("WRITE") { - s.Identifier = Identifier{Parts: []string{"transaction_read_only"}} } else { - return status.Error(codes.InvalidArgument, "invalid TRANSACTION option, expected READ ONLY or READ WRITE") + return status.Error(codes.InvalidArgument, "invalid TRANSACTION option, expected READ ONLY") } } else if sp.eatKeyword("DEFERRABLE") { s.Identifier = Identifier{Parts: []string{"transaction_deferrable"}} @@ -203,7 +201,7 @@ func (s *ParsedShowStatement) parseShowTransaction(sp *simpleParser, query strin } s.Identifier = Identifier{Parts: []string{"transaction_deferrable"}} } else { - return status.Error(codes.InvalidArgument, "invalid TRANSACTION option, expected one of ISOLATION LEVEL, READ WRITE, or READ ONLY") + return status.Error(codes.InvalidArgument, "invalid TRANSACTION option, expected one of ISOLATION LEVEL, READ ONLY, or [NOT] DEFERRABLE") } if sp.hasMoreTokens() { diff --git a/parser/statements_test.go b/parser/statements_test.go index 19a752b1..11353128 100644 --- a/parser/statements_test.go +++ b/parser/statements_test.go @@ -112,11 +112,8 @@ func TestParseShowStatement(t *testing.T) { }, }, { - input: "show transaction read write", - want: ParsedShowStatement{ - query: "show transaction read write", - Identifier: Identifier{Parts: []string{"transaction_read_only"}}, - }, + input: "show transaction read write", + wantErr: true, }, { input: "show transaction deferrable", From 50ca434a0f3d15065f296180a8bbc1337ebbadd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Fri, 12 Jun 2026 16:06:52 +0200 Subject: [PATCH 3/3] chore: address review comments --- client_side_statement_test.go | 133 ++++++---------------------------- parser/statements_test.go | 96 ++++++++++++------------ 2 files changed, 70 insertions(+), 159 deletions(-) diff --git a/client_side_statement_test.go b/client_side_statement_test.go index f92d7468..8fa70564 100644 --- a/client_side_statement_test.go +++ b/client_side_statement_test.go @@ -701,63 +701,27 @@ func TestStatementExecutor_ShowTransaction(t *testing.T) { } ctx := context.Background() - // Initial checks. - // 1. SHOW TRANSACTION ISOLATION LEVEL should show the default (serializable) - rows, err := c.QueryContext(ctx, "show transaction isolation level", []driver.NamedValue{}) - if err != nil { - t.Fatal(err) - } - defer rows.Close() - values := make([]driver.Value, 1) - if err := rows.Next(values); err != nil { - t.Fatal(err) - } - if got, want := fmt.Sprintf("%v", values[0]), "serializable"; got != want { - t.Errorf("isolation level got %q, want %q", got, want) - } - rows.Close() - - // 2. SHOW transaction_isolation (alias) should show same - rows, err = c.QueryContext(ctx, "show transaction_isolation", []driver.NamedValue{}) - if err != nil { - t.Fatal(err) - } - defer rows.Close() - if err := rows.Next(values); err != nil { - t.Fatal(err) - } - if got, want := fmt.Sprintf("%v", values[0]), "serializable"; got != want { - t.Errorf("transaction_isolation got %q, want %q", got, want) - } - rows.Close() - - // 3. SHOW TRANSACTION READ ONLY should show default (false) - rows, err = c.QueryContext(ctx, "show transaction read only", []driver.NamedValue{}) - if err != nil { - t.Fatal(err) - } - defer rows.Close() - if err := rows.Next(values); err != nil { - t.Fatal(err) - } - if got, want := fmt.Sprintf("%v", values[0]), "false"; got != want { - t.Errorf("read only got %q, want %q", got, want) + checkShowValue := func(t *testing.T, query, want, description string) { + t.Helper() + rows, err := c.QueryContext(ctx, query, []driver.NamedValue{}) + if err != nil { + t.Fatalf("QueryContext for %q failed: %v", query, err) + } + defer rows.Close() + values := make([]driver.Value, 1) + if err := rows.Next(values); err != nil { + t.Fatalf("rows.Next for %q failed: %v", query, err) + } + if got := fmt.Sprintf("%v", values[0]); got != want { + t.Errorf("%s: got %q, want %q", description, got, want) + } } - rows.Close() - // 4. SHOW TRANSACTION DEFERRABLE should show default (false) - rows, err = c.QueryContext(ctx, "show transaction deferrable", []driver.NamedValue{}) - if err != nil { - t.Fatal(err) - } - defer rows.Close() - if err := rows.Next(values); err != nil { - t.Fatal(err) - } - if got, want := fmt.Sprintf("%v", values[0]), "false"; got != want { - t.Errorf("deferrable got %q, want %q", got, want) - } - rows.Close() + // Initial checks. + checkShowValue(t, "show transaction isolation level", "serializable", "default isolation level") + checkShowValue(t, "show transaction_isolation", "serializable", "default transaction_isolation") + checkShowValue(t, "show transaction read only", "false", "default read only") + checkShowValue(t, "show transaction deferrable", "false", "default deferrable") // Now modify the properties using SET, and verify they show new values! if _, err := c.ExecContext(ctx, "set isolation_level = 'repeatable_read'", []driver.NamedValue{}); err != nil { @@ -771,61 +735,10 @@ func TestStatementExecutor_ShowTransaction(t *testing.T) { } // Verify new values - // 1. SHOW TRANSACTION ISOLATION LEVEL - rows, err = c.QueryContext(ctx, "show transaction isolation level", []driver.NamedValue{}) - if err != nil { - t.Fatal(err) - } - defer rows.Close() - if err := rows.Next(values); err != nil { - t.Fatal(err) - } - if got, want := fmt.Sprintf("%v", values[0]), "repeatable read"; got != want { - t.Errorf("isolation level got %q, want %q", got, want) - } - rows.Close() - - // 2. SHOW transaction_isolation - rows, err = c.QueryContext(ctx, "show transaction_isolation", []driver.NamedValue{}) - if err != nil { - t.Fatal(err) - } - defer rows.Close() - if err := rows.Next(values); err != nil { - t.Fatal(err) - } - if got, want := fmt.Sprintf("%v", values[0]), "repeatable read"; got != want { - t.Errorf("transaction_isolation got %q, want %q", got, want) - } - rows.Close() - - // 3. SHOW TRANSACTION READ ONLY - rows, err = c.QueryContext(ctx, "show transaction read only", []driver.NamedValue{}) - if err != nil { - t.Fatal(err) - } - defer rows.Close() - if err := rows.Next(values); err != nil { - t.Fatal(err) - } - if got, want := fmt.Sprintf("%v", values[0]), "true"; got != want { - t.Errorf("read only got %q, want %q", got, want) - } - rows.Close() - - // 4. SHOW TRANSACTION DEFERRABLE - rows, err = c.QueryContext(ctx, "show transaction deferrable", []driver.NamedValue{}) - if err != nil { - t.Fatal(err) - } - defer rows.Close() - if err := rows.Next(values); err != nil { - t.Fatal(err) - } - if got, want := fmt.Sprintf("%v", values[0]), "true"; got != want { - t.Errorf("deferrable got %q, want %q", got, want) - } - rows.Close() + checkShowValue(t, "show transaction isolation level", "repeatable read", "modified isolation level") + checkShowValue(t, "show transaction_isolation", "repeatable read", "modified transaction_isolation") + checkShowValue(t, "show transaction read only", "true", "modified read only") + checkShowValue(t, "show transaction deferrable", "true", "modified deferrable") // 5. Try to modify read-only alias variable, should return error if _, err := c.ExecContext(ctx, "set transaction_isolation = 'repeatable_read'", []driver.NamedValue{}); err == nil { diff --git a/parser/statements_test.go b/parser/statements_test.go index 11353128..62c5ef56 100644 --- a/parser/statements_test.go +++ b/parser/statements_test.go @@ -27,111 +27,123 @@ func TestParseShowStatement(t *testing.T) { t.Parallel() type test struct { - input string - want ParsedShowStatement - wantErr bool + input string + gsqlWant *ParsedShowStatement + pgWant *ParsedShowStatement } tests := []test{ { input: "show my_property", - want: ParsedShowStatement{ + pgWant: &ParsedShowStatement{ query: "show my_property", Identifier: Identifier{Parts: []string{"my_property"}}, }, }, { input: "show variable my_property", - want: ParsedShowStatement{ + gsqlWant: &ParsedShowStatement{ query: "show variable my_property", Identifier: Identifier{Parts: []string{"my_property"}}, }, }, { input: "SHOW variable my_extension.my_property", - want: ParsedShowStatement{ + gsqlWant: &ParsedShowStatement{ query: "SHOW variable my_extension.my_property", Identifier: Identifier{Parts: []string{"my_extension", "my_property"}}, }, }, { input: "show variable my_extension. my_property", - want: ParsedShowStatement{ + gsqlWant: &ParsedShowStatement{ query: "show variable my_extension. my_property", Identifier: Identifier{Parts: []string{"my_extension", "my_property"}}, }, }, { input: "show variable my_extension . my_property", - want: ParsedShowStatement{ + gsqlWant: &ParsedShowStatement{ query: "show variable my_extension . my_property", Identifier: Identifier{Parts: []string{"my_extension", "my_property"}}, }, }, { input: "show variable /*comment*/\n my_extension . my_property \n", - want: ParsedShowStatement{ + gsqlWant: &ParsedShowStatement{ query: "show variable /*comment*/\n my_extension . my_property \n", Identifier: Identifier{Parts: []string{"my_extension", "my_property"}}, }, }, { // Extra tokens after the statement are not allowed. - input: "show variable my_property foo", - wantErr: true, + input: "show variable my_property foo", }, { // Extra tokens after the statement are not allowed. - input: "show variable my_property/", - wantErr: true, + input: "show variable my_property/", }, { - input: "show vraible my_property", - wantErr: true, + input: "show vraible my_property", }, { - // Garbled comment. - input: "show variable /*should have been a comment* my_property", - wantErr: true, - want: ParsedShowStatement{ + // Garbled comment: consumes the rest of the string, causing: + // - GoogleSQL: EOF after VARIABLE keyword (error, nil gsqlWant). + // - PostgreSQL: EOF after variable identifier (success, pgWant with Identifier "variable"). + input: "show variable /*should have been a comment* my_property", + pgWant: &ParsedShowStatement{ query: "show variable /*should have been a comment* my_property", Identifier: Identifier{Parts: []string{"variable"}}, }, }, { input: "show transaction isolation level", - want: ParsedShowStatement{ + gsqlWant: &ParsedShowStatement{ + query: "show transaction isolation level", + Identifier: Identifier{Parts: []string{"isolation_level"}}, + }, + pgWant: &ParsedShowStatement{ query: "show transaction isolation level", Identifier: Identifier{Parts: []string{"isolation_level"}}, }, }, { input: "show transaction read only", - want: ParsedShowStatement{ + gsqlWant: &ParsedShowStatement{ + query: "show transaction read only", + Identifier: Identifier{Parts: []string{"transaction_read_only"}}, + }, + pgWant: &ParsedShowStatement{ query: "show transaction read only", Identifier: Identifier{Parts: []string{"transaction_read_only"}}, }, }, { - input: "show transaction read write", - wantErr: true, + input: "show transaction read write", }, { input: "show transaction deferrable", - want: ParsedShowStatement{ + gsqlWant: &ParsedShowStatement{ + query: "show transaction deferrable", + Identifier: Identifier{Parts: []string{"transaction_deferrable"}}, + }, + pgWant: &ParsedShowStatement{ query: "show transaction deferrable", Identifier: Identifier{Parts: []string{"transaction_deferrable"}}, }, }, { input: "show transaction not deferrable", - want: ParsedShowStatement{ + gsqlWant: &ParsedShowStatement{ + query: "show transaction not deferrable", + Identifier: Identifier{Parts: []string{"transaction_deferrable"}}, + }, + pgWant: &ParsedShowStatement{ query: "show transaction not deferrable", Identifier: Identifier{Parts: []string{"transaction_deferrable"}}, }, }, { - input: "show transaction foo", - wantErr: true, + input: "show transaction foo", }, } keyword := "SHOW" @@ -142,25 +154,15 @@ func TestParseShowStatement(t *testing.T) { } for _, test := range tests { t.Run(fmt.Sprintf("%s %s", dialect, test.input), func(t *testing.T) { - isGsql := dialect == databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL - normalizedInput := normalizeSpace(strings.ToLower(test.input)) - hasVariableKeyword := strings.Contains(normalizedInput, "show variable") - hasTxKeyword := strings.Contains(normalizedInput, "show transaction") - - wantErr := test.wantErr - if isGsql && !hasVariableKeyword && !hasTxKeyword { - wantErr = true - } - if !isGsql && hasVariableKeyword { - if test.input == "show variable /*should have been a comment* my_property" { - wantErr = false - } else { - wantErr = true - } + var want *ParsedShowStatement + if dialect == databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL { + want = test.gsqlWant + } else { + want = test.pgWant } stmt, err := parseStatement(parser, keyword, test.input) - if wantErr { + if want == nil { if err == nil { t.Fatalf("parseStatement(%q) should have failed", test.input) } @@ -172,8 +174,8 @@ func TestParseShowStatement(t *testing.T) { if !ok { t.Fatalf("parseStatement(%q) should have returned a *ParsedShowStatement", test.input) } - if !reflect.DeepEqual(*showStmt, test.want) { - t.Errorf("parseStatement(%q) = %v, want %v", test.input, *showStmt, test.want) + if !reflect.DeepEqual(*showStmt, *want) { + t.Errorf("parseStatement(%q) = %v, want %v", test.input, *showStmt, *want) } } }) @@ -770,7 +772,3 @@ func TestParseRunPartitionedQuery(t *testing.T) { }) } } - -func normalizeSpace(s string) string { - return strings.Join(strings.Fields(s), " ") -}