|
| 1 | +''' |
| 2 | +Install Grad CAM : `!pip install tf-explain` |
| 3 | +* src : https://github.com/sicara/tf-explain |
| 4 | +* paper : Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization |
| 5 | +* Reference : https://arxiv.org/abs/1610.02391 |
| 6 | +* Abstract : We propose a technique for producing "visual explanations" for decisions from a large class |
| 7 | + of CNN-based models, making them more transparent. Our approach - Gradient-weighted Class Activation Mapping |
| 8 | + (Grad-CAM), uses the gradients of any target concept, flowing into the final convolutional layer to produce |
| 9 | + a coarse localization map highlighting important regions in the image for predicting the concept. Grad-CAM |
| 10 | + is applicable to a wide variety of CNN model-families: |
| 11 | + (1) CNNs with fully-connected layers, |
| 12 | + (2) CNNs used for structured outputs, |
| 13 | + (3) CNNs used in tasks with multimodal inputs or reinforcement learning, |
| 14 | + without any architectural changes or re-training. We combine Grad-CAM with fine-grained visualizations to create |
| 15 | + a high-resolution class-discriminative visualization and apply it to off-the-shelf image classification, captioning, |
| 16 | + and visual question answering (VQA) models, including ResNet-based architectures. In the context of image classification |
| 17 | + models, our visualizations (a) lend insights into their failure modes, |
| 18 | + (b) are robust to adversarial images, (c) outperform previous methods on localization, (d) are more faithful to the |
| 19 | + underlying model and (e) help achieve generalization by identifying dataset bias. For captioning and VQA, we show that even |
| 20 | + non-attention based models can localize inputs. We devise a way to identify important neurons through Grad-CAM and combine it |
| 21 | + with neuron names to provide textual explanations for model decisions. Finally, we design and conduct human studies to measure |
| 22 | + if Grad-CAM helps users establish appropriate trust in predictions from models and show that Grad-CAM helps untrained users |
| 23 | + successfully discern a 'stronger' nodel from a 'weaker' one even when both make identical predictions. |
| 24 | +
|
| 25 | +##### Note : you can pass `model` object as any tensorflow keras model. |
| 26 | +''' |
| 27 | +import tensorflow as tf |
| 28 | +import os |
| 29 | +import cv2 |
| 30 | +from tensorflow.keras.preprocessing.image import * |
| 31 | +from tensorflow.keras.applications.imagenet_utils import preprocess_input |
| 32 | +import matplotlib.pyplot as plt |
| 33 | +from tf_explain.core.grad_cam import GradCAM |
| 34 | +import numpy as np |
| 35 | + |
| 36 | + |
| 37 | +class_mapping = {0: 'Cat', 1: 'Dog'} |
| 38 | + |
| 39 | + |
| 40 | +def preprocessing_image(instancePath): |
| 41 | + original_image = plt.imread(instancePath) |
| 42 | + image = load_img(instancePath, target_size=(224, 224)) |
| 43 | + image = img_to_array(image) |
| 44 | + image = tf.expand_dims(image, 0) |
| 45 | + image /= 255.0 |
| 46 | + image = preprocess_input(image) |
| 47 | + return image, original_image |
| 48 | + |
| 49 | + |
| 50 | +def predict_per(IMAGE_PATH): |
| 51 | + image, o_image = preprocessing_image(IMAGE_PATH) |
| 52 | + prediction = np.argmax(model.predict(image)) |
| 53 | + prediction_per = np.max(model.predict(image)) |
| 54 | + return class_mapping[prediction], prediction, prediction_per |
| 55 | + |
| 56 | + |
| 57 | +def Grad_cam_vis(IMAGE_PATH, OUTPUT_PATH, ACTUAL_LABEL: str): |
| 58 | + prediction, index, confidense = predict_per(IMAGE_PATH) |
| 59 | + img = tf.keras.preprocessing.image.load_img(IMAGE_PATH, target_size=(224, 224)) |
| 60 | + img = tf.keras.preprocessing.image.img_to_array(img) |
| 61 | + data = ([img], None) |
| 62 | + |
| 63 | + # Start explainer |
| 64 | + explainer = GradCAM() |
| 65 | + grid = explainer.explain(data, model, class_index=index) |
| 66 | + |
| 67 | + explainer.save(grid, ".", OUTPUT_PATH) |
| 68 | + |
| 69 | + im = cv2.cvtColor(cv2.imread(IMAGE_PATH), cv2.COLOR_BGR2RGB) |
| 70 | + im1 = cv2.cvtColor(cv2.imread(OUTPUT_PATH), cv2.COLOR_BGR2RGB) |
| 71 | + plt.figure(figsize=(15, 8)) |
| 72 | + plt.subplot(1, 2, 1) |
| 73 | + plt.imshow(im) |
| 74 | + plt.xlabel(f"Actaul: Healthy") |
| 75 | + plt.subplot(1, 2, 2) |
| 76 | + plt.imshow(im1) |
| 77 | + plt.xlabel(f"predict:{prediction}\nConfidence: {confidense}") |
| 78 | + plt.show() |
| 79 | + |
| 80 | + |
| 81 | +images = os.listdir('/content/Images/') |
| 82 | +for img in images: |
| 83 | + # print(img) |
| 84 | + os.mkdirs("Output", exist_ok = True) |
| 85 | + maping_result = {'c':'Cat', 'd':'Dog'} |
| 86 | + actual_label = maping_result[img[0].lower()] |
| 87 | + IMAGE_PATH = f"/content/Images/{img}" |
| 88 | + OUTPUT_PATH = f"/content/Output/output_{img[:-4]}.jpg" |
| 89 | + pred_class, prediction, prediction_per = predict_per(IMAGE_PATH) |
| 90 | + Grad_cam_vis(IMAGE_PATH, OUTPUT_PATH, actual_label) |
0 commit comments