1+ #!/usr/bin/env python3
2+ # -*- coding: utf-8 -*-
3+ """
4+ Created on Tue Feb 8 21:27:34 2022
5+
6+ @author: lingxiaoli
7+ """
8+
9+ import numpy as np
10+ from scipy .stats import chi2_contingency , ks_2samp , entropy
11+ from scipy .special import softmax
12+
13+
14+ class ChiSquareDrift :
15+
16+ def __init__ (self , x_ref , threshold : float = .05 ):
17+ self .x_ref = self .process_data (x_ref )
18+ self .threshold = threshold
19+
20+ def updata_ref (self , x_ref ):
21+ self .x_ref = self .process_data (x_ref )
22+
23+ def update_threshold (self , threshold ):
24+ self .threshold = threshold
25+
26+ def process_data (self , x ):
27+ margin_width = 0.1
28+ temp = softmax (x , axis = - 1 )
29+ top_2_probs = - np .partition (- temp , kth = 1 , axis = - 1 )[:, :2 ]
30+ diff = top_2_probs [:, 0 ] - top_2_probs [:, 1 ]
31+ x_logist = (diff < margin_width ).astype (int )
32+ return x_logist [:, None ]
33+
34+ def feature_score_Chi (self , x ):
35+ x = self .process_data (x )
36+ vals = [0 , 1 ]
37+ x_ref_count = self .get_counts (self .x_ref , vals )
38+ x_count = self .get_counts (x , vals )
39+ p_val = np .zeros (1 , dtype = np .float32 )
40+ dist = np .zeros_like (p_val )
41+ contingency_table = np .vstack ((x_ref_count , x_count ))
42+ dist , p_val , _ , _ = chi2_contingency (contingency_table )
43+ return p_val , dist
44+
45+ def get_counts (self , x , vals ):
46+ return [(x [:] == v ).sum () for v in vals ]
47+
48+ def get_result (self , x ):
49+ p_vals , dist = self .feature_score_Chi (x )
50+ threshold = self .threshold
51+ drift_pred = int ((p_vals < threshold ).any ()) # type: ignore[assignment]
52+ cd = {}
53+ cd ['is_drift' ] = drift_pred
54+ cd ['p_val' ] = p_vals
55+ cd ['threshold' ] = threshold
56+ cd ['distance' ] = dist
57+ return cd
58+
59+ class KSDrift :
60+
61+ def __init__ (self , x_ref , threshold : float = .05 ):
62+ self .x_ref = entropy (softmax (x_ref , axis = - 1 ), axis = - 1 )
63+ self .threshold = threshold
64+
65+ def updata_ref (self , x_ref ):
66+ self .x_ref = entropy (softmax (x_ref , axis = - 1 ), axis = - 1 )
67+
68+ def update_threshold (self , threshold ):
69+ self .threshold = threshold
70+
71+ def feature_score_KS (self , x ):
72+ x = entropy (softmax (x , axis = - 1 ), axis = - 1 )
73+ p_val = np .zeros (1 , dtype = np .float32 )
74+ dist = np .zeros_like (p_val )
75+ dist , p_val = ks_2samp (self .x_ref , x , alternative = 'two-sided' , mode = 'asymp' )
76+ return p_val , dist
77+
78+ def get_result (self , x ):
79+ p_vals , dist = self .feature_score_KS (x )
80+ threshold = self .threshold
81+ drift_pred = int ((p_vals < threshold ).any ()) # type: ignore[assignment]
82+ cd = {}
83+ cd ['is_drift' ] = drift_pred
84+ cd ['p_val' ] = p_vals
85+ cd ['threshold' ] = threshold
86+ cd ['distance' ] = dist
87+ return cd
88+
89+ def drift_detection (x_ref , threshold : float = .05 , method = 'KSDrift' ):
90+ if method == 'KSDrift' :
91+ return KSDrift (x_ref , threshold )
92+ elif method == "ChiSquareDrift" :
93+ return ChiSquareDrift (x_ref , threshold )
0 commit comments