Skip to content

Commit 87d8672

Browse files
authored
Feat/real time int (#7)
* unified envrionment, GSP, real time integration * unified envrionment, GSP, real time integration * hand detector ros node * submodules url corrected, hand detection node, initial pose estimation node * yolov7 submodule hash update
1 parent 7e986b9 commit 87d8672

File tree

4 files changed

+320
-41
lines changed

4 files changed

+320
-41
lines changed

.gitmodules

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
branch = main
1111
[submodule "python-tpl/yolov7"]
1212
path = python-tpl/yolov7
13-
url = https://github.com/cameron-a-johnson/yolov7.git
13+
url = https://github.com/PTG-Kitware/yolov7
1414
[submodule "ros/rosbag2"]
1515
path = ros/rosbag2
1616
url = https://github.com/ros2/rosbag2.git

pyproject.toml

+5-9
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ ultralytics = "==8.1.27"
4242
rootutils = "==1.0.7"
4343
torchmetrics = "==0.11.4"
4444
rich = "==13.7.1"
45+
chumpy = "==0.70"
4546

4647

4748
## For UHO Activity Classifier
@@ -78,16 +79,11 @@ kwimage = ">=0.9.18"
7879
networkx = ">=3.1"
7980

8081

81-
# TCN Activity classifier /home/local/KHQ/peri.akiva/projects/TCN_HPL
82-
tcn-hpl = {path = "/home/local/KHQ/peri.akiva/projects/TCN_HPL", develop = true}
83-
84-
# Yolo v7 and Yolo v8 object detection
85-
yolov7 = {path = "/home/local/KHQ/peri.akiva/projects/yolov7", develop = true}
86-
#ultralytics = {path = "/home/local/KHQ/peri.akiva/projects/ultralytics", develop = true}
87-
88-
# For Yolo V8
89-
# thop = ">=0.1.1" # FLOPs computation
82+
# TCN Activity classifier
83+
tcn-hpl = {path = "python-tpl/TCN_HPL", develop = true}
9084

85+
# Yolo v7 object detection
86+
yolov7 = {path = "python-tpl/yolov7", develop = true}
9187

9288
[tool.poetry.dev-dependencies]
9389
ipython = "*"

ros/angel_system_nodes/angel_system_nodes/object_detection/object_hand_detection.py

+45-31
Original file line numberDiff line numberDiff line change
@@ -192,38 +192,52 @@ def rt_loop(self):
192192
dflt_conf_vec = np.zeros(n_classes, dtype=np.float64)
193193
right_hand_cid = n_classes - 2
194194
left_hand_cid = n_classes - 1
195-
196-
hands_preds = predict_hands(hand_model=self.hand_model, img0=img0,
197-
img_size=self._inference_img_size, device=self.device)
198-
199-
hand_centers = [center.xywh.tolist()[0][0] for center in hands_preds.boxes][:2]
200-
hands_label = []
201-
if len(hand_centers) == 2:
202-
if hand_centers[0] > hand_centers[1]:
203-
hands_label.append(right_hand_cid)
204-
hands_label.append(left_hand_cid)
205-
elif hand_centers[0] <= hand_centers[1]:
206-
hands_label.append(left_hand_cid)
207-
hands_label.append(right_hand_cid)
208-
elif len(hand_centers) == 1:
209-
if hand_centers[0] > width//2:
210-
hands_label.append(right_hand_cid)
211-
elif hand_centers[0] <= width//2:
212-
hands_label.append(left_hand_cid)
213195

214-
for xyxy, conf, cls_id in predict_image(
215-
img0,
216-
self.device,
217-
self.model,
218-
self.stride,
219-
self.imgsz,
220-
self.half,
221-
False,
222-
self._det_conf_thresh,
223-
self._iou_thr,
224-
None,
225-
self._agnostic_nms,
226-
):
196+
hand_cid_label_dict = {
197+
"hand (right)": right_hand_cid,
198+
"hand (left)": left_hand_cid,
199+
}
200+
201+
hand_boxes, hand_labels, hand_confs = predict_hands(hand_model=self.hand_model,
202+
img0=img0,
203+
img_size=self._inference_img_size,
204+
device=self.device)
205+
206+
hand_classids = [hand_cid_label_dict[label] for label in hand_labels]
207+
208+
209+
objcet_boxes, object_confs, objects_classids = predict_image(
210+
img0,
211+
self.device,
212+
self.model,
213+
self.stride,
214+
self.imgsz,
215+
self.half,
216+
False,
217+
self._det_conf_thresh,
218+
self._iou_thr,
219+
None,
220+
self._agnostic_nms,
221+
)
222+
223+
objcet_boxes.extend(hand_boxes)
224+
object_confs.extend(hand_confs)
225+
objects_classids.extend(hand_classids)
226+
for xyxy, conf, cls_id in zip(objcet_boxes, object_confs, objects_classids):
227+
# for xyxy, conf, cls_id in predict_image(
228+
# img0,
229+
# self.device,
230+
# self.model,
231+
# self.stride,
232+
# self.imgsz,
233+
# self.half,
234+
# False,
235+
# self._det_conf_thresh,
236+
# self._iou_thr,
237+
# None,
238+
# self._agnostic_nms,
239+
# ):
240+
227241
n_dets += 1
228242
msg.left.append(xyxy[0])
229243
msg.top.append(xyxy[1])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
from pathlib import Path
2+
from threading import Event, Lock, Thread
3+
from typing import Union
4+
5+
from cv_bridge import CvBridge
6+
import numpy as np
7+
from rclpy.callback_groups import MutuallyExclusiveCallbackGroup, ReentrantCallbackGroup
8+
from rclpy.node import Node, ParameterDescriptor, Parameter
9+
from sensor_msgs.msg import Image
10+
11+
from yolov7.detect_ptg import load_model, predict_image, predict_hands
12+
from yolov7.models.experimental import attempt_load
13+
import yolov7.models.yolo
14+
from yolov7.utils.torch_utils import TracedModel
15+
16+
from angel_system.utils.event import WaitAndClearEvent
17+
from angel_system.utils.simple_timer import SimpleTimer
18+
19+
from angel_msgs.msg import ObjectDetection2dSet
20+
from angel_utils import declare_and_get_parameters, RateTracker, DYNAMIC_TYPE
21+
from angel_utils import make_default_main
22+
23+
24+
BRIDGE = CvBridge()
25+
26+
27+
class YoloObjectDetector(Node):
28+
"""
29+
ROS node that runs the yolov7 object detector model and outputs
30+
`ObjectDetection2dSet` messages.
31+
"""
32+
33+
def __init__(self):
34+
super().__init__(self.__class__.__name__)
35+
log = self.get_logger()
36+
37+
# Inputs
38+
param_values = declare_and_get_parameters(
39+
self,
40+
[
41+
##################################
42+
# Required parameter (no defaults)
43+
("image_topic",),
44+
("det_topic",),
45+
("net_checkpoint",),
46+
##################################
47+
# Defaulted parameters
48+
("inference_img_size", 1280), # inference size (pixels)
49+
("det_conf_threshold", 0.7), # object confidence threshold
50+
("iou_threshold", 0.45), # IOU threshold for NMS
51+
("cuda_device_id", 0, DYNAMIC_TYPE), # cuda device: ID int or CPU
52+
("no_trace", True), # don`t trace model
53+
("agnostic_nms", False), # class-agnostic NMS
54+
# Runtime thread checkin heartbeat interval in seconds.
55+
("rt_thread_heartbeat", 0.1),
56+
# If we should enable additional logging to the info level
57+
# about when we receive and process data.
58+
("enable_time_trace_logging", False),
59+
],
60+
)
61+
self._image_topic = param_values["image_topic"]
62+
self._det_topic = param_values["det_topic"]
63+
self._model_ckpt_fp = Path(param_values["net_checkpoint"])
64+
65+
self._inference_img_size = param_values["inference_img_size"]
66+
self._det_conf_thresh = param_values["det_conf_threshold"]
67+
self._iou_thr = param_values["iou_threshold"]
68+
self._cuda_device_id = param_values["cuda_device_id"]
69+
self._no_trace = param_values["no_trace"]
70+
self._agnostic_nms = param_values["agnostic_nms"]
71+
72+
self._enable_trace_logging = param_values["enable_time_trace_logging"]
73+
74+
# Model
75+
self.model: Union[yolov7.models.yolo.Model, TracedModel]
76+
if not self._model_ckpt_fp.is_file():
77+
raise ValueError(
78+
f"Model checkpoint file did not exist: {self._model_ckpt_fp}"
79+
)
80+
(self.device, self.model, self.stride, self.imgsz) = load_model(
81+
str(self._cuda_device_id), self._model_ckpt_fp, self._inference_img_size
82+
)
83+
log.info(
84+
f"Loaded model with classes:\n"
85+
+ "\n".join(f'\t- "{n}"' for n in self.model.names)
86+
)
87+
88+
# Single slot for latest image message to process detection over.
89+
self._cur_image_msg: Image = None
90+
self._cur_image_msg_lock = Lock()
91+
92+
# Initialize ROS hooks
93+
self._subscription = self.create_subscription(
94+
Image,
95+
self._image_topic,
96+
self.listener_callback,
97+
1,
98+
callback_group=MutuallyExclusiveCallbackGroup(),
99+
)
100+
self._det_publisher = self.create_publisher(
101+
ObjectDetection2dSet,
102+
self._det_topic,
103+
1,
104+
callback_group=MutuallyExclusiveCallbackGroup(),
105+
)
106+
107+
if not self._no_trace:
108+
self.model = TracedModel(self.model, self.device, self._inference_img_size)
109+
110+
self.half = half = (
111+
self.device.type != "cpu"
112+
) # half precision only supported on CUDA
113+
if half:
114+
self.model.half() # to FP16
115+
116+
self._rate_tracker = RateTracker()
117+
log.info("Detector initialized")
118+
119+
# Create and start detection runtime thread and loop.
120+
log.info("Starting runtime thread...")
121+
# On/Off Switch for runtime loop
122+
self._rt_active = Event()
123+
self._rt_active.set()
124+
# seconds to occasionally time out of the wait condition for the loop
125+
# to check if it is supposed to still be alive.
126+
self._rt_active_heartbeat = param_values["rt_thread_heartbeat"]
127+
# Condition that the runtime should perform processing
128+
self._rt_awake_evt = WaitAndClearEvent()
129+
self._rt_thread = Thread(target=self.rt_loop, name="prediction_runtime")
130+
self._rt_thread.daemon = True
131+
self._rt_thread.start()
132+
133+
def listener_callback(self, image: Image):
134+
"""
135+
Callback function for image messages. Runs the berkeley object detector
136+
on the image and publishes an ObjectDetectionSet2d message for the image.
137+
"""
138+
log = self.get_logger()
139+
if self._enable_trace_logging:
140+
log.info(f"Received image with TS: {image.header.stamp}")
141+
with self._cur_image_msg_lock:
142+
self._cur_image_msg = image
143+
self._rt_awake_evt.set()
144+
145+
def rt_alive(self) -> bool:
146+
"""
147+
Check that the prediction runtime is still alive and raise an exception
148+
if it is not.
149+
"""
150+
alive = self._rt_thread.is_alive()
151+
if not alive:
152+
self.get_logger().warn("Runtime thread no longer alive.")
153+
self._rt_thread.join()
154+
return alive
155+
156+
def rt_stop(self) -> None:
157+
"""
158+
Indicate that the runtime loop should cease.
159+
"""
160+
self._rt_active.clear()
161+
162+
def rt_loop(self):
163+
log = self.get_logger()
164+
log.info("Runtime loop starting")
165+
enable_trace_logging = self._enable_trace_logging
166+
167+
while self._rt_active.wait(0): # will quickly return false if cleared.
168+
if self._rt_awake_evt.wait_and_clear(self._rt_active_heartbeat):
169+
with self._cur_image_msg_lock:
170+
if self._cur_image_msg is None:
171+
continue
172+
image = self._cur_image_msg
173+
self._cur_image_msg = None
174+
175+
if enable_trace_logging:
176+
log.info(f"[rt-loop] Processing image TS={image.header.stamp}")
177+
# Convert ROS img msg to CV2 image
178+
img0 = BRIDGE.imgmsg_to_cv2(image, desired_encoding="bgr8")
179+
180+
print(f"img0: {img0.shape}")
181+
width, height = self._inference_img_size
182+
183+
msg = ObjectDetection2dSet()
184+
msg.header.stamp = self.get_clock().now().to_msg()
185+
msg.header.frame_id = image.header.frame_id
186+
msg.source_stamp = image.header.stamp
187+
msg.label_vec[:] = self.model.names
188+
189+
n_classes = len(self.model.names) + 2 # accomedate 2 hands
190+
n_dets = 0
191+
192+
dflt_conf_vec = np.zeros(n_classes, dtype=np.float64)
193+
right_hand_cid = n_classes - 2
194+
left_hand_cid = n_classes - 1
195+
196+
hands_preds = predict_hands(hand_model=self.hand_model, img0=img0,
197+
img_size=self._inference_img_size, device=self.device)
198+
199+
hand_centers = [center.xywh.tolist()[0][0] for center in hands_preds.boxes][:2]
200+
hands_label = []
201+
if len(hand_centers) == 2:
202+
if hand_centers[0] > hand_centers[1]:
203+
hands_label.append(right_hand_cid)
204+
hands_label.append(left_hand_cid)
205+
elif hand_centers[0] <= hand_centers[1]:
206+
hands_label.append(left_hand_cid)
207+
hands_label.append(right_hand_cid)
208+
elif len(hand_centers) == 1:
209+
if hand_centers[0] > width//2:
210+
hands_label.append(right_hand_cid)
211+
elif hand_centers[0] <= width//2:
212+
hands_label.append(left_hand_cid)
213+
214+
215+
216+
for xyxy, conf, cls_id in predict_image(
217+
img0,
218+
self.device,
219+
self.model,
220+
self.stride,
221+
self.imgsz,
222+
self.half,
223+
False,
224+
self._det_conf_thresh,
225+
self._iou_thr,
226+
None,
227+
self._agnostic_nms,
228+
):
229+
n_dets += 1
230+
msg.left.append(xyxy[0])
231+
msg.top.append(xyxy[1])
232+
msg.right.append(xyxy[2])
233+
msg.bottom.append(xyxy[3])
234+
235+
dflt_conf_vec[cls_id] = conf
236+
# copies data into array
237+
msg.label_confidences.extend(dflt_conf_vec)
238+
# reset before next passthrough
239+
dflt_conf_vec[cls_id] = 0.0
240+
241+
msg.num_detections = n_dets
242+
243+
self._det_publisher.publish(msg)
244+
245+
self._rate_tracker.tick()
246+
log.info(
247+
f"Objects Detection Rate: {self._rate_tracker.get_rate_avg()} Hz",
248+
)
249+
250+
def destroy_node(self):
251+
print("Stopping runtime")
252+
self.rt_stop()
253+
print("Shutting down runtime thread...")
254+
self._rt_active.clear() # make RT active flag "False"
255+
self._rt_thread.join()
256+
print("Shutting down runtime thread... Done")
257+
super().destroy_node()
258+
259+
260+
# Don't really want to use *all* available threads...
261+
# 3 threads because:
262+
# - 1 known subscriber which has their own group
263+
# - 1 for default group
264+
# - 1 for publishers
265+
main = make_default_main(YoloObjectDetector, multithreaded_executor=3)
266+
267+
268+
if __name__ == "__main__":
269+
main()

0 commit comments

Comments
 (0)