104 lines
3.7 KiB
Python
Executable File
104 lines
3.7 KiB
Python
Executable File
import tensorflow as tf
|
|
|
|
config = tf.ConfigProto()
|
|
config.gpu_options.allow_growth = True
|
|
sess = tf.Session(config=config)
|
|
|
|
import os
|
|
import sys
|
|
import random
|
|
import math
|
|
import numpy as np
|
|
import skimage.io
|
|
import matplotlib
|
|
import matplotlib.pyplot as plt
|
|
from datetime import datetime
|
|
from numpy import asarray
|
|
|
|
# Root directory of the project
|
|
ROOT_DIR = os.path.abspath("../")
|
|
|
|
# Import Mask RCNN
|
|
sys.path.append(ROOT_DIR) # To find local version of the library
|
|
# from mrcnn import utils
|
|
import mrcnn.model as modellib
|
|
from mrcnn import visualize
|
|
# Import COCO config
|
|
sys.path.append(os.path.join(ROOT_DIR, "samples/coco/")) # To find local version
|
|
import coco
|
|
|
|
# Directory to save logs and trained model
|
|
MODEL_DIR = os.path.join(ROOT_DIR, "logs")
|
|
|
|
# Local path to trained weights file
|
|
# COCO_MODEL_PATH = os.path.join(ROOT_DIR, "model/mask_rcnn_coco.h5")
|
|
COCO_MODEL_PATH = "/root/cnnmodel/mask_rcnn_coco.h5"
|
|
# Download COCO trained weights from Releases if needed
|
|
# if not os.path.exists(COCO_MODEL_PATH):
|
|
# utils.download_trained_weights(COCO_MODEL_PATH)
|
|
|
|
# Directory of images to run detection on
|
|
IMAGE_DIR = os.path.join(ROOT_DIR, "TUM_Images")
|
|
|
|
class InferenceConfig(coco.CocoConfig):
|
|
# Set batch size to 1 since we'll be running inference on
|
|
# one image at a time. Batch size = GPU_COUNT * IMAGES_PER_GPU
|
|
GPU_COUNT = 1
|
|
IMAGES_PER_GPU = 10
|
|
|
|
config = InferenceConfig()
|
|
config.display()
|
|
|
|
batch_size = config.GPU_COUNT * config.IMAGES_PER_GPU
|
|
print("batch size: ", batch_size)
|
|
|
|
# Create model object in inference mode.
|
|
model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR, config=config)
|
|
|
|
# Load weights trained on MS-COCO
|
|
model.load_weights(COCO_MODEL_PATH, by_name=True)
|
|
|
|
# COCO Class names
|
|
# Index of the class in the list is its ID. For example, to get ID of
|
|
# the teddy bear class, use: class_names.index('teddy bear')
|
|
class_names = ['BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane',
|
|
'bus', 'train', 'truck', 'boat', 'traffic light',
|
|
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird',
|
|
'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
|
|
'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
|
|
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
|
'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
|
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
|
|
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
|
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
|
|
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
|
|
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
|
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster',
|
|
'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
|
|
'teddy bear', 'hair drier', 'toothbrush']
|
|
|
|
# Load a random image from the images folder
|
|
file_names = next(os.walk(IMAGE_DIR))[2]
|
|
|
|
# image = skimage.io.imread(os.path.join(IMAGE_DIR, random.choice(file_names)))
|
|
|
|
# images = np.empty((2, 480, 640,3))
|
|
# for x in range(len(file_names)):
|
|
images = []
|
|
for i in range(batch_size):
|
|
image = skimage.io.imread(os.path.join(IMAGE_DIR, file_names[i]))
|
|
images.append( image )
|
|
|
|
start_time = datetime.now()
|
|
# Run detection
|
|
# results = model.detect([images[0], images[1] ] , verbose=1)
|
|
results = model.detect( images , verbose=1)
|
|
|
|
time_elipsed = (datetime.now() - start_time).total_seconds()
|
|
print("time cost: %f ms "% (time_elipsed*1000) )
|
|
|
|
# Visualize results
|
|
print("length of reault:" , len(results) )
|
|
r = results[0]
|
|
visualize.display_instances(images[0], r['rois'], r['masks'], r['class_ids'], class_names, r['scores'])
|