diff --git a/client_side_statement_test.go b/client_side_statement_test.go index 2c42f062..8fa70564 100644 --- a/client_side_statement_test.go +++ b/client_side_statement_test.go @@ -688,5 +688,60 @@ func TestStatementExecutor_UsesExecOptions(t *testing.T) { if rows.HasNextResultSet() { t.Fatal("got unexpected next result set") } +} + +func TestStatementExecutor_ShowTransaction(t *testing.T) { + t.Parallel() + p, _ := parser.NewStatementParser(databasepb.DatabaseDialect_POSTGRESQL, 1000) + c := &conn{ + logger: noopLogger, + state: createInitialConnectionStateWithDialect(databasepb.DatabaseDialect_POSTGRESQL, connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}), + parser: p, + } + ctx := context.Background() + + checkShowValue := func(t *testing.T, query, want, description string) { + t.Helper() + rows, err := c.QueryContext(ctx, query, []driver.NamedValue{}) + if err != nil { + t.Fatalf("QueryContext for %q failed: %v", query, err) + } + defer rows.Close() + values := make([]driver.Value, 1) + if err := rows.Next(values); err != nil { + t.Fatalf("rows.Next for %q failed: %v", query, err) + } + if got := fmt.Sprintf("%v", values[0]); got != want { + t.Errorf("%s: got %q, want %q", description, got, want) + } + } + + // Initial checks. + checkShowValue(t, "show transaction isolation level", "serializable", "default isolation level") + checkShowValue(t, "show transaction_isolation", "serializable", "default transaction_isolation") + checkShowValue(t, "show transaction read only", "false", "default read only") + checkShowValue(t, "show transaction deferrable", "false", "default deferrable") + + // Now modify the properties using SET, and verify they show new values! + if _, err := c.ExecContext(ctx, "set isolation_level = 'repeatable_read'", []driver.NamedValue{}); err != nil { + t.Fatal(err) + } + if _, err := c.ExecContext(ctx, "set transaction_read_only = true", []driver.NamedValue{}); err != nil { + t.Fatal(err) + } + if _, err := c.ExecContext(ctx, "set transaction_deferrable = true", []driver.NamedValue{}); err != nil { + t.Fatal(err) + } + + // Verify new values + checkShowValue(t, "show transaction isolation level", "repeatable read", "modified isolation level") + checkShowValue(t, "show transaction_isolation", "repeatable read", "modified transaction_isolation") + checkShowValue(t, "show transaction read only", "true", "modified read only") + checkShowValue(t, "show transaction deferrable", "true", "modified deferrable") + + // 5. Try to modify read-only alias variable, should return error + if _, err := c.ExecContext(ctx, "set transaction_isolation = 'repeatable_read'", []driver.NamedValue{}); err == nil { + t.Error("expected error when setting read-only alias transaction_isolation, got nil") + } } diff --git a/connection_properties.go b/connection_properties.go index 1a4f5387..ff1458e2 100644 --- a/connection_properties.go +++ b/connection_properties.go @@ -750,6 +750,16 @@ func createConfiguredConnectionState(initialValues map[string]connectionstate.Co } func createInitialConnectionState(connectionStateType connectionstate.Type, initialValues map[string]connectionstate.ConnectionPropertyValue) *connectionstate.ConnectionState { + return createInitialConnectionStateWithDialect(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, connectionStateType, initialValues) +} + +func createInitialConnectionStateWithDialect(dialect databasepb.DatabaseDialect, connectionStateType connectionstate.Type, initialValues map[string]connectionstate.ConnectionPropertyValue) *connectionstate.ConnectionState { state, _ := connectionstate.NewConnectionState(connectionStateType, connectionProperties, initialValues) + if dialect == databasepb.DatabaseDialect_POSTGRESQL { + state.AddAlias("transaction_isolation", "isolation_level", true /* readOnly */) + if val := propertyIsolationLevel.GetValueOrDefault(state); val == sql.LevelDefault { + _ = propertyIsolationLevel.SetValue(state, sql.LevelSerializable, connectionstate.ContextStartup) + } + } return state } diff --git a/connectionstate/connection_state.go b/connectionstate/connection_state.go index 26e43ed6..10434006 100644 --- a/connectionstate/connection_state.go +++ b/connectionstate/connection_state.go @@ -53,6 +53,26 @@ type ConnectionState struct { transactionProperties map[string]ConnectionPropertyValue localProperties map[string]ConnectionPropertyValue statementScopedProperties map[string]ConnectionPropertyValue + aliases map[string]alias +} + +// alias defines an alias for a connection property. +type alias struct { + name string + target string + readOnly bool +} + +// AddAlias adds a new alias for a connection property. +func (cs *ConnectionState) AddAlias(name, target string, readOnly bool) { + if cs.aliases == nil { + cs.aliases = make(map[string]alias) + } + cs.aliases[strings.ToLower(name)] = alias{ + name: name, + target: target, + readOnly: readOnly, + } } // ExtractValues extracts a map of ConnectionPropertyValue from a map of strings. @@ -155,6 +175,13 @@ const ( var errInvalidValueType = status.Error(codes.InvalidArgument, "invalid value type") func (cs *ConnectionState) setValue(extension, name, value string, context Context, valueType valueType) error { + if extension == "" { + if a, ok := cs.aliases[strings.ToLower(name)]; ok { + if a.readOnly { + return status.Errorf(codes.InvalidArgument, "variable %q is read-only", name) + } + } + } prop, err := cs.findProperty(extension, name) if err != nil { return err @@ -207,6 +234,11 @@ func (cs *ConnectionState) setValue(extension, name, value string, context Conte } func (cs *ConnectionState) findProperty(extension, name string) (ConnectionProperty, error) { + if extension == "" { + if a, ok := cs.aliases[strings.ToLower(name)]; ok { + name = a.target + } + } key := toKey(extension, name) var prop ConnectionProperty existingValue, ok := cs.properties[key] diff --git a/connectionstate/connection_state_test.go b/connectionstate/connection_state_test.go index 43dc44ab..e8517e44 100644 --- a/connectionstate/connection_state_test.go +++ b/connectionstate/connection_state_test.go @@ -519,3 +519,51 @@ func TestResetWithInitialValues(t *testing.T) { } } } + +func TestConnectionState_Aliases(t *testing.T) { + t.Parallel() + + properties := map[string]ConnectionProperty{ + "target_prop": &TypedConnectionProperty[string]{ + key: "target_prop", + name: "target_prop", + defaultValue: "default-val", + hasDefaultValue: true, + context: ContextUser, + converter: ConvertString, + }, + } + + state, _ := NewConnectionState(TypeNonTransactional, properties, map[string]ConnectionPropertyValue{}) + state.AddAlias("alias_rw", "target_prop", false) + state.AddAlias("alias_ro", "target_prop", true) + + // 1. Get through alias + val, ok, err := state.GetValue("", "alias_rw") + if err != nil { + t.Fatal(err) + } + if !ok || val != "default-val" { + t.Errorf("GetValue got %v, %t, want 'default-val', true", val, ok) + } + + // 2. Set through read-write alias + if err := state.SetValue("", "alias_rw", "new-val", ContextUser); err != nil { + t.Fatal(err) + } + val, _, _ = state.GetValue("", "target_prop") + if val != "new-val" { + t.Errorf("expected target_prop to be updated to 'new-val', got %v", val) + } + + // 3. Set through read-only alias should fail + if err := state.SetValue("", "alias_ro", "other-val", ContextUser); err == nil { + t.Error("expected error when setting read-only alias, got nil") + } + + // 4. Get through read-only alias should still work and reflect the updated target value + val, _, _ = state.GetValue("", "alias_ro") + if val != "new-val" { + t.Errorf("expected GetValue through read-only alias to return 'new-val', got %v", val) + } +} diff --git a/driver.go b/driver.go index 8d102f36..8ec65c40 100644 --- a/driver.go +++ b/driver.go @@ -863,7 +863,7 @@ func openDriverConn(ctx context.Context, c *connector) (driver.Conn, error) { logger: logger, instance: instanceName, database: databaseName, - state: createInitialConnectionState(connectionStateType, c.initialPropertyValues), + state: createInitialConnectionStateWithDialect(c.parser.Dialect, connectionStateType, c.initialPropertyValues), execSingleQuery: queryInSingleUse, execSingleQueryTransactional: queryInNewRWTransaction, execSingleDMLTransactional: execInNewRWTransaction, diff --git a/parser/statements.go b/parser/statements.go index cfa1619f..c98997d7 100644 --- a/parser/statements.go +++ b/parser/statements.go @@ -150,11 +150,14 @@ func (s *ParsedShowStatement) Query() string { func (s *ParsedShowStatement) parse(parser *StatementParser, query string) error { // Parse a statement of the form - // SHOW [VARIABLE] [my_extension.]my_property + // SHOW [VARIABLE | TRANSACTION] [my_extension.]my_property sp := &simpleParser{sql: []byte(query), statementParser: parser} if !sp.eatKeyword("SHOW") { return status.Error(codes.InvalidArgument, "statement does not start with SHOW") } + if sp.eatKeyword("TRANSACTION") { + return s.parseShowTransaction(sp, query) + } if parser.Dialect == databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL { // Just eat and ignore the keyword VARIABLE. if !sp.eatKeyword("VARIABLE") { @@ -173,6 +176,40 @@ func (s *ParsedShowStatement) parse(parser *StatementParser, query string) error return nil } +func (s *ParsedShowStatement) parseShowTransaction(sp *simpleParser, query string) error { + if !sp.hasMoreTokens() { + return status.Errorf(codes.InvalidArgument, "syntax error: missing TRANSACTION option, expected one of ISOLATION LEVEL, READ ONLY, or [NOT] DEFERRABLE") + } + s.query = query + + if sp.eatKeyword("ISOLATION") { + if !sp.eatKeyword("LEVEL") { + return status.Error(codes.InvalidArgument, "invalid TRANSACTION option, expected ISOLATION LEVEL") + } + s.Identifier = Identifier{Parts: []string{"isolation_level"}} + } else if sp.eatKeyword("READ") { + if sp.eatKeyword("ONLY") { + s.Identifier = Identifier{Parts: []string{"transaction_read_only"}} + } else { + return status.Error(codes.InvalidArgument, "invalid TRANSACTION option, expected READ ONLY") + } + } else if sp.eatKeyword("DEFERRABLE") { + s.Identifier = Identifier{Parts: []string{"transaction_deferrable"}} + } else if sp.eatKeyword("NOT") { + if !sp.eatKeyword("DEFERRABLE") { + return status.Error(codes.InvalidArgument, "invalid TRANSACTION option, expected NOT DEFERRABLE") + } + s.Identifier = Identifier{Parts: []string{"transaction_deferrable"}} + } else { + return status.Error(codes.InvalidArgument, "invalid TRANSACTION option, expected one of ISOLATION LEVEL, READ ONLY, or [NOT] DEFERRABLE") + } + + if sp.hasMoreTokens() { + return status.Errorf(codes.InvalidArgument, "unexpected tokens at position %d in %q", sp.pos, sp.sql) + } + return nil +} + // ParsedSetStatement is a statement of the form // SET [SESSION | LOCAL] [my_extension.]my_property {=|to} // diff --git a/parser/statements_test.go b/parser/statements_test.go index 5879ef41..62c5ef56 100644 --- a/parser/statements_test.go +++ b/parser/statements_test.go @@ -26,100 +26,160 @@ import ( func TestParseShowStatement(t *testing.T) { t.Parallel() - parser, err := NewStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000) - if err != nil { - t.Fatal(err) - } type test struct { - input string - want ParsedShowStatement - wantErr bool + input string + gsqlWant *ParsedShowStatement + pgWant *ParsedShowStatement } tests := []test{ { - input: "show my_property", - wantErr: parser.Dialect == databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, - want: ParsedShowStatement{ + input: "show my_property", + pgWant: &ParsedShowStatement{ query: "show my_property", Identifier: Identifier{Parts: []string{"my_property"}}, }, }, { input: "show variable my_property", - want: ParsedShowStatement{ + gsqlWant: &ParsedShowStatement{ query: "show variable my_property", Identifier: Identifier{Parts: []string{"my_property"}}, }, }, { input: "SHOW variable my_extension.my_property", - want: ParsedShowStatement{ + gsqlWant: &ParsedShowStatement{ query: "SHOW variable my_extension.my_property", Identifier: Identifier{Parts: []string{"my_extension", "my_property"}}, }, }, { input: "show variable my_extension. my_property", - want: ParsedShowStatement{ + gsqlWant: &ParsedShowStatement{ query: "show variable my_extension. my_property", Identifier: Identifier{Parts: []string{"my_extension", "my_property"}}, }, }, { input: "show variable my_extension . my_property", - want: ParsedShowStatement{ + gsqlWant: &ParsedShowStatement{ query: "show variable my_extension . my_property", Identifier: Identifier{Parts: []string{"my_extension", "my_property"}}, }, }, { input: "show variable /*comment*/\n my_extension . my_property \n", - want: ParsedShowStatement{ + gsqlWant: &ParsedShowStatement{ query: "show variable /*comment*/\n my_extension . my_property \n", Identifier: Identifier{Parts: []string{"my_extension", "my_property"}}, }, }, { // Extra tokens after the statement are not allowed. - input: "show variable my_property foo", - wantErr: true, + input: "show variable my_property foo", }, { // Extra tokens after the statement are not allowed. - input: "show variable my_property/", - wantErr: true, + input: "show variable my_property/", }, { - input: "show vraible my_property", - wantErr: true, + input: "show vraible my_property", }, { - // Garbled comment. - input: "show variable /*should have been a comment* my_property", - wantErr: true, + // Garbled comment: consumes the rest of the string, causing: + // - GoogleSQL: EOF after VARIABLE keyword (error, nil gsqlWant). + // - PostgreSQL: EOF after variable identifier (success, pgWant with Identifier "variable"). + input: "show variable /*should have been a comment* my_property", + pgWant: &ParsedShowStatement{ + query: "show variable /*should have been a comment* my_property", + Identifier: Identifier{Parts: []string{"variable"}}, + }, + }, + { + input: "show transaction isolation level", + gsqlWant: &ParsedShowStatement{ + query: "show transaction isolation level", + Identifier: Identifier{Parts: []string{"isolation_level"}}, + }, + pgWant: &ParsedShowStatement{ + query: "show transaction isolation level", + Identifier: Identifier{Parts: []string{"isolation_level"}}, + }, + }, + { + input: "show transaction read only", + gsqlWant: &ParsedShowStatement{ + query: "show transaction read only", + Identifier: Identifier{Parts: []string{"transaction_read_only"}}, + }, + pgWant: &ParsedShowStatement{ + query: "show transaction read only", + Identifier: Identifier{Parts: []string{"transaction_read_only"}}, + }, + }, + { + input: "show transaction read write", + }, + { + input: "show transaction deferrable", + gsqlWant: &ParsedShowStatement{ + query: "show transaction deferrable", + Identifier: Identifier{Parts: []string{"transaction_deferrable"}}, + }, + pgWant: &ParsedShowStatement{ + query: "show transaction deferrable", + Identifier: Identifier{Parts: []string{"transaction_deferrable"}}, + }, + }, + { + input: "show transaction not deferrable", + gsqlWant: &ParsedShowStatement{ + query: "show transaction not deferrable", + Identifier: Identifier{Parts: []string{"transaction_deferrable"}}, + }, + pgWant: &ParsedShowStatement{ + query: "show transaction not deferrable", + Identifier: Identifier{Parts: []string{"transaction_deferrable"}}, + }, + }, + { + input: "show transaction foo", }, } keyword := "SHOW" - for _, test := range tests { - t.Run(test.input, func(t *testing.T) { - stmt, err := parseStatement(parser, keyword, test.input) - if test.wantErr { - if err == nil { - t.Fatalf("parseStatement(%q) should have failed", test.input) - } - } else { - if err != nil { - t.Fatal(err) - } - showStmt, ok := stmt.(*ParsedShowStatement) - if !ok { - t.Fatalf("parseStatement(%q) should have returned a *ParsedShowStatement", test.input) + for _, dialect := range []databasepb.DatabaseDialect{databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, databasepb.DatabaseDialect_POSTGRESQL} { + parser, err := NewStatementParser(dialect, 1000) + if err != nil { + t.Fatal(err) + } + for _, test := range tests { + t.Run(fmt.Sprintf("%s %s", dialect, test.input), func(t *testing.T) { + var want *ParsedShowStatement + if dialect == databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL { + want = test.gsqlWant + } else { + want = test.pgWant } - if !reflect.DeepEqual(*showStmt, test.want) { - t.Errorf("parseStatement(%q) = %v, want %v", test.input, *showStmt, test.want) + + stmt, err := parseStatement(parser, keyword, test.input) + if want == nil { + if err == nil { + t.Fatalf("parseStatement(%q) should have failed", test.input) + } + } else { + if err != nil { + t.Fatal(err) + } + showStmt, ok := stmt.(*ParsedShowStatement) + if !ok { + t.Fatalf("parseStatement(%q) should have returned a *ParsedShowStatement", test.input) + } + if !reflect.DeepEqual(*showStmt, *want) { + t.Errorf("parseStatement(%q) = %v, want %v", test.input, *showStmt, *want) + } } - } - }) + }) + } } } diff --git a/statements.go b/statements.go index 6fbf5ca1..9b581a37 100644 --- a/statements.go +++ b/statements.go @@ -16,9 +16,11 @@ package spannerdriver import ( "context" + "database/sql" "database/sql/driver" "encoding/json" "fmt" + "strings" "time" "cloud.google.com/go/spanner" @@ -87,6 +89,12 @@ func (s *executableShowStatement) queryContext(ctx context.Context, c *conn, opt it, err = createStringIterator(col, val) case *time.Time: it, err = createTimestampIterator(col, val) + case sql.IsolationLevel: + isolationStr := val.String() + if c.parser.Dialect == databasepb.DatabaseDialect_POSTGRESQL { + isolationStr = strings.ToLower(isolationStr) + } + it, err = createStringIterator(col, isolationStr) default: stringVal := "" if hasValue {