-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmake_inference.py
88 lines (76 loc) · 3.02 KB
/
make_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import os
import sys
import random
# Load the submission file from command-line arguments
if len(sys.argv) > 1:
submission_file = sys.argv[1]
else:
print("Please provide the path to the CSV file.")
sys.exit(1)
df = pd.read_csv(submission_file)
# Get image IDs from command-line arguments or select random images
if len(sys.argv) > 2:
image_ids = sys.argv[2:5]
else:
image_ids = random.sample(list(df['Image_ID'].unique()), 3)
# Set up the plot
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# Iterate over the selected images
for ax, image_id in zip(axes.flatten(), image_ids):
# Load the image folder path from command-line arguments
if len(sys.argv) > 2:
image_folder_path = sys.argv[2]
else:
print("Please provide the path to the image folder.")
sys.exit(1)
# Get image IDs from command-line arguments or select random images
if len(sys.argv) > 3:
image_ids = sys.argv[3:6]
else:
image_ids = random.sample(list(df['Image_ID'].unique()), 3)
# Set up the plot
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# Iterate over the selected images
for ax, image_id in zip(axes.flatten(), image_ids):
# Load the image
image_path = os.path.join(image_folder_path, image_id)
image = Image.open(image_path)
ax.imshow(image)
ax.set_title(image_id)
# Draw bounding boxes and class labels
image_data = df[df['Image_ID'] == image_id]
for _, row in image_data.iterrows():
ymin, xmin, ymax, xmax = row['ymin'], row['xmin'], row['ymax'], row['xmax']
width, height = xmax - xmin, ymax - ymin
rect = patches.Rectangle((xmin, ymin), width, height, linewidth=2, edgecolor='r', facecolor='none')
ax.add_patch(rect)
# Add class label
class_label = 'TPZ' if row['class'] == 'Trophozoite' else row['class']
ax.text(xmin, ymin - 10, class_label, color='r', fontsize=12, weight='bold')
# Adjust layout and show the plot
plt.tight_layout()
plt.show()
image = Image.open(image_path)
ax.imshow(image)
ax.set_title(image_id)
# Draw bounding boxes and class labels
image_data = df[df['Image_ID'] == image_id]
for _, row in image_data.iterrows():
ymin, xmin, ymax, xmax = row['ymin'], row['xmin'], row['ymax'], row['xmax']
width, height = xmax - xmin, ymax - ymin
rect = patches.Rectangle((xmin, ymin), width, height, linewidth=2, edgecolor='r', facecolor='none')
ax.add_patch(rect)
# Add class label
class_label = 'TPZ' if row['class'] == 'Trophozoite' else row['class']
ax.text(xmin, ymin - 10, class_label, color='r', fontsize=12, weight='bold')
# Adjust layout and show the plot
plt.tight_layout()
plt.show()
# Save the plot as a PNG file
output_file = "inference.png"
plt.savefig(output_file)
print(f"Plot saved as {output_file}")