diff --git a/driver_with_mockserver_test.go b/driver_with_mockserver_test.go index 6d90afa6..04ad02e7 100644 --- a/driver_with_mockserver_test.go +++ b/driver_with_mockserver_test.go @@ -3839,8 +3839,7 @@ func TestStressClientReuse(t *testing.T) { // Verify that each unique connection string created numSessions (10) sessions on the server. reqs := server.TestSpanner.DrainRequestsFromServer() createReqs := testutil.RequestsOfType(reqs, reflect.TypeOf(&sppb.CreateSessionRequest{})) - // TODO: Fix when the client lib has been fixed to only create max one session per client. - if g, w := len(createReqs), numClients+1; g != w { + if g, w := len(createReqs), numClients; g != w { t.Fatalf("number of CreateSessions mismatch\n Got: %v\nWant: %v", g, w) } sqlReqs := testutil.RequestsOfType(reqs, reflect.TypeOf(&sppb.ExecuteSqlRequest{})) @@ -4571,42 +4570,36 @@ func TestTag_RunTransactionWithOptions_IsNotSticky(t *testing.T) { } func TestMaxIdleConnectionsNonZero(t *testing.T) { - db, server, teardown := setupTestDBConnection(t) - defer teardown() + db, server, teardown := setupTestDBConnectionWithParams(t, "minSessions=0") db.SetMaxIdleConns(2) for i := 0; i < 2; i++ { openAndCloseConn(t, db) } - // Verify that only one client was created. - // This happens because we have a non-zero value for the number of idle connections. + teardown() + requests := server.TestSpanner.DrainRequestsFromServer() - batchRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&sppb.CreateSessionRequest{})) - // TODO: Fix this when the client library has been fixed, so that it only creates one multiplexed - // session per client. - if g, w := len(batchRequests), 2; g != w { - t.Fatalf("CreateSession requests count mismatch\n Got: %v\nWant: %v", g, w) + created := countCreatedSessions(requests) + if g, w := created, 1; g != w { + t.Fatalf("sessions created count mismatch\n Got: %v\nWant: %v", g, w) } } func TestMaxIdleConnectionsZero(t *testing.T) { - db, server, teardown := setupTestDBConnection(t) - defer teardown() + db, server, teardown := setupTestDBConnectionWithParams(t, "minSessions=0") db.SetMaxIdleConns(0) for i := 0; i < 2; i++ { openAndCloseConn(t, db) } - // Verify that two clients were created and closed. - // This should happen because we do not keep any idle connections open. + teardown() + requests := server.TestSpanner.DrainRequestsFromServer() - batchRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&sppb.CreateSessionRequest{})) - // TODO: Fix this when the client library has been fixed, so that it only creates one multiplexed - // session per client. - if g, w := len(batchRequests), 3; g != w { - t.Fatalf("CreateSession requests count mismatch\n Got: %v\nWant: %v", g, w) + created := countCreatedSessions(requests) + if g, w := created, 2; g != w { + t.Fatalf("sessions created count mismatch\n Got: %v\nWant: %v", g, w) } } @@ -5918,7 +5911,8 @@ func setupTestDBConnectionWithParams(t testing.TB, params string) (db *sql.DB, s } func setupTestDBConnectionWithParamsAndDialect(t testing.TB, params string, dialect databasepb.DatabaseDialect) (db *sql.DB, server *testutil.MockedSpannerInMemTestServer, teardown func()) { - server, _, serverTeardown := setupMockedTestServerWithDialect(t, dialect) + server, _, serverTeardown := testutil.NewMockedSpannerInMemTestServer(t) + server.SetupSelectDialectResult(dialect) db, err := sql.Open( "spanner", fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true;%s", server.Address, params)) @@ -6021,3 +6015,16 @@ func filterBeginReadOnlyRequests(requests []interface{}) []*sppb.BeginTransactio } return res } + +func countCreatedSessions(requests []interface{}) int { + count := 0 + for _, r := range requests { + switch req := r.(type) { + case *sppb.CreateSessionRequest: + count++ + case *sppb.BatchCreateSessionsRequest: + count += int(req.SessionCount) + } + } + return count +} diff --git a/testutil/inmem_spanner_server.go b/testutil/inmem_spanner_server.go index 02477b24..11be0c82 100644 --- a/testutil/inmem_spanner_server.go +++ b/testutil/inmem_spanner_server.go @@ -1254,10 +1254,14 @@ func (s *inMemSpannerServer) PartitionRead(ctx context.Context, req *spannerpb.P func (s *inMemSpannerServer) DrainRequestsFromServer() []interface{} { var reqs []interface{} + ch := s.ReceivedRequests() loop: for { select { - case req := <-s.ReceivedRequests(): + case req, ok := <-ch: + if !ok { + break loop + } reqs = append(reqs, req) default: break loop