-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhw1_dt.py
48 lines (38 loc) · 1.4 KB
/
hw1_dt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import numpy as np
import utils as Util
class DecisionTree():
def __init__(self):
self.clf_name = "DecisionTree"
self.root_node = None
# TODO: train Decision Tree
def train(self, features, labels):
# features: List[List[float]], labels: List[int]
# init
# TODO: predic function
def predict(self, features):
# features: List[List[any]]
# return List[int]
class TreeNode(object):
def __init__(self, features, labels, num_cls):
# features: List[List[any]], labels: List[int], num_cls: int
self.features = features
self.labels = labels
self.children = []
self.num_cls = num_cls
# find the most common labels in current node
count_max = 0
for label in np.unique(labels):
if self.labels.count(label) > count_max:
count_max = labels.count(label)
self.cls_max = label
# splitable is false when all features belongs to one class
if len(np.unique(labels)) < 2:
self.splittable = False
else:
self.splittable = True
self.dim_split = None # the index of the feature to be split
self.feature_uniq_split = None # the possible unique values of the feature to be split
# TODO: implement split function
def split(self):
# TODO:treeNode predict function
def predict(self, feature):