#!../tb-venv/bin/python3 """ Ressources : * https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/raspberry_pi """ import tflite_runtime.interpreter as tflite import argparse import glob import time from PIL import Image import numpy as np def set_input_tensor(interpreter, image): tensor_index = interpreter.get_input_details()[0]['index'] input_tensor = interpreter.tensor(tensor_index)()[0] input_tensor[:, :] = image def classify_image(interpreter, image, top_k=1): """Returns a sorted array of classification results.""" set_input_tensor(interpreter, image) interpreter.invoke() output_details = interpreter.get_output_details()[0] output = np.squeeze(interpreter.get_tensor(output_details['index'])) # If the model is quantized (uint8 data), then dequantize the results if output_details['dtype'] == np.uint8: scale, zero_point = output_details['quantization'] output = scale * (output - zero_point) ordered = np.argpartition(-output, top_k) return [(i, output[i]) for i in ordered[:top_k]] def main(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( '--model', help='File path of .tflite file.', required=True) # parser.add_argument( # '--labels', help='File path of labels file.', required=True) args = parser.parse_args() interpreter = tflite.Interpreter(args.model) interpreter.allocate_tensors() _, height, width, _ = interpreter.get_input_details()[0]['shape'] CAPTURES_DIR = "/home/pi/captures" image_paths = glob.glob(CAPTURES_DIR + "/*.jpg") for image_path in image_paths: im = Image.open(image_path) im = im.resize((width, height)) start_time = time.time() results = classify_image(interpreter, im) elapsed_time = (time.time() - start_time) * 1000 label_id, prob = results[0] print(image_path) print('-' * 20) print(f"Prediction : {label_id} ({prob}) -- computed in {elapsed_time}ms.") if __name__ == "__main__": main()