Skip to content

Commit b02c549

Browse files
committedDec 7, 2023
Added default main generator helper, updated all python nodes
1 parent 79ce797 commit b02c549

28 files changed

+210
-382
lines changed
 

‎ros/angel_system_nodes/angel_system_nodes/activity_classification/activity_classifier_tcn.py

+4-33
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,19 @@
55
"""
66
import json
77
from heapq import heappush, heappop
8-
import logging
98
from pathlib import Path
109
from threading import Condition, Event, Lock, Thread
1110
from typing import Callable
1211
from typing import Dict
1312
from typing import List
1413
from typing import Optional
15-
from typing import Tuple
1614

1715
import kwcoco
1816
from builtin_interfaces.msg import Time
1917
import numpy as np
2018
import numpy.typing as npt
21-
import rclpy
2219
from rclpy.callback_groups import MutuallyExclusiveCallbackGroup
23-
import rclpy.logging
2420
from rclpy.node import Node
25-
from rclpy.executors import MultiThreadedExecutor
2621
import torch
2722

2823
from angel_system.activity_classification.tcn_hpl.predict import (
@@ -39,7 +34,7 @@
3934
ObjectDetection2dSet,
4035
ActivityDetection,
4136
)
42-
from angel_utils import declare_and_get_parameters, RateTracker
37+
from angel_utils import declare_and_get_parameters, make_default_main, RateTracker
4338
from angel_utils.activity_classification import InputWindow, InputBuffer
4439
from angel_utils.conversion import time_to_int
4540
from angel_utils.object_detection import max_labels_and_confs
@@ -683,40 +678,16 @@ def _save_results(self):
683678

684679
def destroy_node(self):
685680
log = self.get_logger()
681+
log.info("Stopping node runtime")
682+
self.rt_stop()
686683
with SimpleTimer("Shutting down runtime thread...", log.info):
687684
self._rt_active.clear() # make RT active flag "False"
688685
self._rt_thread.join()
689686
self._save_results()
690687
super().destroy_node()
691688

692689

693-
def main():
694-
logging.basicConfig(
695-
format="[%(levelname)s] [%(asctime)s] [%(name)s.%(funcName)s]: %(message)s"
696-
)
697-
logging.getLogger().setLevel(logging.INFO)
698-
699-
rclpy.init()
700-
log = rclpy.logging.get_logger("main")
701-
702-
activity_classifier = ActivityClassifierTCN()
703-
704-
executor = MultiThreadedExecutor(num_threads=4)
705-
executor.add_node(activity_classifier)
706-
try:
707-
executor.spin()
708-
except KeyboardInterrupt:
709-
log.info("Keyboard interrupt, shutting down.\n")
710-
finally:
711-
log.info("Stopping node runtime")
712-
activity_classifier.rt_stop()
713-
714-
# Destroy the node explicitly
715-
# (optional - otherwise it will be done automatically
716-
# when the garbage collector destroys the node object)
717-
activity_classifier.destroy_node()
718-
719-
rclpy.shutdown()
690+
main = make_default_main(ActivityClassifierTCN, multithreaded_executor=4)
720691

721692

722693
if __name__ == "__main__":

‎ros/angel_system_nodes/angel_system_nodes/annotation_event_monitor.py

+7-14
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
from threading import Lock
33

44
from pynput import keyboard
5-
import rclpy
65
from rclpy.node import Node
76

87
from angel_msgs.msg import AnnotationEvent
8+
from angel_utils import make_default_main
99

1010

1111
class AnnotationEventMonitor(Node):
@@ -90,23 +90,16 @@ def on_press(self, key):
9090
self._publisher.publish(msg)
9191

9292

93-
def main():
94-
rclpy.init()
95-
96-
event_monitor = AnnotationEventMonitor()
97-
98-
keyboard_t = threading.Thread(target=event_monitor.monitor_keypress)
93+
def init_kb_thread(node):
94+
"""
95+
Initialize the
96+
"""
97+
keyboard_t = threading.Thread(target=node.monitor_keypress)
9998
keyboard_t.daemon = True
10099
keyboard_t.start()
101100

102-
rclpy.spin(event_monitor)
103-
104-
# Destroy the node explicitly
105-
# (optional - otherwise it will be done automatically
106-
# when the garbage collector destroys the node object)
107-
event_monitor.destroy_node()
108101

109-
rclpy.shutdown()
102+
main = make_default_main(AnnotationEventMonitor, pre_spin_callback=init_kb_thread)
110103

111104

112105
if __name__ == "__main__":

‎ros/angel_system_nodes/angel_system_nodes/audio/asr.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import simpleaudio as sa
1313

1414
from angel_msgs.msg import HeadsetAudioData, Utterance
15+
from angel_utils import make_default_main
1516

1617

1718
AUDIO_TOPIC = "audio_topic"
@@ -215,12 +216,7 @@ def asr_server_request_thread(self, audio_data, num_channels, sample_rate):
215216
self._publisher.publish(utterance_msg)
216217

217218

218-
def main():
219-
rclpy.init()
220-
asr = ASR()
221-
rclpy.spin(asr)
222-
asr.destroy_node()
223-
rclpy.shutdown()
219+
main = make_default_main(ASR)
224220

225221

226222
if __name__ == "__main__":

‎ros/angel_system_nodes/angel_system_nodes/audio/audio_player.py

+3-14
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import threading
2-
import time
32

43
import simpleaudio as sa
54

6-
import rclpy
75
from rclpy.node import Node
6+
87
from angel_msgs.msg import HeadsetAudioData
8+
from angel_utils import make_default_main
99

1010

1111
class AudioPlayer(Node):
@@ -101,18 +101,7 @@ def audio_playback_thread(self, audio_data, num_channels, sample_rate):
101101
audio_player_object.wait_done()
102102

103103

104-
def main():
105-
rclpy.init()
106-
107-
audio_player = AudioPlayer()
108-
109-
rclpy.spin(audio_player)
110-
111-
# Destroy the node explicitly
112-
# (optional - otherwise it will be done automatically
113-
# when the garbage collector destroys the node object)
114-
audio_player.destroy_node()
115-
rclpy.shutdown()
104+
main = make_default_main(AudioPlayer)
116105

117106

118107
if __name__ == "__main__":

‎ros/angel_system_nodes/angel_system_nodes/audio/emotion/base_emotion_detector.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import queue
2-
import rclpy
32
from rclpy.node import Node
43
from termcolor import colored
54
import threading
65
from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
76

87
from angel_msgs.msg import InterpretedAudioUserEmotion, InterpretedAudioUserIntent
98
from angel_utils import declare_and_get_parameters
9+
from angel_utils import make_default_main
10+
1011

1112
IN_EXPECT_USER_INTENT_TOPIC = "expect_user_intent_topic"
1213
IN_INTERP_USER_INTENT_TOPIC = "interp_user_intent_topic"
@@ -156,12 +157,7 @@ def _apply_filter(self, msg):
156157
return msg
157158

158159

159-
def main():
160-
rclpy.init()
161-
emotion_detector = BaseEmotionDetector()
162-
rclpy.spin(emotion_detector)
163-
emotion_detector.destroy_node()
164-
rclpy.shutdown()
160+
main = make_default_main(BaseEmotionDetector)
165161

166162

167163
if __name__ == "__main__":

‎ros/angel_system_nodes/angel_system_nodes/audio/emotion/gpt_emotion_detector.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
from langchain.chat_models import ChatOpenAI
44
import openai
55
import os
6-
import rclpy
76

8-
from ros.angel_system_nodes.angel_system_nodes.audio.emotion.base_emotion_detector import (
7+
from angel_system_nodes.audio.emotion.base_emotion_detector import (
98
BaseEmotionDetector,
109
LABEL_MAPPINGS,
1110
)
11+
from angel_utils import make_default_main
1212

1313
openai.organization = os.getenv("OPENAI_ORG_ID")
1414
openai.api_key = os.getenv("OPENAI_API_KEY")
@@ -90,12 +90,7 @@ def get_inference(self, msg):
9090
return (self.chain.run(utterance=msg.utterance_text), 0.5)
9191

9292

93-
def main():
94-
rclpy.init()
95-
emotion_detector = GptEmotionDetector()
96-
rclpy.spin(emotion_detector)
97-
emotion_detector.destroy_node()
98-
rclpy.shutdown()
93+
main = make_default_main(GptEmotionDetector)
9994

10095

10196
if __name__ == "__main__":

‎ros/angel_system_nodes/angel_system_nodes/audio/intent/base_intent_detector.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import queue
2-
import rclpy
32
from rclpy.node import Node
43
from termcolor import colored
54
import threading
65

76
from angel_msgs.msg import InterpretedAudioUserIntent, Utterance
87
from angel_utils import declare_and_get_parameters
8+
from angel_utils import make_default_main
99

1010
NEXT_STEP_KEYPHRASES = ["skip", "next", "next step"]
1111
PREV_STEP_KEYPHRASES = ["previous", "previous step", "last step", "go back"]
@@ -152,12 +152,7 @@ def _contains_phrase(self, utterance, phrases):
152152
return False
153153

154154

155-
def main():
156-
rclpy.init()
157-
intent_detector = BaseIntentDetector()
158-
rclpy.spin(intent_detector)
159-
intent_detector.destroy_node()
160-
rclpy.shutdown()
155+
main = make_default_main(BaseIntentDetector)
161156

162157

163158
if __name__ == "__main__":

‎ros/angel_system_nodes/angel_system_nodes/audio/intent/gpt_intent_detector.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
from langchain import PromptTemplate, FewShotPromptTemplate
22
from langchain.chains import LLMChain
33
from langchain.chat_models import ChatOpenAI
4-
from langchain.llms import OpenAI
54
import openai
65
import os
76
import rclpy
87

9-
from ros.angel_system_nodes.angel_system_nodes.audio.intent.base_intent_detector import (
8+
from angel_system_nodes.audio.intent.base_intent_detector import (
109
BaseIntentDetector,
1110
INTENT_LABELS,
1211
)
12+
from angel_utils import make_default_main
13+
1314

1415
openai.organization = os.getenv("OPENAI_ORG_ID")
1516
openai.api_key = os.getenv("OPENAI_API_KEY")
@@ -91,12 +92,7 @@ def detect_intents(self, msg):
9192
return self.chain.run(utterance=msg), 0.5
9293

9394

94-
def main():
95-
rclpy.init()
96-
intent_detector = GptIntentDetector()
97-
rclpy.spin(intent_detector)
98-
intent_detector.destroy_node()
99-
rclpy.shutdown()
95+
main = make_default_main(GptIntentDetector)
10096

10197

10298
if __name__ == "__main__":

‎ros/angel_system_nodes/angel_system_nodes/audio/intent/intent_detector.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from rclpy.node import Node
33

44
from angel_msgs.msg import InterpretedAudioUserIntent, Utterance
5+
from angel_utils import make_default_main
6+
57

68
# Please refer to labels defined in
79
# https://docs.google.com/document/d/1uuvSL5de3LVM9c0tKpRKYazDxckffRHf7IAcabSw9UA .
@@ -118,15 +120,7 @@ def contains_phrase(self, utterance, phrases):
118120
return False
119121

120122

121-
def main():
122-
rclpy.init()
123-
124-
intentDetector = IntentDetector()
125-
126-
rclpy.spin(intentDetector)
127-
128-
intentDetector.destroy_node()
129-
rclpy.shutdown()
123+
main = make_default_main(IntentDetector)
130124

131125

132126
if __name__ == "__main__":

‎ros/angel_system_nodes/angel_system_nodes/audio/intent/intent_to_command.py

+2-12
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
InterpretedAudioUserIntent,
77
SystemCommands,
88
)
9+
from angel_utils import make_default_main
910

1011

1112
# Parameter name constants
@@ -113,18 +114,7 @@ def intent_callback(self, intent: InterpretedAudioUserIntent) -> None:
113114
self._sys_cmd_publisher.publish(sys_cmd_msg)
114115

115116

116-
def main():
117-
rclpy.init()
118-
119-
node = IntentToCommand()
120-
rclpy.spin(node)
121-
122-
# Destroy the node explicitly
123-
# (optional - otherwise it will be done automatically
124-
# when the garbage collector destroys the node object)
125-
node.destroy_node()
126-
127-
rclpy.shutdown()
117+
main = make_default_main(IntentToCommand)
128118

129119

130120
if __name__ == "__main__":

‎ros/angel_system_nodes/angel_system_nodes/audio/question_answerer.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
import openai
33
import os
44
import queue
5-
import rclpy
65
from rclpy.node import Node
76
import requests
87
from termcolor import colored
98
import threading
109

1110
from angel_msgs.msg import InterpretedAudioUserEmotion, SystemTextResponse
1211
from angel_utils import declare_and_get_parameters
12+
from angel_utils import make_default_main
13+
1314

1415
openai.organization = os.getenv("OPENAI_ORG_ID")
1516
openai.api_key = os.getenv("OPENAI_API_KEY")
@@ -153,12 +154,7 @@ def _apply_filter(self, msg):
153154
return msg
154155

155156

156-
def main():
157-
rclpy.init()
158-
question_answerer = QuestionAnswerer()
159-
rclpy.spin(question_answerer)
160-
question_answerer.destroy_node()
161-
rclpy.shutdown()
157+
main = make_default_main(QuestionAnswerer)
162158

163159

164160
if __name__ == "__main__":

0 commit comments

Comments
 (0)