-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_test_val_split.py
74 lines (60 loc) · 3.12 KB
/
train_test_val_split.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import shutil
root = r"D:\kaggle_practice\Blood_Cell_Detection"
test_set = r"D:\kaggle_practice\Blood_Cell_Detection\Split_declaration\test.txt"
train_set = r"D:\kaggle_practice\Blood_Cell_detection\Split_declaration\train.txt"
val_set = r"D:\kaggle_practice\Blood_Cell_detection\Split_declaration\val.txt"
train_path = r"D:\kaggle_practice\Blood_Cell_Detection\train\images"
train_label = r"D:\kaggle_practice\Blood_Cell_Detection\train\labels"
test_path = r"D:\kaggle_practice\Blood_Cell_Detection\test\images"
test_label = r"D:\kaggle_practice\Blood_Cell_Detection\test\labels"
val_path = r"D:\kaggle_practice\Blood_Cell_Detection\valid\images"
val_label = r"D:\kaggle_practice\Blood_Cell_Detection\valid\labels"
def createTrainingSet():
train_file = open(train_set, 'r')
for line in train_file:
# print(line.strip())
for filename in os.listdir(os.path.join(root, "images")):
# f = os.path.join(root, filename)
# print(filename.rsplit('.', 1)[0])
if filename.rsplit('.', 1)[0] == line.strip():
# print("True")
# print(os.path.join(root, filename))
shutil.copy(os.path.join(os.path.join(root, "images"), filename), train_path)
print("moved ", filename, " To training set")
for filename in os.listdir(os.path.join(root, "labels")):
if filename.rsplit('.', 1)[0] == line.strip():
shutil.copy(os.path.join(os.path.join(root, "labels"), filename), train_label)
def createTestingSet():
test_file = open(test_set, 'r')
for line in test_file:
# print(line.strip())
for filename in os.listdir(os.path.join(root, "images")):
# f = os.path.join(root, filename)
# print(filename.rsplit('.', 1)[0])
if filename.rsplit('.', 1)[0] == line.strip():
# print("True")
# print(os.path.join(root, filename))
shutil.copy(os.path.join(os.path.join(root, "images"), filename), test_path)
print("moved ", filename, " To testing set")
for filename in os.listdir(os.path.join(root, "labels")):
if filename.rsplit('.', 1)[0] == line.strip():
shutil.copy(os.path.join(os.path.join(root, "labels"), filename), test_label)
def createValidationSet():
val_file = open(val_set, 'r')
for line in val_file:
# print(line.strip())
for filename in os.listdir(os.path.join(root, "images")):
# f = os.path.join(root, filename)
# print(filename.rsplit('.', 1)[0])
if filename.rsplit('.', 1)[0] == line.strip():
# print("True")
# print(os.path.join(root, filename))
shutil.copy(os.path.join(os.path.join(root, "images"), filename), val_path)
print("moved ", filename, " To Validation set")
for filename in os.listdir(os.path.join(root, "labels")):
if filename.rsplit('.', 1)[0] == line.strip():
shutil.copy(os.path.join(os.path.join(root, "labels"), filename), val_label)
createTrainingSet()
createTestingSet()
createValidationSet()