44
55import numpy as np
66import pandas as pd
7- from multimethod import multimethod
87
98from ydata_profiling .config import Settings
109
1413 from pandas .errors import DataError
1514
1615
16+ class CorrelationBackend :
17+ """Helper class to select and cache the appropriate correlation backend (Pandas or Spark)."""
18+
19+ def __init__ (self , df : Sized ):
20+ """Determine backend once and store it for all correlation computations."""
21+ if isinstance (df , pd .DataFrame ):
22+ from ydata_profiling .model .pandas import (
23+ correlations_pandas as correlation_backend , #type: ignore
24+ )
25+ else :
26+ from ydata_profiling .model .spark import (
27+ correlations_spark as correlation_backend , # type: ignore
28+ )
29+
30+ self .backend = correlation_backend
31+
32+ def get_method (self , method_name : str ):
33+ """Retrieve the appropriate correlation method class from the backend."""
34+ if hasattr (self .backend , method_name ):
35+ return getattr (self .backend , method_name )
36+ raise AttributeError (
37+ f"Correlation method '{ method_name } ' is not available in the backend."
38+ )
39+
40+
1741class Correlation :
18- @staticmethod
19- def compute (config : Settings , df : Sized , summary : dict ) -> Optional [Sized ]:
20- raise NotImplementedError ()
42+ _method_name : str = ""
43+
44+ def compute (
45+ self , config : Settings , df : Sized , summary : dict , backend : CorrelationBackend
46+ ) -> Optional [Sized ]:
47+ """Computes correlation using the correct backend (Pandas or Spark)."""
48+ try :
49+ method = backend .get_method (self ._method_name )
50+ except AttributeError as ex :
51+ raise NotImplementedError () from ex
52+ else :
53+ return method (config , df , summary )
2154
2255
2356class Auto (Correlation ):
24- @staticmethod
25- @multimethod
26- def compute (config : Settings , df : Sized , summary : dict ) -> Optional [Sized ]:
27- raise NotImplementedError ()
57+ """Automatically selects the appropriate correlation method based on the DataFrame type."""
58+
59+ _method_name = "auto_compute"
2860
2961
3062class Spearman (Correlation ):
31- @staticmethod
32- @multimethod
33- def compute (config : Settings , df : Sized , summary : dict ) -> Optional [Sized ]:
34- raise NotImplementedError ()
63+ _method_name = "spearman_compute"
3564
3665
3766class Pearson (Correlation ):
38- @staticmethod
39- @multimethod
40- def compute (config : Settings , df : Sized , summary : dict ) -> Optional [Sized ]:
41- raise NotImplementedError ()
67+ _method_name = "pearson_compute"
4268
4369
4470class Kendall (Correlation ):
45- @staticmethod
46- @multimethod
47- def compute (config : Settings , df : Sized , summary : dict ) -> Optional [Sized ]:
48- raise NotImplementedError ()
71+ _method_name = "kendall_compute"
4972
5073
5174class Cramers (Correlation ):
52- @staticmethod
53- @multimethod
54- def compute (config : Settings , df : Sized , summary : dict ) -> Optional [Sized ]:
55- raise NotImplementedError ()
75+ _method_name = "cramers_compute"
5676
5777
5878class PhiK (Correlation ):
59- @staticmethod
60- @multimethod
61- def compute (config : Settings , df : Sized , summary : dict ) -> Optional [Sized ]:
62- raise NotImplementedError ()
79+ _method_name = "phik_compute"
6380
6481
6582def warn_correlation (correlation_name : str , error : str ) -> None :
@@ -88,6 +105,8 @@ def calculate_correlation(
88105 Returns:
89106 The correlation matrices for the given correlation measures. Return None if correlation is empty.
90107 """
108+ backend = CorrelationBackend (df )
109+
91110 correlation_measures = {
92111 "auto" : Auto ,
93112 "pearson" : Pearson ,
@@ -99,16 +118,13 @@ def calculate_correlation(
99118
100119 correlation = None
101120 try :
102- correlation = correlation_measures [correlation_name ].compute (
103- config , df , summary
121+ correlation = correlation_measures [correlation_name ]() .compute (
122+ config , df , summary , backend
104123 )
105124 except (ValueError , AssertionError , TypeError , DataError , IndexError ) as e :
106125 warn_correlation (correlation_name , str (e ))
107126
108- if correlation is not None and len (correlation ) <= 0 :
109- correlation = None
110-
111- return correlation
127+ return correlation if correlation is not None and len (correlation ) > 0 else None
112128
113129
114130def perform_check_correlation (
0 commit comments