|
| 1 | +from pathlib import Path |
| 2 | +from threading import Event, Lock, Thread |
| 3 | +from typing import Union |
| 4 | + |
| 5 | +from cv_bridge import CvBridge |
| 6 | +import numpy as np |
| 7 | +from rclpy.callback_groups import MutuallyExclusiveCallbackGroup, ReentrantCallbackGroup |
| 8 | +from rclpy.node import Node, ParameterDescriptor, Parameter |
| 9 | +from sensor_msgs.msg import Image |
| 10 | + |
| 11 | +from yolov7.detect_ptg import load_model, predict_image, predict_hands |
| 12 | +from yolov7.models.experimental import attempt_load |
| 13 | +import yolov7.models.yolo |
| 14 | +from yolov7.utils.torch_utils import TracedModel |
| 15 | + |
| 16 | +from angel_system.utils.event import WaitAndClearEvent |
| 17 | +from angel_system.utils.simple_timer import SimpleTimer |
| 18 | + |
| 19 | +from angel_msgs.msg import ObjectDetection2dSet |
| 20 | +from angel_utils import declare_and_get_parameters, RateTracker, DYNAMIC_TYPE |
| 21 | +from angel_utils import make_default_main |
| 22 | + |
| 23 | + |
| 24 | +BRIDGE = CvBridge() |
| 25 | + |
| 26 | + |
| 27 | +class YoloObjectDetector(Node): |
| 28 | + """ |
| 29 | + ROS node that runs the yolov7 object detector model and outputs |
| 30 | + `ObjectDetection2dSet` messages. |
| 31 | + """ |
| 32 | + |
| 33 | + def __init__(self): |
| 34 | + super().__init__(self.__class__.__name__) |
| 35 | + log = self.get_logger() |
| 36 | + |
| 37 | + # Inputs |
| 38 | + param_values = declare_and_get_parameters( |
| 39 | + self, |
| 40 | + [ |
| 41 | + ################################## |
| 42 | + # Required parameter (no defaults) |
| 43 | + ("image_topic",), |
| 44 | + ("det_topic",), |
| 45 | + ("net_checkpoint",), |
| 46 | + ################################## |
| 47 | + # Defaulted parameters |
| 48 | + ("inference_img_size", 1280), # inference size (pixels) |
| 49 | + ("det_conf_threshold", 0.7), # object confidence threshold |
| 50 | + ("iou_threshold", 0.45), # IOU threshold for NMS |
| 51 | + ("cuda_device_id", 0, DYNAMIC_TYPE), # cuda device: ID int or CPU |
| 52 | + ("no_trace", True), # don`t trace model |
| 53 | + ("agnostic_nms", False), # class-agnostic NMS |
| 54 | + # Runtime thread checkin heartbeat interval in seconds. |
| 55 | + ("rt_thread_heartbeat", 0.1), |
| 56 | + # If we should enable additional logging to the info level |
| 57 | + # about when we receive and process data. |
| 58 | + ("enable_time_trace_logging", False), |
| 59 | + ], |
| 60 | + ) |
| 61 | + self._image_topic = param_values["image_topic"] |
| 62 | + self._det_topic = param_values["det_topic"] |
| 63 | + self._model_ckpt_fp = Path(param_values["net_checkpoint"]) |
| 64 | + |
| 65 | + self._inference_img_size = param_values["inference_img_size"] |
| 66 | + self._det_conf_thresh = param_values["det_conf_threshold"] |
| 67 | + self._iou_thr = param_values["iou_threshold"] |
| 68 | + self._cuda_device_id = param_values["cuda_device_id"] |
| 69 | + self._no_trace = param_values["no_trace"] |
| 70 | + self._agnostic_nms = param_values["agnostic_nms"] |
| 71 | + |
| 72 | + self._enable_trace_logging = param_values["enable_time_trace_logging"] |
| 73 | + |
| 74 | + # Model |
| 75 | + self.model: Union[yolov7.models.yolo.Model, TracedModel] |
| 76 | + if not self._model_ckpt_fp.is_file(): |
| 77 | + raise ValueError( |
| 78 | + f"Model checkpoint file did not exist: {self._model_ckpt_fp}" |
| 79 | + ) |
| 80 | + (self.device, self.model, self.stride, self.imgsz) = load_model( |
| 81 | + str(self._cuda_device_id), self._model_ckpt_fp, self._inference_img_size |
| 82 | + ) |
| 83 | + log.info( |
| 84 | + f"Loaded model with classes:\n" |
| 85 | + + "\n".join(f'\t- "{n}"' for n in self.model.names) |
| 86 | + ) |
| 87 | + |
| 88 | + # Single slot for latest image message to process detection over. |
| 89 | + self._cur_image_msg: Image = None |
| 90 | + self._cur_image_msg_lock = Lock() |
| 91 | + |
| 92 | + # Initialize ROS hooks |
| 93 | + self._subscription = self.create_subscription( |
| 94 | + Image, |
| 95 | + self._image_topic, |
| 96 | + self.listener_callback, |
| 97 | + 1, |
| 98 | + callback_group=MutuallyExclusiveCallbackGroup(), |
| 99 | + ) |
| 100 | + self._det_publisher = self.create_publisher( |
| 101 | + ObjectDetection2dSet, |
| 102 | + self._det_topic, |
| 103 | + 1, |
| 104 | + callback_group=MutuallyExclusiveCallbackGroup(), |
| 105 | + ) |
| 106 | + |
| 107 | + if not self._no_trace: |
| 108 | + self.model = TracedModel(self.model, self.device, self._inference_img_size) |
| 109 | + |
| 110 | + self.half = half = ( |
| 111 | + self.device.type != "cpu" |
| 112 | + ) # half precision only supported on CUDA |
| 113 | + if half: |
| 114 | + self.model.half() # to FP16 |
| 115 | + |
| 116 | + self._rate_tracker = RateTracker() |
| 117 | + log.info("Detector initialized") |
| 118 | + |
| 119 | + # Create and start detection runtime thread and loop. |
| 120 | + log.info("Starting runtime thread...") |
| 121 | + # On/Off Switch for runtime loop |
| 122 | + self._rt_active = Event() |
| 123 | + self._rt_active.set() |
| 124 | + # seconds to occasionally time out of the wait condition for the loop |
| 125 | + # to check if it is supposed to still be alive. |
| 126 | + self._rt_active_heartbeat = param_values["rt_thread_heartbeat"] |
| 127 | + # Condition that the runtime should perform processing |
| 128 | + self._rt_awake_evt = WaitAndClearEvent() |
| 129 | + self._rt_thread = Thread(target=self.rt_loop, name="prediction_runtime") |
| 130 | + self._rt_thread.daemon = True |
| 131 | + self._rt_thread.start() |
| 132 | + |
| 133 | + def listener_callback(self, image: Image): |
| 134 | + """ |
| 135 | + Callback function for image messages. Runs the berkeley object detector |
| 136 | + on the image and publishes an ObjectDetectionSet2d message for the image. |
| 137 | + """ |
| 138 | + log = self.get_logger() |
| 139 | + if self._enable_trace_logging: |
| 140 | + log.info(f"Received image with TS: {image.header.stamp}") |
| 141 | + with self._cur_image_msg_lock: |
| 142 | + self._cur_image_msg = image |
| 143 | + self._rt_awake_evt.set() |
| 144 | + |
| 145 | + def rt_alive(self) -> bool: |
| 146 | + """ |
| 147 | + Check that the prediction runtime is still alive and raise an exception |
| 148 | + if it is not. |
| 149 | + """ |
| 150 | + alive = self._rt_thread.is_alive() |
| 151 | + if not alive: |
| 152 | + self.get_logger().warn("Runtime thread no longer alive.") |
| 153 | + self._rt_thread.join() |
| 154 | + return alive |
| 155 | + |
| 156 | + def rt_stop(self) -> None: |
| 157 | + """ |
| 158 | + Indicate that the runtime loop should cease. |
| 159 | + """ |
| 160 | + self._rt_active.clear() |
| 161 | + |
| 162 | + def rt_loop(self): |
| 163 | + log = self.get_logger() |
| 164 | + log.info("Runtime loop starting") |
| 165 | + enable_trace_logging = self._enable_trace_logging |
| 166 | + |
| 167 | + while self._rt_active.wait(0): # will quickly return false if cleared. |
| 168 | + if self._rt_awake_evt.wait_and_clear(self._rt_active_heartbeat): |
| 169 | + with self._cur_image_msg_lock: |
| 170 | + if self._cur_image_msg is None: |
| 171 | + continue |
| 172 | + image = self._cur_image_msg |
| 173 | + self._cur_image_msg = None |
| 174 | + |
| 175 | + if enable_trace_logging: |
| 176 | + log.info(f"[rt-loop] Processing image TS={image.header.stamp}") |
| 177 | + # Convert ROS img msg to CV2 image |
| 178 | + img0 = BRIDGE.imgmsg_to_cv2(image, desired_encoding="bgr8") |
| 179 | + |
| 180 | + print(f"img0: {img0.shape}") |
| 181 | + width, height = self._inference_img_size |
| 182 | + |
| 183 | + msg = ObjectDetection2dSet() |
| 184 | + msg.header.stamp = self.get_clock().now().to_msg() |
| 185 | + msg.header.frame_id = image.header.frame_id |
| 186 | + msg.source_stamp = image.header.stamp |
| 187 | + msg.label_vec[:] = self.model.names |
| 188 | + |
| 189 | + n_classes = len(self.model.names) + 2 # accomedate 2 hands |
| 190 | + n_dets = 0 |
| 191 | + |
| 192 | + dflt_conf_vec = np.zeros(n_classes, dtype=np.float64) |
| 193 | + right_hand_cid = n_classes - 2 |
| 194 | + left_hand_cid = n_classes - 1 |
| 195 | + |
| 196 | + hands_preds = predict_hands(hand_model=self.hand_model, img0=img0, |
| 197 | + img_size=self._inference_img_size, device=self.device) |
| 198 | + |
| 199 | + hand_centers = [center.xywh.tolist()[0][0] for center in hands_preds.boxes][:2] |
| 200 | + hands_label = [] |
| 201 | + if len(hand_centers) == 2: |
| 202 | + if hand_centers[0] > hand_centers[1]: |
| 203 | + hands_label.append(right_hand_cid) |
| 204 | + hands_label.append(left_hand_cid) |
| 205 | + elif hand_centers[0] <= hand_centers[1]: |
| 206 | + hands_label.append(left_hand_cid) |
| 207 | + hands_label.append(right_hand_cid) |
| 208 | + elif len(hand_centers) == 1: |
| 209 | + if hand_centers[0] > width//2: |
| 210 | + hands_label.append(right_hand_cid) |
| 211 | + elif hand_centers[0] <= width//2: |
| 212 | + hands_label.append(left_hand_cid) |
| 213 | + |
| 214 | + |
| 215 | + |
| 216 | + for xyxy, conf, cls_id in predict_image( |
| 217 | + img0, |
| 218 | + self.device, |
| 219 | + self.model, |
| 220 | + self.stride, |
| 221 | + self.imgsz, |
| 222 | + self.half, |
| 223 | + False, |
| 224 | + self._det_conf_thresh, |
| 225 | + self._iou_thr, |
| 226 | + None, |
| 227 | + self._agnostic_nms, |
| 228 | + ): |
| 229 | + n_dets += 1 |
| 230 | + msg.left.append(xyxy[0]) |
| 231 | + msg.top.append(xyxy[1]) |
| 232 | + msg.right.append(xyxy[2]) |
| 233 | + msg.bottom.append(xyxy[3]) |
| 234 | + |
| 235 | + dflt_conf_vec[cls_id] = conf |
| 236 | + # copies data into array |
| 237 | + msg.label_confidences.extend(dflt_conf_vec) |
| 238 | + # reset before next passthrough |
| 239 | + dflt_conf_vec[cls_id] = 0.0 |
| 240 | + |
| 241 | + msg.num_detections = n_dets |
| 242 | + |
| 243 | + self._det_publisher.publish(msg) |
| 244 | + |
| 245 | + self._rate_tracker.tick() |
| 246 | + log.info( |
| 247 | + f"Objects Detection Rate: {self._rate_tracker.get_rate_avg()} Hz", |
| 248 | + ) |
| 249 | + |
| 250 | + def destroy_node(self): |
| 251 | + print("Stopping runtime") |
| 252 | + self.rt_stop() |
| 253 | + print("Shutting down runtime thread...") |
| 254 | + self._rt_active.clear() # make RT active flag "False" |
| 255 | + self._rt_thread.join() |
| 256 | + print("Shutting down runtime thread... Done") |
| 257 | + super().destroy_node() |
| 258 | + |
| 259 | + |
| 260 | +# Don't really want to use *all* available threads... |
| 261 | +# 3 threads because: |
| 262 | +# - 1 known subscriber which has their own group |
| 263 | +# - 1 for default group |
| 264 | +# - 1 for publishers |
| 265 | +main = make_default_main(YoloObjectDetector, multithreaded_executor=3) |
| 266 | + |
| 267 | + |
| 268 | +if __name__ == "__main__": |
| 269 | + main() |
0 commit comments