Skip to content

Commit 36ee01b

Browse files
Create GRADCAM.py
1 parent 1cf00d8 commit 36ee01b

File tree

1 file changed

+90
-0
lines changed

1 file changed

+90
-0
lines changed

GRADCAM.py

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
'''
2+
Install Grad CAM : `!pip install tf-explain`
3+
* src : https://github.com/sicara/tf-explain
4+
* paper : Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization
5+
* Reference : https://arxiv.org/abs/1610.02391
6+
* Abstract : We propose a technique for producing "visual explanations" for decisions from a large class
7+
of CNN-based models, making them more transparent. Our approach - Gradient-weighted Class Activation Mapping
8+
(Grad-CAM), uses the gradients of any target concept, flowing into the final convolutional layer to produce
9+
a coarse localization map highlighting important regions in the image for predicting the concept. Grad-CAM
10+
is applicable to a wide variety of CNN model-families:
11+
(1) CNNs with fully-connected layers,
12+
(2) CNNs used for structured outputs,
13+
(3) CNNs used in tasks with multimodal inputs or reinforcement learning,
14+
without any architectural changes or re-training. We combine Grad-CAM with fine-grained visualizations to create
15+
a high-resolution class-discriminative visualization and apply it to off-the-shelf image classification, captioning,
16+
and visual question answering (VQA) models, including ResNet-based architectures. In the context of image classification
17+
models, our visualizations (a) lend insights into their failure modes,
18+
(b) are robust to adversarial images, (c) outperform previous methods on localization, (d) are more faithful to the
19+
underlying model and (e) help achieve generalization by identifying dataset bias. For captioning and VQA, we show that even
20+
non-attention based models can localize inputs. We devise a way to identify important neurons through Grad-CAM and combine it
21+
with neuron names to provide textual explanations for model decisions. Finally, we design and conduct human studies to measure
22+
if Grad-CAM helps users establish appropriate trust in predictions from models and show that Grad-CAM helps untrained users
23+
successfully discern a 'stronger' nodel from a 'weaker' one even when both make identical predictions.
24+
25+
##### Note : you can pass `model` object as any tensorflow keras model.
26+
'''
27+
import tensorflow as tf
28+
import os
29+
import cv2
30+
from tensorflow.keras.preprocessing.image import *
31+
from tensorflow.keras.applications.imagenet_utils import preprocess_input
32+
import matplotlib.pyplot as plt
33+
from tf_explain.core.grad_cam import GradCAM
34+
import numpy as np
35+
36+
37+
class_mapping = {0: 'Cat', 1: 'Dog'}
38+
39+
40+
def preprocessing_image(instancePath):
41+
original_image = plt.imread(instancePath)
42+
image = load_img(instancePath, target_size=(224, 224))
43+
image = img_to_array(image)
44+
image = tf.expand_dims(image, 0)
45+
image /= 255.0
46+
image = preprocess_input(image)
47+
return image, original_image
48+
49+
50+
def predict_per(IMAGE_PATH):
51+
image, o_image = preprocessing_image(IMAGE_PATH)
52+
prediction = np.argmax(model.predict(image))
53+
prediction_per = np.max(model.predict(image))
54+
return class_mapping[prediction], prediction, prediction_per
55+
56+
57+
def Grad_cam_vis(IMAGE_PATH, OUTPUT_PATH, ACTUAL_LABEL: str):
58+
prediction, index, confidense = predict_per(IMAGE_PATH)
59+
img = tf.keras.preprocessing.image.load_img(IMAGE_PATH, target_size=(224, 224))
60+
img = tf.keras.preprocessing.image.img_to_array(img)
61+
data = ([img], None)
62+
63+
# Start explainer
64+
explainer = GradCAM()
65+
grid = explainer.explain(data, model, class_index=index)
66+
67+
explainer.save(grid, ".", OUTPUT_PATH)
68+
69+
im = cv2.cvtColor(cv2.imread(IMAGE_PATH), cv2.COLOR_BGR2RGB)
70+
im1 = cv2.cvtColor(cv2.imread(OUTPUT_PATH), cv2.COLOR_BGR2RGB)
71+
plt.figure(figsize=(15, 8))
72+
plt.subplot(1, 2, 1)
73+
plt.imshow(im)
74+
plt.xlabel(f"Actaul: Healthy")
75+
plt.subplot(1, 2, 2)
76+
plt.imshow(im1)
77+
plt.xlabel(f"predict:{prediction}\nConfidence: {confidense}")
78+
plt.show()
79+
80+
81+
images = os.listdir('/content/Images/')
82+
for img in images:
83+
# print(img)
84+
os.mkdirs("Output", exist_ok = True)
85+
maping_result = {'c':'Cat', 'd':'Dog'}
86+
actual_label = maping_result[img[0].lower()]
87+
IMAGE_PATH = f"/content/Images/{img}"
88+
OUTPUT_PATH = f"/content/Output/output_{img[:-4]}.jpg"
89+
pred_class, prediction, prediction_per = predict_per(IMAGE_PATH)
90+
Grad_cam_vis(IMAGE_PATH, OUTPUT_PATH, actual_label)

0 commit comments

Comments
 (0)