Skip to content

Commit f9b81bf

Browse files
authored
Add files via upload
1 parent 3afea9f commit f9b81bf

1 file changed

Lines changed: 93 additions & 0 deletions

File tree

drift_detection.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
"""
4+
Created on Tue Feb 8 21:27:34 2022
5+
6+
@author: lingxiaoli
7+
"""
8+
9+
import numpy as np
10+
from scipy.stats import chi2_contingency, ks_2samp, entropy
11+
from scipy.special import softmax
12+
13+
14+
class ChiSquareDrift:
15+
16+
def __init__(self, x_ref, threshold: float = .05):
17+
self.x_ref = self.process_data(x_ref)
18+
self.threshold = threshold
19+
20+
def updata_ref(self, x_ref):
21+
self.x_ref = self.process_data(x_ref)
22+
23+
def update_threshold(self, threshold):
24+
self.threshold = threshold
25+
26+
def process_data(self, x):
27+
margin_width = 0.1
28+
temp = softmax(x, axis=-1)
29+
top_2_probs = -np.partition(-temp, kth=1, axis=-1)[:, :2]
30+
diff = top_2_probs[:, 0] - top_2_probs[:, 1]
31+
x_logist = (diff < margin_width).astype(int)
32+
return x_logist[:, None]
33+
34+
def feature_score_Chi(self, x):
35+
x = self.process_data(x)
36+
vals = [0, 1]
37+
x_ref_count = self.get_counts(self.x_ref, vals)
38+
x_count = self.get_counts(x, vals)
39+
p_val = np.zeros(1, dtype=np.float32)
40+
dist = np.zeros_like(p_val)
41+
contingency_table = np.vstack((x_ref_count, x_count))
42+
dist, p_val, _, _ = chi2_contingency(contingency_table)
43+
return p_val, dist
44+
45+
def get_counts(self, x, vals):
46+
return [(x[:] == v).sum() for v in vals]
47+
48+
def get_result(self, x):
49+
p_vals, dist = self.feature_score_Chi(x)
50+
threshold = self.threshold
51+
drift_pred = int((p_vals < threshold).any()) # type: ignore[assignment]
52+
cd = {}
53+
cd['is_drift'] = drift_pred
54+
cd['p_val'] = p_vals
55+
cd['threshold'] = threshold
56+
cd['distance'] = dist
57+
return cd
58+
59+
class KSDrift:
60+
61+
def __init__(self, x_ref, threshold: float = .05):
62+
self.x_ref = entropy(softmax(x_ref, axis=-1), axis=-1)
63+
self.threshold = threshold
64+
65+
def updata_ref(self, x_ref):
66+
self.x_ref = entropy(softmax(x_ref, axis=-1), axis=-1)
67+
68+
def update_threshold(self, threshold):
69+
self.threshold = threshold
70+
71+
def feature_score_KS(self, x):
72+
x = entropy(softmax(x, axis=-1), axis=-1)
73+
p_val = np.zeros(1, dtype=np.float32)
74+
dist = np.zeros_like(p_val)
75+
dist, p_val = ks_2samp(self.x_ref, x, alternative='two-sided', mode='asymp')
76+
return p_val, dist
77+
78+
def get_result(self, x):
79+
p_vals, dist = self.feature_score_KS(x)
80+
threshold = self.threshold
81+
drift_pred = int((p_vals < threshold).any()) # type: ignore[assignment]
82+
cd = {}
83+
cd['is_drift'] = drift_pred
84+
cd['p_val'] = p_vals
85+
cd['threshold'] = threshold
86+
cd['distance'] = dist
87+
return cd
88+
89+
def drift_detection(x_ref, threshold: float = .05, method='KSDrift'):
90+
if method == 'KSDrift':
91+
return KSDrift(x_ref, threshold)
92+
elif method == "ChiSquareDrift":
93+
return ChiSquareDrift(x_ref, threshold)

0 commit comments

Comments
 (0)