Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom simple CNN + alexnet, densenet(121+161), efficientnet, mobilenet, googlenet, vgg. InceptionV3 seems to not work yet due to AuxLogits #141

Merged
merged 1 commit into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@

from util.logger import Logger
from .bagnet import BagNet, bagnet9, bagnet17, bagnet33
from .simple_cnn import SimpleCNN


def get_batch_size(model: nn.Module, device: torch.device, input_shape: Tuple[int, int, int], output_shape: Tuple[int],
dataset_size: int, min_batch_size: int = 8, max_batch_size: int = 256, num_iterations: int = 5) -> int:
dataset_size: int, min_batch_size: int = 8, max_batch_size: int = 256,
num_iterations: int = 5) -> int:
"""
https://towardsdatascience.com/a-batch-too-large-finding-the-batch-size-that-fits-on-gpus-aef70902a9f1
"""
Expand Down
29 changes: 29 additions & 0 deletions src/core/processors/cg_image_classification/nn_model/simple_cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch
import torch.nn as nn

from core.processors.cg_image_classification.dataset import ImgDataset


class SimpleCNN(torch.nn.Module):

def __init__(self, dataset: ImgDataset, size_hid: int, dropout: float):
super().__init__()

layers = []
layers.append(
nn.Conv2d(dataset.img_color_channels, size_hid, kernel_size=1, stride=1, padding=0, bias=True))
layers.append(
nn.Conv2d(size_hid, size_hid, kernel_size=1, stride=1, padding=0, bias=True))

layers.append(nn.BatchNorm2d(size_hid))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Dropout(dropout))
layers.append(nn.AvgPool2d(1, stride=1))

self.stack = nn.Sequential(*layers)
self.fc = nn.Linear(size_hid * 10000, dataset.num_classes)

def forward(self, x):
x = self.stack(x)
x = torch.flatten(x, 1)
return self.fc(x)
39 changes: 38 additions & 1 deletion src/core/processors/cg_image_classification/train_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from core.processors.cg_image_classification.dataset import Datasets, ImgDataset
from core.processors.cg_image_classification.dataset.dataloader import create_bodmas_train_val_loader
from core.processors.cg_image_classification.hparams import HPARAMS, get_hparam_value
from core.processors.cg_image_classification.nn_model import bagnet17, bagnet9, bagnet33
from core.processors.cg_image_classification.nn_model import bagnet17, bagnet9, bagnet33, SimpleCNN
from util.logger import Logger


Expand Down Expand Up @@ -58,6 +58,43 @@ def get_model() -> torch.nn.Module:
weights = None if not hp_model_pretrained else torchvision.models.ResNet50_Weights.IMAGENET1K_V2
MODEL = torchvision.models.resnet50(weights=weights)
MODEL.fc = torch.nn.Linear(512 * 4, DATASET.num_classes)
elif hp_model.startswith("alexnet"):
weights = None if not hp_model_pretrained else torchvision.models.AlexNet_Weights.IMAGENET1K_V1
MODEL = torchvision.models.alexnet(weights=weights)
MODEL.classifier[6] = torch.nn.Linear(4096, DATASET.num_classes)
elif hp_model.startswith("densenet"):
if hp_model == "densenet121":
weights = None if not hp_model_pretrained else torchvision.models.DenseNet121_Weights.IMAGENET1K_V1
MODEL = torchvision.models.densenet121(weights=weights)
MODEL.classifier = torch.nn.Linear(MODEL.classifier.in_features, DATASET.num_classes)
elif hp_model == "densenet161":
weights = None if not hp_model_pretrained else torchvision.models.DenseNet161_Weights.IMAGENET1K_V1
MODEL = torchvision.models.densenet161(weights=weights)
MODEL.classifier = torch.nn.Linear(MODEL.classifier.in_features, DATASET.num_classes)
elif hp_model.startswith("efficientnet"):
weights = None if not hp_model_pretrained else torchvision.models.EfficientNet_V2_S_Weights.IMAGENET1K_V1
MODEL = torchvision.models.efficientnet_v2_s(weights=weights)
MODEL.classifier[1] = torch.nn.Linear(MODEL.classifier[1].in_features, DATASET.num_classes)
elif hp_model.startswith("googlenet"):
weights = None if not hp_model_pretrained else torchvision.models.GoogLeNet_Weights.IMAGENET1K_V1
MODEL = torchvision.models.googlenet(weights=weights)
MODEL.fc = torch.nn.Linear(1024, DATASET.num_classes)
# elif hp_model.startswith("inception"):
# weights = None if not hp_model_pretrained else torchvision.models.Inception_V3_Weights.IMAGENET1K_V1
# MODEL = torchvision.models.inception_v3(weights=weights)
# MODEL.AuxLogits = InceptionAux(768, DATASET.num_classes)
# MODEL.fc = torch.nn.Linear(2048, DATASET.num_classes)
elif hp_model.startswith("mobilenet"):
weights = None if not hp_model_pretrained else torchvision.models.MobileNet_V3_Small_Weights.IMAGENET1K_V1
MODEL = torchvision.models.mobilenet_v3_small(weights=weights)
MODEL.classifier[-1] = torch.nn.Linear(MODEL.classifier[-1].in_features, DATASET.num_classes)
elif hp_model.startswith("vgg"):
weights = None if not hp_model_pretrained else torchvision.models.VGG11_Weights.IMAGENET1K_V1
MODEL = torchvision.models.vgg11(weights=weights)
MODEL.classifier[-1] = torch.nn.Linear(4096, DATASET.num_classes)

elif hp_model.startswith("simplecnn"):
MODEL = SimpleCNN(DATASET, 32, dropout=0.5)

if MODEL is None:
raise Exception(f"Unknown model: {hp_model}")
Expand Down
Loading