|
| 1 | +import os |
| 2 | +from PIL import Image |
| 3 | +import numpy as np |
| 4 | +import random |
| 5 | +import pickle |
| 6 | + |
| 7 | +IMAGENET_PATH = "/MEng/Data/ILSVRC2012_img_val/" |
| 8 | +MEAN = [0.485, 0.456, 0.406] |
| 9 | +STD = [0.229, 0.224, .225] |
| 10 | + |
| 11 | +CALIB_BASE_PATH=os.getenv("CALIB_BASE_PATH") |
| 12 | +if CALIB_BASE_PATH is None: |
| 13 | + raise ValueError("Environment variable CALIB_BASE_PATH not set") |
| 14 | + |
| 15 | +CALIB_MODEL_SPLIT=os.getenv("CALIB_MODEL_SPLIT") |
| 16 | +if CALIB_MODEL_SPLIT is None: |
| 17 | + raise ValueError("Environment variable CALIB_MODEL_SPLIT not set") |
| 18 | + |
| 19 | +quantize_info_path = os.path.join(CALIB_BASE_PATH, f"model_tf_split_{CALIB_MODEL_SPLIT}/quantize_info.txt") |
| 20 | +input_info_path = os.path.join(CALIB_BASE_PATH, f"model_tf_split_{CALIB_MODEL_SPLIT}/inputs.pickle") |
| 21 | + |
| 22 | +input_shapes = {} |
| 23 | +with open(quantize_info_path) as f: |
| 24 | + lines = f.readlines() |
| 25 | + raw_input_names = [] |
| 26 | + raw_input_shapes = [] |
| 27 | + for i in range(len(lines)): |
| 28 | + if "--input_nodes" in lines[i]: |
| 29 | + raw_input_names = lines[i+1].rstrip() |
| 30 | + if "--input_shapes" in lines[i]: |
| 31 | + raw_input_shapes = lines[i+1].rstrip() |
| 32 | + |
| 33 | + raw_input_names = raw_input_names.split(",") |
| 34 | + raw_input_shapes = raw_input_shapes.split(":") |
| 35 | + raw_input_shapes = [[int(x) for x in shape.split(',')] for shape in raw_input_shapes] |
| 36 | + input_shapes = dict(zip(raw_input_names, raw_input_shapes)) |
| 37 | + |
| 38 | + |
| 39 | +input_data = {} |
| 40 | +# shift_concat, resid |
| 41 | +with open(input_info_path, 'rb') as f: |
| 42 | + input_data = pickle.load(f) |
| 43 | + |
| 44 | +def input_fn(iter): |
| 45 | + #files = sorted(os.listdir(IMAGENET_PATH)) |
| 46 | + #img = Image.open(os.path.join(IMAGENET_PATH,files[iter])).resize((224, 224)) |
| 47 | + #img = np.array(img) / 255.0 |
| 48 | + ##img = (img - MEAN) / STD |
| 49 | + #img = np.transpose(img, axes=[2, 0, 1]) |
| 50 | + #img = np.expand_dims(img, axis=0) |
| 51 | + #return {"input_node": img} |
| 52 | + inputs = {} |
| 53 | + for name,shape in input_shapes.items(): |
| 54 | + if "/input" in name: |
| 55 | + inputs[name] = np.array(input_data[iter]["resid"]) |
| 56 | + #inputs[name] = np.array(input_data["0"]["resid"]) |
| 57 | + else: |
| 58 | + inputs[name] = np.array(input_data[iter]["shift_concat"]) |
| 59 | + #inputs[name] = np.array(input_data["0"]["shift_concat"]) |
| 60 | + |
| 61 | + #inputs = {name: np.random.rand(*shape) for name,shape in input_shapes.items()} |
| 62 | + |
| 63 | + return inputs |
0 commit comments