Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions client_side_statement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -688,5 +688,60 @@ func TestStatementExecutor_UsesExecOptions(t *testing.T) {
if rows.HasNextResultSet() {
t.Fatal("got unexpected next result set")
}
}

func TestStatementExecutor_ShowTransaction(t *testing.T) {
t.Parallel()

p, _ := parser.NewStatementParser(databasepb.DatabaseDialect_POSTGRESQL, 1000)
c := &conn{
logger: noopLogger,
state: createInitialConnectionStateWithDialect(databasepb.DatabaseDialect_POSTGRESQL, connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}),
parser: p,
}
ctx := context.Background()

checkShowValue := func(t *testing.T, query, want, description string) {
t.Helper()
rows, err := c.QueryContext(ctx, query, []driver.NamedValue{})
if err != nil {
t.Fatalf("QueryContext for %q failed: %v", query, err)
}
defer rows.Close()
values := make([]driver.Value, 1)
if err := rows.Next(values); err != nil {
t.Fatalf("rows.Next for %q failed: %v", query, err)
}
if got := fmt.Sprintf("%v", values[0]); got != want {
t.Errorf("%s: got %q, want %q", description, got, want)
}
}

// Initial checks.
checkShowValue(t, "show transaction isolation level", "serializable", "default isolation level")
checkShowValue(t, "show transaction_isolation", "serializable", "default transaction_isolation")
checkShowValue(t, "show transaction read only", "false", "default read only")
checkShowValue(t, "show transaction deferrable", "false", "default deferrable")

// Now modify the properties using SET, and verify they show new values!
if _, err := c.ExecContext(ctx, "set isolation_level = 'repeatable_read'", []driver.NamedValue{}); err != nil {
t.Fatal(err)
}
if _, err := c.ExecContext(ctx, "set transaction_read_only = true", []driver.NamedValue{}); err != nil {
t.Fatal(err)
}
if _, err := c.ExecContext(ctx, "set transaction_deferrable = true", []driver.NamedValue{}); err != nil {
t.Fatal(err)
}

// Verify new values
checkShowValue(t, "show transaction isolation level", "repeatable read", "modified isolation level")
checkShowValue(t, "show transaction_isolation", "repeatable read", "modified transaction_isolation")
checkShowValue(t, "show transaction read only", "true", "modified read only")
checkShowValue(t, "show transaction deferrable", "true", "modified deferrable")

// 5. Try to modify read-only alias variable, should return error
if _, err := c.ExecContext(ctx, "set transaction_isolation = 'repeatable_read'", []driver.NamedValue{}); err == nil {
t.Error("expected error when setting read-only alias transaction_isolation, got nil")
}
}
Comment thread
olavloite marked this conversation as resolved.
10 changes: 10 additions & 0 deletions connection_properties.go
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,16 @@ func createConfiguredConnectionState(initialValues map[string]connectionstate.Co
}

func createInitialConnectionState(connectionStateType connectionstate.Type, initialValues map[string]connectionstate.ConnectionPropertyValue) *connectionstate.ConnectionState {
return createInitialConnectionStateWithDialect(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, connectionStateType, initialValues)
}

func createInitialConnectionStateWithDialect(dialect databasepb.DatabaseDialect, connectionStateType connectionstate.Type, initialValues map[string]connectionstate.ConnectionPropertyValue) *connectionstate.ConnectionState {
state, _ := connectionstate.NewConnectionState(connectionStateType, connectionProperties, initialValues)
if dialect == databasepb.DatabaseDialect_POSTGRESQL {
state.AddAlias("transaction_isolation", "isolation_level", true /* readOnly */)
if val := propertyIsolationLevel.GetValueOrDefault(state); val == sql.LevelDefault {
_ = propertyIsolationLevel.SetValue(state, sql.LevelSerializable, connectionstate.ContextStartup)
}
Comment thread
olavloite marked this conversation as resolved.
}
return state
}
32 changes: 32 additions & 0 deletions connectionstate/connection_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,26 @@ type ConnectionState struct {
transactionProperties map[string]ConnectionPropertyValue
localProperties map[string]ConnectionPropertyValue
statementScopedProperties map[string]ConnectionPropertyValue
aliases map[string]alias
}

// alias defines an alias for a connection property.
type alias struct {
name string
target string
readOnly bool
}

// AddAlias adds a new alias for a connection property.
func (cs *ConnectionState) AddAlias(name, target string, readOnly bool) {
if cs.aliases == nil {
cs.aliases = make(map[string]alias)
}
cs.aliases[strings.ToLower(name)] = alias{
name: name,
target: target,
readOnly: readOnly,
}
}

