This repository was archived by the owner on Mar 12, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
129 lines (110 loc) · 4.3 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""Vehicle tracking pipeline."""
import os
import time
from functools import partial
import cv2
import fire
import torch
from tqdm.auto import tqdm
from ultralytics import YOLO
from ultralytics.engine.results import Boxes
from yaml import safe_load
from somecompany.logger import logging
from somecompany.segmentation import RoadPlaneSegmenter, test_intersection
from somecompany.tracking import VehicleTracker
from somecompany.viz import overlay_mask, plot_vehicle_tracks
LOG = logging.getLogger(__name__)
@torch.inference_mode()
def process_video(
source: str, checkpoint_name: str = "yolov8n.pt", device: str = "cuda", output: str = "screen"
) -> None:
"""Process a video and detect vehicles.
Args:
source (str): Path to the video file.
checkpoint_name (str): Name of the checkpoint to load.
device (str): Device to run inference on.
output (str): Output path for the processed video. If "screen" the video will be displayed on screen.
"""
if not os.path.exists(source):
ValueError(f"Video path {source} doesn't exist")
t0_all = time.time()
with open("config.yml", encoding="utf-8") as f:
config = safe_load(f)
config.update({"source": source, "checkpoint_name": checkpoint_name, "device": device, "output": output})
LOG.info(f"Using config: {config}")
model = YOLO(checkpoint_name).to(device)
model.fuse() # Fuse Conv2d + BatchNorm2d layers
get_tracks = partial(
model.track,
classes=config["coco_relevant_classes"],
tracker="bytetrack.yaml",
persist=True,
verbose=LOG.isEnabledFor(logging.DEBUG),
)
road_plane_segmenter = RoadPlaneSegmenter(**config["road_plane_segmenter"])
tracker = VehicleTracker(**config["tracker"])
cap = cv2.VideoCapture(source)
frames_count, fps, width, height = (
cap.get(cv2.CAP_PROP_FRAME_COUNT),
cap.get(cv2.CAP_PROP_FPS),
cap.get(cv2.CAP_PROP_FRAME_WIDTH),
cap.get(cv2.CAP_PROP_FRAME_HEIGHT),
)
LOG.info(f"Input video #frames ={frames_count}, fps ={fps}, width ={width}, height={height}")
frameNumber = 0
if output == "screen":
cv2.namedWindow("output", cv2.WINDOW_AUTOSIZE)
else:
os.makedirs(os.path.dirname(output), exist_ok=True)
writer = cv2.VideoWriter(output, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height)))
bar = tqdm(total=frames_count, desc="Processing frames", unit="frames")
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
mask = road_plane_segmenter.update(frame)
detections: Boxes = get_tracks(frame)[0].boxes
if detections.id is None:
LOG.debug("No detections in frame, skipping update")
else:
LOG.debug(f"Found {len(detections)} detections in frame")
keep_detections = test_intersection(detections.xyxy.long().numpy(), mask)
detections = detections[keep_detections]
LOG.debug(f"Kept {len(detections)} detections after testing the intersection with the road plane")
tracker.update(detections)
frame = overlay_mask(frame, mask)
frame = plot_vehicle_tracks(frame, tracker)
current_fps = frameNumber / (time.time() - t0_all)
cv2.putText(
frame,
(
f"Frame#: {frameNumber}/{int(frames_count)}, incoming: {tracker.num_incoming}, "
f"outgoing: {tracker.num_outgoing}, FPS: {current_fps:.1f}"
),
(0, 30),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(2, 10, 200),
2,
)
if output == "screen":
cv2.imshow("output", frame)
else:
writer.write(frame)
key = cv2.waitKey(1)
# Quit when 'q' is pressed
if key == ord("q") or key == ord("Q") or key == 27:
break
elif key == ord("k") or key == ord("K") or key == 32:
cv2.waitKey(0)
frameNumber = frameNumber + 1
bar.update(1)
cap.release()
cv2.destroyAllWindows()
if output != "screen":
writer.release()
t1_all = time.time()
time_taken = t1_all - t0_all
print(f"Done. process_video took ({time_taken:.3f}s) @ {frameNumber / time_taken:.1f} FPS")
if __name__ == "__main__":
fire.Fire(process_video)