Skip to content

Commit b78b32e

Browse files
authored
Merge pull request #72 from fzi-forschungszentrum-informatik/Debug_SETS
Debug sets
2 parents 4a553dd + 2bfd101 commit b78b32e

File tree

10 files changed

+1411
-125
lines changed

10 files changed

+1411
-125
lines changed

TSInterpret/InterpretabilityModels/counterfactual/SETS/ContractedST.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -662,8 +662,8 @@ def transform(self, X, y=None):
662662
shapelet_distances.append(dist)
663663

664664
min_dist = min(min_dist, dist)
665-
666-
output[i][s] = dist
665+
#TODO THIS WAS CHANGED
666+
output[i][s] = min_dist
667667

668668
self.shapelets[s].distances[i] = np.asarray(shapelet_distances)
669669

TSInterpret/InterpretabilityModels/counterfactual/SETS/sets.py

+37-19
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,13 @@
1313
get_all_shapelet_locations_scaled_threshold,
1414
get_all_shapelet_locations_scaled_threshold_test,
1515
get_nearest_neighbor,
16-
get_shapelets_locations_test,
16+
get_shapelets_locations_test,get_shapelets_distances
1717
)
1818

1919

2020
# cast to tf format
2121
def to_tff(x):
22-
return np.expand_dims(np.swapaxes(x, 0, 1), axis=0)
23-
22+
return np.expand_dims(np.swapaxes(x, 0, 1), axis=0)
2423

