-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathClassModel.lua
135 lines (121 loc) · 4.45 KB
/
ClassModel.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
-- Implementation of Class Model Visualisation described in:
-- Simonyan, K., Vedaldi, A., & Zisserman, A. (2013). Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps. arXiv Preprint arXiv:1312.6034, 1–8. Computer Vision and Pattern Recognition.
--
-- Code partially based on https://github.com/torch/nn/blob/master/Sequential.lua
local ClassModel, Parent = torch.class('nn.ClassModel', 'nn.Sequential')
function ClassModel:__init(initial)
Parent.__init(self)
if initial:dim() == 3 then
self.nInputPlane = initial:size(1)
self.iH = initial:size(2)
self.iW = initial:size(2)
elseif initial:dim() == 2 then
self.nInputPlane = 1
self.iH = initial:size(1)
self.iW = initial:size(2)
else
error('Need 2 or 3 input dimensions')
end
self.weight = initial:clone()
self.gradWeight = torch.zeros(self.nInputPlane, self.iH, self.iW)
end
function ClassModel:input()
if torch.Tensor.type(self.weight) == 'torch.CudaTensor' then
input = self.weight:clone():resize(self.nInputPlane, self.iH, self.iW, 1)
else
input = self.weight:clone():resize(1, self.nInputPlane, self.iH, self.iW)
end
return input
end
function ClassModel:add(module)
-- Removing LogSoftMax works better according to Simonyan et al.
-- Need modified dropout, so we can forwardprop without dropout
-- and backprop without errors
local function clean(model)
for i=1,#model.modules do
if tostring(model.modules[i]) == 'nn.LogSoftMax' then
table.remove(model.modules, i)
end
if tostring(model.modules[i]) == 'nn.Dropout' then
model.modules[i].train = false
model.modules[i].updateGradInput = function(input, gradOutput)
return model.modules[i].gradInput
end
end
end
end
if #self.modules == 0 then
self.gradInput = module.gradInput
end
table.insert(self.modules, module)
self.output = module.output
return self
end
function ClassModel:updateOutput(input)
local currentOutput = self:input()
for i=1,#self.modules do
currentOutput = self.modules[i]:updateOutput(currentOutput)
end
self.output = currentOutput
return currentOutput
end
function ClassModel:parameters()
return {self.weight}, {self.gradWeight}
end
function ClassModel:updateGradInput(input, gradOutput)
input = self:input()
local currentGradOutput = gradOutput
local currentModule = self.modules[#self.modules]
for i=#self.modules-1,1,-1 do
local previousModule = self.modules[i]
currentGradOutput = currentModule:updateGradInput(previousModule.output, currentGradOutput)
currentModule = previousModule
end
currentGradOutput = currentModule:updateGradInput(input, currentGradOutput)
self.gradInput = currentGradOutput
return currentGradOutput
end
function ClassModel:accGradParameters(input, gradOutput, scale)
if torch.Tensor.type(self.weight) == 'torch.CudaTensor' then
input = self.weight:clone():resize(self.nInputPlane, self.iH, self.iW, 1)
else
input = self.weight:clone():resize(1, self.nInputPlane, self.iH, self.iW)
end
scale = scale or 1
local currentGradOutput = gradOutput
local currentModule = self.modules[#self.modules]
for i=#self.modules-1,1,-1 do
local previousModule = self.modules[i]
currentModule:accGradParameters(previousModule.output, currentGradOutput, scale)
currentGradOutput = currentModule.gradInput
currentModule = previousModule
end
currentModule:accGradParameters(input, currentGradOutput, scale)
self.gradWeight = self.gradWeight:copy(currentModule.gradInput)
return self.gradWeight
end
function ClassModel:reset(initial)
self.weight = initial:clone()
end
function ClassModel:image()
im = self.weight:clone()
im = im:add(-im:min())
im = im:div(im:max())
return im
end
function ClassModel:__tostring__()
local tab = ' '
local line = '\n'
local next = ' -> '
local str = 'nn.ClassModel'
str = str .. ' {' .. line .. tab .. '[input'
for i=1,#self.modules do
str = str .. next .. '(' .. i .. ')'
end
str = str .. next .. 'output]'
for i=1,#self.modules do
str = str .. line .. tab .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab)
end
str = str .. line .. '}'
return str
end