// ExtractValues extracts a map of ConnectionPropertyValue from a map of strings.
Expand Down Expand Up @@ -155,6 +175,13 @@ const (
var errInvalidValueType = status.Error(codes.InvalidArgument, "invalid value type")

func (cs *ConnectionState) setValue(extension, name, value string, context Context, valueType valueType) error {
if extension == "" {
if a, ok := cs.aliases[strings.ToLower(name)]; ok {
if a.readOnly {
return status.Errorf(codes.InvalidArgument, "variable %q is read-only", name)
}
}
}
prop, err := cs.findProperty(extension, name)
if err != nil {
return err
Expand Down Expand Up @@ -207,6 +234,11 @@ func (cs *ConnectionState) setValue(extension, name, value string, context Conte
}

func (cs *ConnectionState) findProperty(extension, name string) (ConnectionProperty, error) {
if extension == "" {
if a, ok := cs.aliases[strings.ToLower(name)]; ok {
name = a.target
}
}
key := toKey(extension, name)
var prop ConnectionProperty
existingValue, ok := cs.properties[key]
Expand Down
48 changes: 48 additions & 0 deletions connectionstate/connection_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -519,3 +519,51 @@ func TestResetWithInitialValues(t *testing.T) {
}
}
}

func TestConnectionState_Aliases(t *testing.T) {
t.Parallel()

properties := map[string]ConnectionProperty{
"target_prop": &TypedConnectionProperty[string]{
key: "target_prop",
name: "target_prop",
defaultValue: "default-val",
hasDefaultValue: true,
context: ContextUser,
converter: ConvertString,
},
}

state, _ := NewConnectionState(TypeNonTransactional, properties, map[string]ConnectionPropertyValue{})
state.AddAlias("alias_rw", "target_prop", false)
state.AddAlias("alias_ro", "target_prop", true)

// 1. Get through alias
val, ok, err := state.GetValue("", "alias_rw")
if err != nil {
t.Fatal(err)
}
if !ok || val != "default-val" {
t.Errorf("GetValue got %v, %t, want 'default-val', true", val, ok)
}

// 2. Set through read-write alias
if err := state.SetValue("", "alias_rw", "new-val", ContextUser); err != nil {
t.Fatal(err)
}
val, _, _ = state.GetValue("", "target_prop")
if val != "new-val" {
t.Errorf("expected target_prop to be updated to 'new-val', got %v", val)
}

// 3. Set through read-only alias should fail
if err := state.SetValue("", "alias_ro", "other-val", ContextUser); err == nil {
t.Error("expected error when setting read-only alias, got nil")
}

// 4. Get through read-only alias should still work and reflect the updated target value
val, _, _ = state.GetValue("", "alias_ro")
if val != "new-val" {
t.Errorf("expected GetValue through read-only alias to return 'new-val', got %v", val)
}
}
2 changes: 1 addition & 1 deletion driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,7 @@ func openDriverConn(ctx context.Context, c *connector) (driver.Conn, error) {
logger: logger,
instance: instanceName,
database: databaseName,
state: createInitialConnectionState(connectionStateType, c.initialPropertyValues),
state: createInitialConnectionStateWithDialect(c.parser.Dialect, connectionStateType, c.initialPropertyValues),
execSingleQuery: queryInSingleUse,
execSingleQueryTransactional: queryInNewRWTransaction,
execSingleDMLTransactional: execInNewRWTransaction,
Expand Down
39 changes: 38 additions & 1 deletion parser/statements.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,14 @@ func (s *ParsedShowStatement) Query() string {

func (s *ParsedShowStatement) parse(parser *StatementParser, query string) error {
// Parse a statement of the form
// SHOW [VARIABLE] [my_extension.]my_property
// SHOW [VARIABLE | TRANSACTION] [my_extension.]my_property
sp := &simpleParser{sql: []byte(query), statementParser: parser}
if !sp.eatKeyword("SHOW") {
return status.Error(codes.InvalidArgument, "statement does not start with SHOW")
}
if sp.eatKeyword("TRANSACTION") {
return s.parseShowTransaction(sp, query)
}
if parser.Dialect == databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL {
// Just eat and ignore the keyword VARIABLE.
if !sp.eatKeyword("VARIABLE") {
Expand All @@ -173,6 +176,40 @@ func (s *ParsedShowStatement) parse(parser *StatementParser, query string) error
return nil
}

func (s *ParsedShowStatement) parseShowTransaction(sp *simpleParser, query string) error {
if !sp.hasMoreTokens() {
return status.Errorf(codes.InvalidArgument, "syntax error: missing TRANSACTION option, expected one of ISOLATION LEVEL, READ ONLY, or [NOT] DEFERRABLE")
}
Comment thread
olavloite marked this conversation as resolved.
s.query = query

if sp.eatKeyword("ISOLATION") {
if !sp.eatKeyword("LEVEL") {
return status.Error(codes.InvalidArgument, "invalid TRANSACTION option, expected ISOLATION LEVEL")
}
s.Identifier = Identifier{Parts: []string{"isolation_level"}}
} else if sp.eatKeyword("READ") {
if sp.eatKeyword("ONLY") {
s.Identifier = Identifier{Parts: []string{"transaction_read_only"}}
} else {
return status.Error(codes.InvalidArgument, "invalid TRANSACTION option, expected READ ONLY")
}
Comment thread
olavloite marked this conversation as resolved.
} else if sp.eatKeyword("DEFERRABLE") {
s.Identifier = Identifier{Parts: []string{"transaction_deferrable"}}
} else if sp.eatKeyword("NOT") {
if !sp.eatKeyword("DEFERRABLE") {
return status.Error(codes.InvalidArgument, "invalid TRANSACTION option, expected NOT DEFERRABLE")
}
s.Identifier = Identifier{Parts: []string{"transaction_deferrable"}}
} else {
return status.Error(codes.InvalidArgument, "invalid TRANSACTION option, expected one of ISOLATION LEVEL, READ ONLY, or [NOT] DEFERRABLE")
}
Comment thread
olavloite marked this conversation as resolved.

if sp.hasMoreTokens() {
return status.Errorf(codes.InvalidArgument, "unexpected tokens at position %d in %q", sp.pos, sp.sql)
}
return nil
}

// ParsedSetStatement is a statement of the form
// SET [SESSION | LOCAL] [my_extension.]my_property {=|to} <value>
//
Expand Down
Loading
Loading