2524
def fit_shapelets(
2625
data,
@@ -34,6 +33,7 @@ def fit_shapelets(
3433
random.seed(random_seed)
3534
X_train, y_train = data
3635

36+
3737
# make deep copy for reusability
3838
fitted_shapelets = copy.deepcopy(st_shapelets)
3939

@@ -51,6 +51,7 @@ def fit_shapelets(
5151
all_heat_maps = {}
5252

5353
for c in np.unique(y_train):
54+
#print(c)
5455
all_shapelets_class[c] = []
5556
all_heat_maps[c] = []
5657

@@ -156,12 +157,21 @@ def sets_explain(
156157
from_3d_numpy_to_nested(np.expand_dims(instance_x, axis=0))
157158
)
158159

160+
shapelet_dist=[]
161+
for st in transformer.sts:
162+
save=[]
163+
for shp in st.shapelets:
164+
save.append(shp.distances)
165+
shapelet_dist.append(save)
166+
167+
shapelets_distances_test=shapelet_dist
159168
all_shapelet_locations_test, _ = get_all_shapelet_locations_scaled_threshold_test(
160-
[np.expand_dims(shapelets_distances_test, axis=0)],
161-
instance_x.shape[1],
162-
threshhold,
169+
shapelets_distances_test,
170+
ts_length,
171+
threshhold
163172
)
164173

174+
165175
# Sort dimensions by their highest shapelet scores
166176
shapelets_best_scores = []
167177
for dim in range(len(st_shapelets)):
@@ -174,13 +184,15 @@ def sets_explain(
174184
# fit a KNN for each class
175185
for c in np.unique(y_train):
176186
knns[c] = KNeighborsTimeSeries(n_neighbors=1)
187+
if X_train.shape[1]!= ts_length:
188+
X_train=np.swapaxes(X_train,1,2)
177189
X_train_knn = X_train[np.argwhere(y_train == c)].reshape(
178-
np.argwhere(y_train == c).shape[0], X_train.shape[1], X_train.shape[2]
190+
np.argwhere(y_train == c).shape[0], ts_length,-1
179191
)
180-
X_train_knn = np.swapaxes(X_train_knn, 1, 2)
181192
knns[c].fit(X_train_knn)
182193

183-
orig_c = int(np.argmax(model.predict(to_tff(instance_x))))
194+
orig_c = int(np.argmax(model.predict(to_tff(instance_x)),axis=1)[0])
195+
184196
if len(target) > 1:
185197
target.remove(orig_c)
186198
for target_c in target:
@@ -202,7 +214,7 @@ def sets_explain(
202214
cf = instance_x.copy()
203215

204216
cf_pred = model.predict(to_tff(cf))
205-
cf_pred = np.argmax(cf_pred)
217+
cf_pred = np.argmax(cf_pred,axis=1)[0]
206218
if target_c != cf_pred:
207219
# Get the locations where the original class shapelets occur
208220
all_locs = get_shapelets_locations_test(
@@ -215,7 +227,7 @@ def sets_explain(
215227
for c_i in all_locs:
216228
for loc in all_locs.get(c_i):
217229
cf_pred = model.predict(to_tff(cf))
218-
cf_pred = np.argmax(cf_pred)
230+
cf_pred = np.argmax(cf_pred,axis=1)[0]
219231
if target_c != cf_pred:
220232
# print('Removing original shapelet')
221233
nn = X_train[nn_idx].reshape(-1)
@@ -238,15 +250,18 @@ def sets_explain(
238250

239251
start = loc[0]
240252
end = loc[1]
253+
#print('start', start)
254+
#print('end', end)
241255

242256
cf[dim][start:end] = target_shapelet
257+
assert np.any(instance_x !=cf ), f"Pertubed instance is identical to the original instance"
258+
243259

244260
# Introduce new shapelets from the target class
245261
for idx, target_shapelet_idx in enumerate(all_target_heat_maps.keys()):
246262
cf_pred = model.predict(to_tff(cf))
247-
cf_pred = np.argmax(cf_pred)
263+
cf_pred = np.argmax(cf_pred,axis=1)[0]
248264
if target_c != cf_pred:
249-
# print('Introducing new shapelet')
250265
h_m = all_target_heat_maps[target_shapelet_idx]
251266
center = (
252267
np.argwhere(h_m > 0)[-1][0] - np.argwhere(h_m > 0)[0][0]
@@ -283,10 +298,12 @@ def sets_explain(
283298

284299
cf[dim][start:end] = target_shapelet
285300

301+
assert np.any(instance_x !=cf), f"Pertubed instance is identical to the original instance"
302+
286303
# Save the perturbed dimension
287304
cf_dims[dim] = cf[dim]
288305
cf_pred = model.predict(to_tff(cf))
289-
cf_pred = np.argmax(cf_pred)
306+
cf_pred = np.argmax(cf_pred,axis=1)[0]
290307
if target_c == cf_pred:
291308
return cf, cf_pred
292309
elif target_c != cf_pred:
@@ -298,10 +315,11 @@ def sets_explain(
298315
for dim_ in subset:
299316
cf[dim_] = cf_dims[dim_]
300317
cf_pred = model.predict(to_tff(cf))
301-
cf_pred = np.argmax(cf_pred)
318+
cf_pred = np.argmax(cf_pred,axis=1)[0]
302319
if target_c == cf_pred:
303320
break
304-
if target_c == cf_pred:
305-
return cf, cf_pred
306-
else:
307-
return None, None
321+
322+
#if orig_c != cf_pred:
323+
return cf, cf_pred
324+
#else:
325+
# return None, None

TSInterpret/InterpretabilityModels/counterfactual/SETS/utils.py

+28-11
Original file line numberDiff line numberDiff line change
@@ -123,19 +123,24 @@ def remove_similar_locations(shapelet_locations, shapelet_distances):
123123

124124
# Given the shapelet_distances matrix of a given shapelet, get the locations of
125125
# the closest shapelets from the entire dataset
126-
def get_shapelet_locations_scaled_threshold(shapelet_distances, ts_length, threshold):
126+
def get_shapelet_locations_scaled_threshold(shapelet_distances, ts_length, threshold, shapelets=None):
127127
# Compute the length of the shapelet
128128
shapelet_length = ts_length - shapelet_distances.shape[1] + 1
129129

130130
# Get the indices of the n closest shapelets to the original shapelet
131131
s_indices = []
132132
for i in range(shapelet_distances.shape[0]):
133133
for j in range(shapelet_distances.shape[1]):
134+
# i Iterates Items
135+
# j iterates Shapelets
134136
# Compare to the threshold, scaled to shapelet length
137+
#shapelet_length = ts_length - len(shapelet_distances[j]) + 1
135138
if shapelet_distances[i][j] / shapelet_length <= threshold:
139+
#j is the number of the shapelet
136140
s_indices.append(np.array([i, j]))
137141

138142
if len(s_indices) > 0:
143+
# Relevant shaplet indicies
139144
s_indices = np.asarray(s_indices)
140145

141146
# Create an array to store the locations of the closest n shapelets
@@ -147,7 +152,6 @@ def get_shapelet_locations_scaled_threshold(shapelet_distances, ts_length, thres
147152
shapelet_locations[i] = np.append(
148153
s_indices[i], s_indices[i][1] + shapelet_length
149154
)
150-
151155
# Remove overlapping shapelets and keep the closest one to th original shapelet
152156
shapelet_locations = remove_similar_locations(
153157
shapelet_locations, shapelet_distances
@@ -176,9 +180,9 @@ def get_occurences_threshold(shapelets_distances, ts_length, percentage):
176180
# Sort the distances ascendingly
177181
sds.sort()
178182

183+
179184
# Number of shapelet occurences to keep (per shapelet)
180185
n = int(percentage * len(sds))
181-
182186
# Return the threshold distance to select the shapelet occurences to keep
183187
return sds[n]
184188

@@ -188,6 +192,7 @@ def get_occurences_threshold(shapelets_distances, ts_length, percentage):
188192
def get_all_shapelet_locations_scaled_threshold(
189193
shapelets_distances, ts_length, percentage
190194
):
195+
191196
# Get the threshold to be used for selecting shapelet occurences
192197
threshold = get_occurences_threshold(shapelets_distances, ts_length, percentage)
193198

@@ -213,19 +218,28 @@ def get_all_shapelet_locations_scaled_threshold(
213218
# Get the locations of the closest shapelets for each timeseries across the
214219
# entire dataset based on the training threshold
215220
def get_all_shapelet_locations_scaled_threshold_test(
216-
shapelets_distances, ts_length, threshold
221+
shapelets_distances, ts_length, threshold,shapelets =None
217222
):
223+
224+
threshold=5
218225
all_shapelet_locations = []
219226
all_no_occurences = []
220227

221228
for dim in shapelets_distances:
229+
# Itreate DIMs
222230
dim_shapelet_locations = []
223231
no_occurences = []
232+
if type(dim) == int:
233+
dim= shapelets_distances[0]
224234
for i, shapelet in enumerate(dim):
235+
236+
# Iterate the shapelet [0. Num Shapelts]?
237+
# Get the shapelet Locations
225238
sls = get_shapelet_locations_scaled_threshold(
226-
shapelet, ts_length, threshold
239+
shapelet, ts_length, threshold,shapelets
227240
)
228241
if sls[0][0] != 4294967295:
242+
#print('Append',sls)
229243
dim_shapelet_locations.append(sls)
230244
else:
231245
no_occurences.append(i)
@@ -236,31 +250,34 @@ def get_all_shapelet_locations_scaled_threshold_test(
236250

237251

238252
def get_shapelets_locations_test(idx, all_sls, dim, all_shapelets_class):
253+
if len(np.array(all_shapelets_class).shape):
254+
all_shapelets_class=[all_shapelets_class]
239255
all_locs = {}
240-
try:
256+
257+
if True:
241258
for i, s in enumerate([all_sls[dim][j] for j in all_shapelets_class[dim]]):
259+
242260
i_locs = []
243261
for loc in s:
244-
if loc[0] == idx:
262+
if True:
263+
# TODO not necessary?
264+
#if loc[0] == idx:
245265
loc = (loc[1], loc[2])
246266
i_locs.append(loc)
247267
all_locs[i] = i_locs
248-
except Exception as ex:
249-
pass
268+
250269
return all_locs
251270

252271

253272
##Optimize by fitting outside or returning a list of all nns at once
254273
## Reworked so that only training data is available.
255274
def get_nearest_neighbor(knn, instance_x, pred_label, x_train, y_train):
256-
# pred_label = y_pred[idx]
257275
target_labels = np.argwhere(y_train != pred_label)
258276

259277
X_train_knn = instance_x.reshape(1, instance_x.shape[0], instance_x.shape[1])
260278
X_train_knn = np.swapaxes(X_train_knn, 1, 2)
261279

262280
_, nn = knn.kneighbors(X_train_knn)
263-
# print("TARGETLABELS", [t[0] for t in target_labels], [int(nn[0][0])])
264281
nn_idx = None
265282
try:
266283
nn_idx = [t[0] for t in target_labels][int(nn[0][0])]

TSInterpret/InterpretabilityModels/counterfactual/SETSCF.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(
5252
remove_self_similar=True,
5353
silent=False,
5454
fit_shapelets=True,
55+
le=False
5556
) -> None:
5657
"""
5758
Arguments:
@@ -81,20 +82,24 @@ def __init__(
8182
train_x, train_y = data
8283
self.le = LabelEncoder()
8384
self.train_y = self.le.fit_transform(train_y)
85+
self.mode=mode
8486
if mode == "time":
8587
# Parse test data into (1, feat, time):
86-
change = True
88+
change = False
8789
self.train_x = np.swapaxes(train_x, 2, 1)
8890
self.ts_len = train_x.shape[1]
8991
elif mode == "feat":
90-
change = False
91-
self.train_x = np.array(train_x)
92+
change = True
93+
self.train_x = train_x
9294
self.ts_len = train_x.shape[2]
95+
#self.train_x = np.swapaxes(train_x, 2, 1)
96+
print(self.train_x.shape)
9397
self.train_x_n = from_3d_numpy_to_nested(self.train_x)
98+
print(self.train_x_n.shape)
9499
if backend == "PYT":
95-
self.predict = PyTorchModel(model, change).predict
100+
self.predict = PyTorchModel(model, change)
96101
elif backend == "TF":
97-
self.predict = TensorFlowModel(model, change).predict
102+
self.predict = TensorFlowModel(model, change)
98103
elif backend == "SK":
99104
self.predict = SklearnModel(model, change).predict
100105
# Fit Shapelet Transform
@@ -183,13 +188,20 @@ def explain(
183188
target = list(np.unique(self.train_y))
184189
else:
185190
target = [target]
191+
if self.mode == 'time':
192+
x= np.swapaxes(x, -1, -2)
193+
194+
195+
#else:
196+
# x=np.swapaxes(x,-1,-2)
197+
186198

187199
expl, label = sets_explain(
188200
x,
189201
target,
190202
(self.train_x, self.train_y),
191203
self.st_transformer,
192-
self.model,
204+
self.predict,
193205
self.ts_len,
194206
self.fitted_shapelets,
195207
self.threshhold,

TSInterpret/__version__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
VERSION = (0, 4, 6)
1+
VERSION = (0, 4, 7)
22
__version__ = ".".join(map(str, VERSION)) # noqa: F401

0 commit comments

Comments
 (0)