diff --git a/README.md b/README.md index 7f08c01..c8b8a69 100644 --- a/README.md +++ b/README.md @@ -1,24 +1,73 @@ -## Implementation of [MobileNetV3](https://arxiv.org/abs/1905.02244) in PyTorch +# [MobileNet V3](https://arxiv.org/abs/1905.02244) in PyTorch -**Arxiv**: https://arxiv.org/abs/1905.02244 +MobileNet V3 implementation using PyTorch -After 300 epochs MobileNetV3L reaches **Acc@1**: 74.3025 **Acc@5**: 91.8342 +**Arxiv:** https://arxiv.org/abs/1905.02244 -### Updates +## Table of Contents -* 2022.05.13: - - Weights are uploaded to the `weights` folder. `last.ckpt` is checkpoint (88.3MB) (includes model, model_ema, optimizer, ...) and last.pth is model with - Exponential Moving Average (11.2MB) and converted to `half()` tensor. +- [Installation](#installation) +- [Usage](#usage) +- [Datasets](#datasets) +- [Training](#training) +- [Reference](#reference) -### Dataset +## Installation -Specify the IMAGENET data folder in the `main.py` file. +```bash +pip install -r requirements.txt +``` + +## Usage + +```python +import torch +from torchvision import transforms + +from PIL import Image +from nets import mobilenet_v3_small, mobilenet_v3_large +from assets.meta import IMAGENET_CATEGORIES + +model = mobilenet_v3_large() +model.load_state_dict("./weights/mobilenet_v3_large.pt") # weights ported from torchvision +model.float() # converting weights to float32 + + +def preprocess_image(image_path): + transform = transforms.Compose([ + transforms.Resize((224, 224)), # Resize the image to match the model's input size + transforms.ToTensor(), # Convert the image to a PyTorch tensor + transforms.Normalize( + mean=[0.485, 0.456, 0.406], # Normalize the image using the mean and std of ImageNet + std=[0.229, 0.224, 0.225] + ), + ]) + + image = Image.open(image_path) + image = transform(image).unsqueeze(0) # Add a batch dimension + return image + + +def inference(model, image_path): + model.eval() + + input_image = preprocess_image(image_path) + with torch.no_grad(): + output = model(input_image) -``` python -parser.add_argument("--data-path", default="../../Projects/Datasets/IMAGENET/", type=str, help="dataset path") + _, predicted_class = output.max(1) + print(f"Predicted class index: {predicted_class.item()}") + + predicted_label = IMAGENET_CATEGORIES[predicted_class.item()] + print(f"Predicted class label: {predicted_label}") + + +inference(model, "assets/tabby_cat.jpg") ``` -IMAGENET folder structure: +## Datasets + +ImageNet, folder structure: ``` ├── IMAGENET @@ -26,18 +75,15 @@ IMAGENET folder structure: ├── [class_id1]/xxx.{jpg,png,jpeg} ├── [class_id2]/xxy.{jpg,png,jpeg} ├── [class_id3]/xxz.{jpg,png,jpeg} - .... + ... ├── val ├── [class_id1]/xxx1.{jpg,png,jpeg} ├── [class_id2]/xxy2.{jpg,png,jpeg} ├── [class_id3]/xxz3.{jpg,png,jpeg} + ... ``` -#### Augmentation: - -`AutoAugment` for IMAGENET is used as a default augmentation. The interpolation mode is `BILINEAR` - -### Train +## Training Run `main.sh` (for DDP) file by running the following command: @@ -51,35 +97,6 @@ bash main.sh torchrun --nproc_per_node=@num_gpu main.py --epochs 300 --batch-size 512 --lr 0.064 --lr-step-size 2 --lr-gamma 0.973 --random-erase 0.2 ``` -To resume the training add `--resume @path_to_checkpoint` to `main.sh` +## Reference -Run `main.py` for `DataParallel` training. - -The training config taken -from [official torchvision models' training config](https://github.com/pytorch/vision/tree/970ba3555794d163daca0ab95240d21e3035c304/references/classification) - -### PyTorch Implementation of [PolyLoss: A Polynomial Expansion Perspective of Classification Loss Functions](https://arxiv.org/abs/2204.12511) -```py -import torch -import torch.nn.functional as F - -class PolyLoss: - """ [https://arxiv.org/abs/2204.12511] """ - - def __init__(self, reduction='none', label_smoothing=0.0) -> None: - super().__init__() - self.reduction = reduction - self.label_smoothing = label_smoothing - self.softmax = torch.nn.Softmax(dim=-1) - - def __call__(self, prediction, target, epsilon=1.0): - ce = F.cross_entropy(prediction, target, reduction=self.reduction, label_smoothing=self.label_smoothing) - pt = torch.sum(F.one_hot(target, num_classes=1000) * self.softmax(prediction), dim=-1) - pl = torch.mean(ce + epsilon * (1 - pt)) - return pl -``` - -### Evaluation -```commandline -python main.py --test -``` +- [torchvision](https://github.com/pytorch/vision) \ No newline at end of file diff --git a/assets/meta.py b/assets/meta.py new file mode 100644 index 0000000..44b03e3 --- /dev/null +++ b/assets/meta.py @@ -0,0 +1,1002 @@ +IMAGENET_CATEGORIES = [ + "tench", + "goldfish", + "great white shark", + "tiger shark", + "hammerhead", + "electric ray", + "stingray", + "cock", + "hen", + "ostrich", + "brambling", + "goldfinch", + "house finch", + "junco", + "indigo bunting", + "robin", + "bulbul", + "jay", + "magpie", + "chickadee", + "water ouzel", + "kite", + "bald eagle", + "vulture", + "great grey owl", + "European fire salamander", + "common newt", + "eft", + "spotted salamander", + "axolotl", + "bullfrog", + "tree frog", + "tailed frog", + "loggerhead", + "leatherback turtle", + "mud turtle", + "terrapin", + "box turtle", + "banded gecko", + "common iguana", + "American chameleon", + "whiptail", + "agama", + "frilled lizard", + "alligator lizard", + "Gila monster", + "green lizard", + "African chameleon", + "Komodo dragon", + "African crocodile", + "American alligator", + "triceratops", + "thunder snake", + "ringneck snake", + "hognose snake", + "green snake", + "king snake", + "garter snake", + "water snake", + "vine snake", + "night snake", + "boa constrictor", + "rock python", + "Indian cobra", + "green mamba", + "sea snake", + "horned viper", + "diamondback", + "sidewinder", + "trilobite", + "harvestman", + "scorpion", + "black and gold garden spider", + "barn spider", + "garden spider", + "black widow", + "tarantula", + "wolf spider", + "tick", + "centipede", + "black grouse", + "ptarmigan", + "ruffed grouse", + "prairie chicken", + "peacock", + "quail", + "partridge", + "African grey", + "macaw", + "sulphur-crested cockatoo", + "lorikeet", + "coucal", + "bee eater", + "hornbill", + "hummingbird", + "jacamar", + "toucan", + "drake", + "red-breasted merganser", + "goose", + "black swan", + "tusker", + "echidna", + "platypus", + "wallaby", + "koala", + "wombat", + "jellyfish", + "sea anemone", + "brain coral", + "flatworm", + "nematode", + "conch", + "snail", + "slug", + "sea slug", + "chiton", + "chambered nautilus", + "Dungeness crab", + "rock crab", + "fiddler crab", + "king crab", + "American lobster", + "spiny lobster", + "crayfish", + "hermit crab", + "isopod", + "white stork", + "black stork", + "spoonbill", + "flamingo", + "little blue heron", + "American egret", + "bittern", + "crane bird", + "limpkin", + "European gallinule", + "American coot", + "bustard", + "ruddy turnstone", + "red-backed sandpiper", + "redshank", + "dowitcher", + "oystercatcher", + "pelican", + "king penguin", + "albatross", + "grey whale", + "killer whale", + "dugong", + "sea lion", + "Chihuahua", + "Japanese spaniel", + "Maltese dog", + "Pekinese", + "Shih-Tzu", + "Blenheim spaniel", + "papillon", + "toy terrier", + "Rhodesian ridgeback", + "Afghan hound", + "basset", + "beagle", + "bloodhound", + "bluetick", + "black-and-tan coonhound", + "Walker hound", + "English foxhound", + "redbone", + "borzoi", + "Irish wolfhound", + "Italian greyhound", + "whippet", + "Ibizan hound", + "Norwegian elkhound", + "otterhound", + "Saluki", + "Scottish deerhound", + "Weimaraner", + "Staffordshire bullterrier", + "American Staffordshire terrier", + "Bedlington terrier", + "Border terrier", + "Kerry blue terrier", + "Irish terrier", + "Norfolk terrier", + "Norwich terrier", + "Yorkshire terrier", + "wire-haired fox terrier", + "Lakeland terrier", + "Sealyham terrier", + "Airedale", + "cairn", + "Australian terrier", + "Dandie Dinmont", + "Boston bull", + "miniature schnauzer", + "giant schnauzer", + "standard schnauzer", + "Scotch terrier", + "Tibetan terrier", + "silky terrier", + "soft-coated wheaten terrier", + "West Highland white terrier", + "Lhasa", + "flat-coated retriever", + "curly-coated retriever", + "golden retriever", + "Labrador retriever", + "Chesapeake Bay retriever", + "German short-haired pointer", + "vizsla", + "English setter", + "Irish setter", + "Gordon setter", + "Brittany spaniel", + "clumber", + "English springer", + "Welsh springer spaniel", + "cocker spaniel", + "Sussex spaniel", + "Irish water spaniel", + "kuvasz", + "schipperke", + "groenendael", + "malinois", + "briard", + "kelpie", + "komondor", + "Old English sheepdog", + "Shetland sheepdog", + "collie", + "Border collie", + "Bouvier des Flandres", + "Rottweiler", + "German shepherd", + "Doberman", + "miniature pinscher", + "Greater Swiss Mountain dog", + "Bernese mountain dog", + "Appenzeller", + "EntleBucher", + "boxer", + "bull mastiff", + "Tibetan mastiff", + "French bulldog", + "Great Dane", + "Saint Bernard", + "Eskimo dog", + "malamute", + "Siberian husky", + "dalmatian", + "affenpinscher", + "basenji", + "pug", + "Leonberg", + "Newfoundland", + "Great Pyrenees", + "Samoyed", + "Pomeranian", + "chow", + "keeshond", + "Brabancon griffon", + "Pembroke", + "Cardigan", + "toy poodle", + "miniature poodle", + "standard poodle", + "Mexican hairless", + "timber wolf", + "white wolf", + "red wolf", + "coyote", + "dingo", + "dhole", + "African hunting dog", + "hyena", + "red fox", + "kit fox", + "Arctic fox", + "grey fox", + "tabby", + "tiger cat", + "Persian cat", + "Siamese cat", + "Egyptian cat", + "cougar", + "lynx", + "leopard", + "snow leopard", + "jaguar", + "lion", + "tiger", + "cheetah", + "brown bear", + "American black bear", + "ice bear", + "sloth bear", + "mongoose", + "meerkat", + "tiger beetle", + "ladybug", + "ground beetle", + "long-horned beetle", + "leaf beetle", + "dung beetle", + "rhinoceros beetle", + "weevil", + "fly", + "bee", + "ant", + "grasshopper", + "cricket", + "walking stick", + "cockroach", + "mantis", + "cicada", + "leafhopper", + "lacewing", + "dragonfly", + "damselfly", + "admiral", + "ringlet", + "monarch", + "cabbage butterfly", + "sulphur butterfly", + "lycaenid", + "starfish", + "sea urchin", + "sea cucumber", + "wood rabbit", + "hare", + "Angora", + "hamster", + "porcupine", + "fox squirrel", + "marmot", + "beaver", + "guinea pig", + "sorrel", + "zebra", + "hog", + "wild boar", + "warthog", + "hippopotamus", + "ox", + "water buffalo", + "bison", + "ram", + "bighorn", + "ibex", + "hartebeest", + "impala", + "gazelle", + "Arabian camel", + "llama", + "weasel", + "mink", + "polecat", + "black-footed ferret", + "otter", + "skunk", + "badger", + "armadillo", + "three-toed sloth", + "orangutan", + "gorilla", + "chimpanzee", + "gibbon", + "siamang", + "guenon", + "patas", + "baboon", + "macaque", + "langur", + "colobus", + "proboscis monkey", + "marmoset", + "capuchin", + "howler monkey", + "titi", + "spider monkey", + "squirrel monkey", + "Madagascar cat", + "indri", + "Indian elephant", + "African elephant", + "lesser panda", + "giant panda", + "barracouta", + "eel", + "coho", + "rock beauty", + "anemone fish", + "sturgeon", + "gar", + "lionfish", + "puffer", + "abacus", + "abaya", + "academic gown", + "accordion", + "acoustic guitar", + "aircraft carrier", + "airliner", + "airship", + "altar", + "ambulance", + "amphibian", + "analog clock", + "apiary", + "apron", + "ashcan", + "assault rifle", + "backpack", + "bakery", + "balance beam", + "balloon", + "ballpoint", + "Band Aid", + "banjo", + "bannister", + "barbell", + "barber chair", + "barbershop", + "barn", + "barometer", + "barrel", + "barrow", + "baseball", + "basketball", + "bassinet", + "bassoon", + "bathing cap", + "bath towel", + "bathtub", + "beach wagon", + "beacon", + "beaker", + "bearskin", + "beer bottle", + "beer glass", + "bell cote", + "bib", + "bicycle-built-for-two", + "bikini", + "binder", + "binoculars", + "birdhouse", + "boathouse", + "bobsled", + "bolo tie", + "bonnet", + "bookcase", + "bookshop", + "bottlecap", + "bow", + "bow tie", + "brass", + "brassiere", + "breakwater", + "breastplate", + "broom", + "bucket", + "buckle", + "bulletproof vest", + "bullet train", + "butcher shop", + "cab", + "caldron", + "candle", + "cannon", + "canoe", + "can opener", + "cardigan", + "car mirror", + "carousel", + "carpenter's kit", + "carton", + "car wheel", + "cash machine", + "cassette", + "cassette player", + "castle", + "catamaran", + "CD player", + "cello", + "cellular telephone", + "chain", + "chainlink fence", + "chain mail", + "chain saw", + "chest", + "chiffonier", + "chime", + "china cabinet", + "Christmas stocking", + "church", + "cinema", + "cleaver", + "cliff dwelling", + "cloak", + "clog", + "cocktail shaker", + "coffee mug", + "coffeepot", + "coil", + "combination lock", + "computer keyboard", + "confectionery", + "container ship", + "convertible", + "corkscrew", + "cornet", + "cowboy boot", + "cowboy hat", + "cradle", + "crane", + "crash helmet", + "crate", + "crib", + "Crock Pot", + "croquet ball", + "crutch", + "cuirass", + "dam", + "desk", + "desktop computer", + "dial telephone", + "diaper", + "digital clock", + "digital watch", + "dining table", + "dishrag", + "dishwasher", + "disk brake", + "dock", + "dogsled", + "dome", + "doormat", + "drilling platform", + "drum", + "drumstick", + "dumbbell", + "Dutch oven", + "electric fan", + "electric guitar", + "electric locomotive", + "entertainment center", + "envelope", + "espresso maker", + "face powder", + "feather boa", + "file", + "fireboat", + "fire engine", + "fire screen", + "flagpole", + "flute", + "folding chair", + "football helmet", + "forklift", + "fountain", + "fountain pen", + "four-poster", + "freight car", + "French horn", + "frying pan", + "fur coat", + "garbage truck", + "gasmask", + "gas pump", + "goblet", + "go-kart", + "golf ball", + "golfcart", + "gondola", + "gong", + "gown", + "grand piano", + "greenhouse", + "grille", + "grocery store", + "guillotine", + "hair slide", + "hair spray", + "half track", + "hammer", + "hamper", + "hand blower", + "hand-held computer", + "handkerchief", + "hard disc", + "harmonica", + "harp", + "harvester", + "hatchet", + "holster", + "home theater", + "honeycomb", + "hook", + "hoopskirt", + "horizontal bar", + "horse cart", + "hourglass", + "iPod", + "iron", + "jack-o'-lantern", + "jean", + "jeep", + "jersey", + "jigsaw puzzle", + "jinrikisha", + "joystick", + "kimono", + "knee pad", + "knot", + "lab coat", + "ladle", + "lampshade", + "laptop", + "lawn mower", + "lens cap", + "letter opener", + "library", + "lifeboat", + "lighter", + "limousine", + "liner", + "lipstick", + "Loafer", + "lotion", + "loudspeaker", + "loupe", + "lumbermill", + "magnetic compass", + "mailbag", + "mailbox", + "maillot", + "maillot tank suit", + "manhole cover", + "maraca", + "marimba", + "mask", + "matchstick", + "maypole", + "maze", + "measuring cup", + "medicine chest", + "megalith", + "microphone", + "microwave", + "military uniform", + "milk can", + "minibus", + "miniskirt", + "minivan", + "missile", + "mitten", + "mixing bowl", + "mobile home", + "Model T", + "modem", + "monastery", + "monitor", + "moped", + "mortar", + "mortarboard", + "mosque", + "mosquito net", + "motor scooter", + "mountain bike", + "mountain tent", + "mouse", + "mousetrap", + "moving van", + "muzzle", + "nail", + "neck brace", + "necklace", + "nipple", + "notebook", + "obelisk", + "oboe", + "ocarina", + "odometer", + "oil filter", + "organ", + "oscilloscope", + "overskirt", + "oxcart", + "oxygen mask", + "packet", + "paddle", + "paddlewheel", + "padlock", + "paintbrush", + "pajama", + "palace", + "panpipe", + "paper towel", + "parachute", + "parallel bars", + "park bench", + "parking meter", + "passenger car", + "patio", + "pay-phone", + "pedestal", + "pencil box", + "pencil sharpener", + "perfume", + "Petri dish", + "photocopier", + "pick", + "pickelhaube", + "picket fence", + "pickup", + "pier", + "piggy bank", + "pill bottle", + "pillow", + "ping-pong ball", + "pinwheel", + "pirate", + "pitcher", + "plane", + "planetarium", + "plastic bag", + "plate rack", + "plow", + "plunger", + "Polaroid camera", + "pole", + "police van", + "poncho", + "pool table", + "pop bottle", + "pot", + "potter's wheel", + "power drill", + "prayer rug", + "printer", + "prison", + "projectile", + "projector", + "puck", + "punching bag", + "purse", + "quill", + "quilt", + "racer", + "racket", + "radiator", + "radio", + "radio telescope", + "rain barrel", + "recreational vehicle", + "reel", + "reflex camera", + "refrigerator", + "remote control", + "restaurant", + "revolver", + "rifle", + "rocking chair", + "rotisserie", + "rubber eraser", + "rugby ball", + "rule", + "running shoe", + "safe", + "safety pin", + "saltshaker", + "sandal", + "sarong", + "sax", + "scabbard", + "scale", + "school bus", + "schooner", + "scoreboard", + "screen", + "screw", + "screwdriver", + "seat belt", + "sewing machine", + "shield", + "shoe shop", + "shoji", + "shopping basket", + "shopping cart", + "shovel", + "shower cap", + "shower curtain", + "ski", + "ski mask", + "sleeping bag", + "slide rule", + "sliding door", + "slot", + "snorkel", + "snowmobile", + "snowplow", + "soap dispenser", + "soccer ball", + "sock", + "solar dish", + "sombrero", + "soup bowl", + "space bar", + "space heater", + "space shuttle", + "spatula", + "speedboat", + "spider web", + "spindle", + "sports car", + "spotlight", + "stage", + "steam locomotive", + "steel arch bridge", + "steel drum", + "stethoscope", + "stole", + "stone wall", + "stopwatch", + "stove", + "strainer", + "streetcar", + "stretcher", + "studio couch", + "stupa", + "submarine", + "suit", + "sundial", + "sunglass", + "sunglasses", + "sunscreen", + "suspension bridge", + "swab", + "sweatshirt", + "swimming trunks", + "swing", + "switch", + "syringe", + "table lamp", + "tank", + "tape player", + "teapot", + "teddy", + "television", + "tennis ball", + "thatch", + "theater curtain", + "thimble", + "thresher", + "throne", + "tile roof", + "toaster", + "tobacco shop", + "toilet seat", + "torch", + "totem pole", + "tow truck", + "toyshop", + "tractor", + "trailer truck", + "tray", + "trench coat", + "tricycle", + "trimaran", + "tripod", + "triumphal arch", + "trolleybus", + "trombone", + "tub", + "turnstile", + "typewriter keyboard", + "umbrella", + "unicycle", + "upright", + "vacuum", + "vase", + "vault", + "velvet", + "vending machine", + "vestment", + "viaduct", + "violin", + "volleyball", + "waffle iron", + "wall clock", + "wallet", + "wardrobe", + "warplane", + "washbasin", + "washer", + "water bottle", + "water jug", + "water tower", + "whiskey jug", + "whistle", + "wig", + "window screen", + "window shade", + "Windsor tie", + "wine bottle", + "wing", + "wok", + "wooden spoon", + "wool", + "worm fence", + "wreck", + "yawl", + "yurt", + "web site", + "comic book", + "crossword puzzle", + "street sign", + "traffic light", + "book jacket", + "menu", + "plate", + "guacamole", + "consomme", + "hot pot", + "trifle", + "ice cream", + "ice lolly", + "French loaf", + "bagel", + "pretzel", + "cheeseburger", + "hotdog", + "mashed potato", + "head cabbage", + "broccoli", + "cauliflower", + "zucchini", + "spaghetti squash", + "acorn squash", + "butternut squash", + "cucumber", + "artichoke", + "bell pepper", + "cardoon", + "mushroom", + "Granny Smith", + "strawberry", + "orange", + "lemon", + "fig", + "pineapple", + "banana", + "jackfruit", + "custard apple", + "pomegranate", + "hay", + "carbonara", + "chocolate sauce", + "dough", + "meat loaf", + "pizza", + "potpie", + "burrito", + "red wine", + "espresso", + "cup", + "eggnog", + "alp", + "bubble", + "cliff", + "coral reef", + "geyser", + "lakeside", + "promontory", + "sandbar", + "seashore", + "valley", + "volcano", + "ballplayer", + "groom", + "scuba diver", + "rapeseed", + "daisy", + "yellow lady's slipper", + "corn", + "acorn", + "hip", + "buckeye", + "coral fungus", + "agaric", + "gyromitra", + "stinkhorn", + "earthstar", + "hen-of-the-woods", + "bolete", + "ear", + "toilet tissue", +] \ No newline at end of file diff --git a/assets/tabby_cat.jpg b/assets/tabby_cat.jpg new file mode 100644 index 0000000..75b71f7 Binary files /dev/null and b/assets/tabby_cat.jpg differ diff --git a/demo.ipynb b/demo.ipynb new file mode 100644 index 0000000..8c5205a --- /dev/null +++ b/demo.ipynb @@ -0,0 +1,281 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "7e6d1aea-40d1-4752-aae7-11ed8c8a508b", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "from torchvision.models import mobilenet_v3_large, mobilenet_v3_small, MobileNet_V3_Large_Weights, MobileNet_V3_Small_Weights\n", + "from torchvision import transforms\n", + "\n", + "from PIL import Image\n", + "from nets.nn import mobilenet_v3_large as m3l\n", + "from nets.nn import mobilenet_v3_small as m3s\n", + "from assets.meta import IMAGENET_CATEGORIES" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8cdcf220-cb2d-42d3-86c7-7f0a715493ac", + "metadata": {}, + "outputs": [], + "source": [ + "mobilenet_v3_l = mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.IMAGENET1K_V1)\n", + "mobilenet_v3_s = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "8728e7a2-a32f-4765-9a80-ff83804acc98", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "5483032" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x = lambda model: sum(p.numel() for p in model.parameters() if p.requires_grad)\n", + "x(mobilenet_v3_l)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "388d8b68-8078-44d0-8db7-f6213cf143a2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2542856" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x(mobilenet_v3_s)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "8ff37c5e-76bd-4f50-8288-8dcfeadbe262", + "metadata": {}, + "outputs": [], + "source": [ + "def preprocess_image(image_path):\n", + " transform = transforms.Compose([\n", + " transforms.Resize((224, 224)), # Resize the image to match the model's input size\n", + " transforms.ToTensor(), # Convert the image to a PyTorch tensor\n", + " transforms.Normalize(\n", + " mean=[0.485, 0.456, 0.406], # Normalize the image using the mean and std of ImageNet\n", + " std=[0.229, 0.224, 0.225]\n", + " ),\n", + " ])\n", + "\n", + " image = Image.open(image_path)\n", + " image = transform(image).unsqueeze(0) # Add a batch dimension\n", + " return image" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "fbf325f2-21f3-4aec-9c6c-052f810d4755", + "metadata": {}, + "outputs": [], + "source": [ + "def inference(model, image_path):\n", + " model.eval()\n", + "\n", + " input_image = preprocess_image(image_path)\n", + " with torch.no_grad():\n", + " output = model(input_image)\n", + "\n", + " _, predicted_class = output.max(1)\n", + " print(f\"Predicted class index: {predicted_class.item()}\")\n", + "\n", + " predicted_label = IMAGENET_CATEGORIES[predicted_class.item()]\n", + " print(f\"Predicted class label: {predicted_label}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "ad20dd28-473f-4715-81f1-a2591a4ab5c9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Predicted class index: 281\n", + "Predicted class label: tabby\n" + ] + } + ], + "source": [ + "inference(mobilenet_v3_s, \"assets/tabby_cat.jpg\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "30f88688-bc3c-4d33-af6e-b5ba2d6361a9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "local_m3l = m3l()\n", + "local_m3l.load_state_dict(mobilenet_v3_l.state_dict())" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "b9a5b225-2e54-4ac5-a4f2-dbb4fe8e6a4c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Predicted class index: 281\n", + "Predicted class label: tabby\n" + ] + } + ], + "source": [ + "inference(local_m3l, \"assets/tabby_cat.jpg\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "8c54a37d-1d4c-4fb9-9bb2-91589ca044a0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "5483032" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x(local_m3l)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "9fe34419-611c-4b96-875f-48b55f43c1e3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "local_m3s = m3s()\n", + "local_m3s.load_state_dict(mobilenet_v3_s.state_dict())" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "d3e7e6b9-b350-4b69-b5de-2494640e187d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Predicted class index: 281\n", + "Predicted class label: tabby\n" + ] + } + ], + "source": [ + "inference(local_m3s, \"assets/tabby_cat.jpg\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "7f1834a8-703f-4e90-96da-c47bbb2fdcab", + "metadata": {}, + "outputs": [], + "source": [ + "torch.save(local_m3s.half().state_dict(), \"./weights/mobilenet_v3_small_half.pt\")\n", + "torch.save(local_m3l.half().state_dict(), \"./weights/mobilenet_v3_large_half.pt\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7a562db-08c6-4310-82d5-32c3d5ae4c42", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/main.py b/main.py index 62420ca..d835eaa 100644 --- a/main.py +++ b/main.py @@ -210,11 +210,11 @@ def main(args): model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) parameters = utils.add_weight_decay(model, weight_decay=args.weight_decay) - criterion = nn.CrossEntropyLoss() - optimizer = nn.RMSprop(parameters, lr=args.lr, alpha=0.9, eps=1e-3, weight_decay=0, momentum=args.momentum) - scheduler = nn.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma, warmup_epochs=args.warmup_epochs, + criterion = utils.CrossEntropyLoss() + optimizer = utils.RMSprop(parameters, lr=args.lr, alpha=0.9, eps=1e-3, weight_decay=0, momentum=args.momentum) + scheduler = utils.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma, warmup_epochs=args.warmup_epochs, warmup_lr_init=args.warmup_lr_init) - model_ema = nn.EMA(model, decay=0.9999) + model_ema = utils.EMA(model, decay=0.9999) if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) diff --git a/nets/__init__.py b/nets/__init__.py index 9d20d0c..2545d60 100644 --- a/nets/__init__.py +++ b/nets/__init__.py @@ -1,4 +1,2 @@ -from .nn import MobileNetV3S, MobileNetV3L -from .nn import CrossEntropyLoss, PolyLoss -from .nn import RMSprop, StepLR -from .nn import EMA +from nets.nn import mobilenet_v3_large, mobilenet_v3_small + diff --git a/nets/nn.py b/nets/nn.py index 9e6d977..8809862 100644 --- a/nets/nn.py +++ b/nets/nn.py @@ -1,15 +1,45 @@ -from typing import Callable, List, Optional +from typing import Any, Callable, List, Optional import torch from torch import nn, Tensor __all__ = [ + "MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small", ] -def _make_divisible(v: float, divisor: int) -> int: +class SqueezeExcitation(torch.nn.Module): + """This Squeeze-and-Excitation block + Args: + in_channels (int): Number of channels in the input image + squeeze_channels (int): Number of squeeze channels + """ + + def __init__( + self, + in_channels: int, + squeeze_channels: int, + ) -> None: + super().__init__() + self.avg_pool = torch.nn.AdaptiveAvgPool2d(1) + self.fc1 = torch.nn.Conv2d(in_channels, squeeze_channels, 1) + self.fc2 = torch.nn.Conv2d(squeeze_channels, in_channels, 1) + self.relu = nn.ReLU() # `delta` activation + self.hard = nn.Hardsigmoid() # `sigma` (aka scale) activation + + def forward(self, x: Tensor) -> Tensor: + scale = self.avg_pool(x) + scale = self.fc1(scale) + scale = self.relu(scale) + scale = self.fc2(scale) + scale = self.hard(scale) + return scale * x + + +def _make_divisible(v: float, divisor: int = 8) -> int: + """This function ensures that all layers have a channel number divisible by 8""" new_v = max(divisor, int(v + divisor / 2) // divisor * divisor) # Make sure that round down does not go down by more than 10%. if new_v < 0.9 * v: @@ -17,20 +47,23 @@ def _make_divisible(v: float, divisor: int) -> int: return new_v -class Conv2dNormActivation(nn.Module): +class Conv2dNormActivation(torch.nn.Sequential): + """Convolutional block, consists of nn.Conv2d, nn.BatchNorm2d, nn.ReLU""" + def __init__( self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, - padding: Optional[int] = None, + padding: Optional = None, groups: int = 1, + activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, dilation: int = 1, - activation_layer: Optional[ - Callable[..., torch.nn.Module]] = torch.nn.ReLU + inplace: Optional[bool] = True, + bias: bool = False, ) -> None: - super().__init__() + if padding is None: padding = (kernel_size - 1) // 2 * dilation @@ -43,49 +76,15 @@ def __init__( padding=padding, dilation=dilation, groups=groups, - bias=False, + bias=bias, ), nn.BatchNorm2d(num_features=out_channels, eps=0.001, momentum=0.01) ] if activation_layer is not None: - layers.append(activation_layer(inplace=True)) - self.out_channels = out_channels - self.block = nn.Sequential(*layers) - - def forward(self, x: Tensor) -> Tensor: - return self.block(x) - - -class SqueezeExcitation(torch.nn.Module): - """Squeeze-and-Excitation block - Args: - input_channels (int): Number of channels in the input image - squeeze_channels (int): Number of squeeze channels - """ - - def __init__( - self, - input_channels: int, - squeeze_channels: int, - ) -> None: - super().__init__() - self.avg_pool = torch.nn.AdaptiveAvgPool2d(1) - self.fc1 = torch.nn.Conv2d(input_channels, squeeze_channels, 1) - self.fc2 = torch.nn.Conv2d(squeeze_channels, input_channels, 1) - self.relu = nn.ReLU() - self.hard = nn.Hardsigmoid() - - def _scale(self, x: Tensor) -> Tensor: - scale = self.avg_pool(x) - scale = self.fc1(scale) - scale = self.relu(scale) - scale = self.fc2(scale) - return self.hard(scale) - - def forward(self, x: Tensor) -> Tensor: - scale = self._scale(x) - return scale * x + params = {} if inplace is None else {"inplace": inplace} + layers.append(activation_layer(**params)) + super().__init__(*layers) class InvertedResidual(nn.Module): @@ -93,54 +92,61 @@ class InvertedResidual(nn.Module): def __init__( self, - input_channels: int, - kernel_size: int, - expanded_channels: int, + in_channels: int, + kernel: int, + exp_channels: int, out_channels: int, use_se: bool, activation: str, stride: int, - dilation: int - ): + dilation: int, + ) -> None: super().__init__() - if not (1 <= stride <= 2): - raise ValueError("illegal stride value") + self._shortcut = stride == 1 and in_channels == out_channels + + in_channels = _make_divisible(in_channels) + exp_channels = _make_divisible(exp_channels) + out_channels = _make_divisible(out_channels) - self._shortcut = stride == 1 and input_channels == out_channels - activation_layer = nn.Hardswish if activation == "HS" else nn.ReLU layers: List[nn.Module] = [] + activation_layer = nn.Hardswish if activation == "HS" else nn.ReLU + # expand - if expanded_channels != input_channels: + if exp_channels != in_channels: layers.append( Conv2dNormActivation( - input_channels, - expanded_channels, + in_channels=in_channels, + out_channels=exp_channels, kernel_size=1, activation_layer=activation_layer, ) ) - # depth-wise + # depth-wise convolution layers.append( Conv2dNormActivation( - in_channels=expanded_channels, - out_channels=expanded_channels, - kernel_size=kernel_size, - stride=stride, + in_channels=exp_channels, + out_channels=exp_channels, + kernel_size=kernel, + stride=1 if dilation > 1 else stride, dilation=dilation, - groups=expanded_channels, + groups=exp_channels, activation_layer=activation_layer, ) ) - # squeeze excitation if use_se: - squeeze_channels = _make_divisible(expanded_channels // 4, 8) + squeeze_channels = _make_divisible(exp_channels // 4, 8) layers.append( - SqueezeExcitation(expanded_channels, squeeze_channels)) + SqueezeExcitation( + in_channels=exp_channels, + squeeze_channels=squeeze_channels + ) + ) + # project layer layers.append( Conv2dNormActivation( - in_channels=expanded_channels, + in_channels=exp_channels, out_channels=out_channels, kernel_size=1, activation_layer=None @@ -150,28 +156,35 @@ def __init__( self.block = nn.Sequential(*layers) def forward(self, x: Tensor) -> Tensor: + result = self.block(x) if self._shortcut: - return x + self.block(x) - return self.block(x) + result += x + return result class MobileNetV3(nn.Module): - """MobileNet V3 main class""" - def __init__( self, - cnf: List[List[int | bool | str]], + inverted_residual_setting: List[List[int | str | bool]], last_channel: int, num_classes: int = 1000, dropout: float = 0.2, ) -> None: - + """MobileNet V3 main class + Args: + inverted_residual_setting: network structure + last_channel: number of channels on the penultimate layer + num_classes: number of classes + dropout: dropout probability + """ super().__init__() + # building first layer + first_conv_out_channels = inverted_residual_setting[0][0] layers: List[nn.Module] = [ Conv2dNormActivation( in_channels=3, - out_channels=16, + out_channels=first_conv_out_channels, kernel_size=3, stride=2, activation_layer=nn.Hardswish, @@ -179,16 +192,16 @@ def __init__( ] # building inverted residual blocks - for args in cnf: - layers.append(InvertedResidual(*args)) + for params in inverted_residual_setting: + layers.append(InvertedResidual(*params)) - last_in_channels = cnf[-1][0] - last_out_channels = cnf[-1][2] # building last several layers + last_conv_in_channels = inverted_residual_setting[-1][3] + last_conv_out_channels = 6 * last_conv_in_channels layers.append( Conv2dNormActivation( - in_channels=last_in_channels, - out_channels=last_out_channels, + in_channels=last_conv_in_channels, + out_channels=last_conv_out_channels, kernel_size=1, activation_layer=nn.Hardswish, ) @@ -196,16 +209,13 @@ def __init__( self.features = nn.Sequential(*layers) self.avg_pool = nn.AdaptiveAvgPool2d(1) - - # classifier self.classifier = nn.Sequential( - nn.Linear(in_features=last_out_channels, out_features=last_channel), + nn.Linear(last_conv_out_channels, last_channel), nn.Hardswish(inplace=True), nn.Dropout(p=dropout, inplace=True), - nn.Linear(in_features=last_channel, out_features=num_classes), + nn.Linear(last_channel, num_classes), ) - # initialize weights for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out") @@ -220,82 +230,58 @@ def __init__( def forward(self, x: Tensor) -> Tensor: x = self.features(x) - x = self.avg_pool(x) x = torch.flatten(x, 1) - x = self.classifier(x) return x -def _mobilenet_v3_conf(arch: str): +def _mobilenet_v3(arch: str, **kwargs: Any, ) -> MobileNetV3: if arch == "mobilenet_v3_large": - inverted_residual_settings = [ + inverted_residual_setting = [ [16, 3, 16, 16, False, "RE", 1, 1], - [16, 3, 64, 24, False, "RE", 2, 1], - # C1 - + [16, 3, 64, 24, False, "RE", 2, 1], # C1 [24, 3, 72, 24, False, "RE", 1, 1], - [24, 5, 72, 40, True, "RE", 2, 1], - # C2 - + [24, 5, 72, 40, True, "RE", 2, 1], # C2 [40, 5, 120, 40, True, "RE", 1, 1], [40, 5, 120, 40, True, "RE", 1, 1], - [40, 3, 240, 80, False, "HS", 2, 1], - # C3 - + [40, 3, 240, 80, False, "HS", 2, 1], # C3 [80, 3, 200, 80, False, "HS", 1, 1], [80, 3, 184, 80, False, "HS", 1, 1], [80, 3, 184, 80, False, "HS", 1, 1], [80, 3, 480, 112, True, "HS", 1, 1], [112, 3, 672, 112, True, "HS", 1, 1], - [112, 5, 672, 160, True, "HS", 2, 1], - # C4 - + [112, 5, 672, 160, True, "HS", 2, 1], # C4 + [160, 5, 960, 160, True, "HS", 1, 1], [160, 5, 960, 160, True, "HS", 1, 1], - [160, 5, 960, 160, True, "HS", 1, 1] ] - last_channel = 1280 + last_channel = 1280 # C5 elif arch == "mobilenet_v3_small": - inverted_residual_settings = [ - [16, 3, 16, 16, True, "RE", 2, 1], - # C1 - [16, 3, 72, 24, False, "RE", 2, 1], - # C2 + inverted_residual_setting = [ + [16, 3, 16, 16, True, "RE", 2, 1], # C1 + [16, 3, 72, 24, False, "RE", 2, 1], # C2 [24, 3, 88, 24, False, "RE", 1, 1], - [24, 5, 96, 40, True, "HS", 2, 1], - # C3 + [24, 5, 96, 40, True, "HS", 2, 1], # C3 [40, 5, 240, 40, True, "HS", 1, 1], [40, 5, 240, 40, True, "HS", 1, 1], [40, 5, 120, 48, True, "HS", 1, 1], [48, 5, 144, 48, True, "HS", 1, 1], - [48, 5, 288, 96, True, "HS", 2, 1], - # C4 + [48, 5, 288, 96, True, "HS", 2, 1], # C4 [96, 5, 576, 96, True, "HS", 1, 1], [96, 5, 576, 96, True, "HS", 1, 1], ] - last_channel = 1024 + last_channel = 1024 # C5 else: raise ValueError(f"Unsupported model type {arch}") - return inverted_residual_settings, last_channel - - -def mobilenet_v3_large(**kwargs) -> MobileNetV3: - inverted_residual_settings, lc = _mobilenet_v3_conf("mobilenet_v3_large") - return MobileNetV3(inverted_residual_settings, lc, **kwargs) + model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs) + return model -def mobilenet_v3_small(**kwargs) -> MobileNetV3: - inverted_residual_settings, lc = _mobilenet_v3_conf("mobilenet_v3_small") - return MobileNetV3(inverted_residual_settings, lc, **kwargs) +def mobilenet_v3_large(**kwargs: Any) -> MobileNetV3: + return _mobilenet_v3(arch="mobilenet_v3_large", **kwargs) -if __name__ == '__main__': - v3_large = mobilenet_v3_large() - v3_small = mobilenet_v3_small() - print("Number of parameters of MobileNet V3 Large: {}".format( - sum(p.numel() for p in v3_large.parameters() if p.requires_grad))) - print("Number of parameters of MobileNet V3 Small: {}".format( - sum(p.numel() for p in v3_small.parameters() if p.requires_grad))) +def mobilenet_v3_small(**kwargs: Any) -> MobileNetV3: + return _mobilenet_v3(arch="mobilenet_v3_small", **kwargs) diff --git a/requirements.txt b/requirements.txt index c1d4e22..d4b67a2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -torch==2.0.1 -torchvision==0.15.2 -Pillow==9.5.0 -numpy==1.24.3 \ No newline at end of file +torch>=2.0.1 +torchvision>=0.15.2 +Pillow>=9.5.0 +numpy>=1.24.3 diff --git a/utils/__init__.py b/utils/__init__.py index 97ab531..e0288b8 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,3 +1,12 @@ -from .dataset import ImageFolder -from .metrics import AverageMeter, accuracy -from .general import add_weight_decay, reduce_tensor, setup_for_distributed, init_distributed_mode +from utils.dataset import ImageFolder +from utils.metrics import AverageMeter, accuracy +from utils.general import ( + add_weight_decay, + reduce_tensor, + setup_for_distributed, + init_distributed_mode, + EMA, + StepLR, + RMSprop, + CrossEntropyLoss +) diff --git a/utils/dataset.py b/utils/dataset.py index 8646f31..3c8b275 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -6,7 +6,6 @@ class ImageFolder(data.Dataset): def __init__(self, root, transform=None): - self.transform = transform self.classes, self.class_to_idx = self.find_classes(root) self.samples = self.make_dataset(root, self.class_to_idx) @@ -21,7 +20,6 @@ def __getitem__(self, index): return image, label def __len__(self): - return len(self.samples) @staticmethod @@ -29,14 +27,12 @@ def load_image(path): with open(path, 'rb') as f: image = Image.open(f) image = image.convert('RGB') - return image @staticmethod def find_classes(directory): class_names = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir()) class_to_idx = {cls_name: idx for idx, cls_name in enumerate(class_names)} - return class_names, class_to_idx @staticmethod diff --git a/utils/general.py b/utils/general.py index 4aae5b2..fecc293 100644 --- a/utils/general.py +++ b/utils/general.py @@ -1,20 +1,12 @@ import os -import torch -import torch.distributed as distributed - +from copy import deepcopy -def _make_divisible(width, divisor=8): - new_width = max(divisor, int(width + divisor / 2) // divisor * divisor) - if new_width < 0.9 * width: - new_width += divisor - return new_width +from torch import nn +from torch.nn import functional as F - -def round_filters(filters: int, width_mult: float) -> int: - if width_mult == 1.0: - return filters - return int(_make_divisible(filters * width_mult)) +import torch +import torch.distributed as distributed def reduce_tensor(tensor, n): @@ -86,3 +78,171 @@ def init_distributed_mode(args): print("Warning: DP is On. Please use DDP(Distributed Data Parallel") setup_for_distributed(args.local_rank == 0) + + +class EMA(torch.nn.Module): + """Exponential Moving Average""" + + def __init__(self, model: nn.Module, decay: float = 0.9999) -> None: + super().__init__() + self.model = deepcopy(model) + self.model.eval() + self.decay = decay + + def _update(self, model: nn.Module, update_fn) -> None: + with torch.no_grad(): + ema_v = self.model.state_dict().values() + model_v = model.state_dict().values() + for e, m in zip(ema_v, model_v): + e.copy_(update_fn(e, m)) + + def update_parameters(self, model: nn.Module) -> None: + self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) + + +class CrossEntropyLoss: + """Cross Entropy Loss""" + + def __init__(self, reduction='mean', label_smoothing=0.0) -> None: + super().__init__() + + self.label_smoothing = label_smoothing + self.reduction = reduction + + def __call__(self, prediction, target): + return F.cross_entropy( + prediction, + target, + reduction=self.reduction, + label_smoothing=self.label_smoothing + ) + + +class RMSprop(torch.optim.Optimizer): + def __init__( + self, + params, + lr=1e-2, + alpha=0.9, + eps=1e-7, + weight_decay=0, + momentum=0., + centered=False, + decoupled_decay=False, + lr_in_momentum=True + ): + + defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay, + decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum) + super(RMSprop, self).__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('momentum', 0) + group.setdefault('centered', False) + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('RMSprop does not support sparse gradients') + state = self.state[p] + + if len(state) == 0: + state['step'] = 0 + state['square_avg'] = torch.ones_like(p.data) # PyTorch inits to zero + if group['momentum'] > 0: + state['momentum_buffer'] = torch.zeros_like(p.data) + if group['centered']: + state['grad_avg'] = torch.zeros_like(p.data) + + square_avg = state['square_avg'] + one_minus_alpha = 1. - group['alpha'] + + state['step'] += 1 + + if group['weight_decay'] != 0: + if 'decoupled_decay' in group and group['decoupled_decay']: + p.data.add_(p.data, alpha=-group['weight_decay']) + else: + grad = grad.add(p.data, alpha=group['weight_decay']) + + square_avg.add_(grad.pow(2) - square_avg, alpha=one_minus_alpha) + + if group['centered']: + grad_avg = state['grad_avg'] + grad_avg.add_(grad - grad_avg, alpha=one_minus_alpha) + avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group['eps']).sqrt_() + else: + avg = square_avg.add(group['eps']).sqrt_() + + if group['momentum'] > 0: + buf = state['momentum_buffer'] + if 'lr_in_momentum' in group and group['lr_in_momentum']: + buf.mul_(group['momentum']).addcdiv_(grad, avg, value=group['lr']) + p.data.add_(-buf) + else: + buf.mul_(group['momentum']).addcdiv_(grad, avg) + p.data.add_(buf, alpha=-group['lr']) + else: + p.data.addcdiv_(grad, avg, value=-group['lr']) + + return loss + + +class StepLR: + + def __init__( + self, + optimizer, + step_size, + gamma=1., + warmup_epochs=0, + warmup_lr_init=0 + ): + + self.optimizer = optimizer + self.step_size = step_size + self.gamma = gamma + self.warmup_epochs = warmup_epochs + self.warmup_lr_init = warmup_lr_init + + for group in self.optimizer.param_groups: + group.setdefault('initial_lr', group['lr']) + + self.base_lr_values = [group['initial_lr'] for group in self.optimizer.param_groups] + self.update_groups(self.base_lr_values) + + if self.warmup_epochs: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_epochs for v in self.base_lr_values] + self.update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_lr_values] + + def state_dict(self): + return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} + + def load_state_dict(self, state_dict): + self.__dict__.update(state_dict) + + def step(self, epoch): + if epoch < self.warmup_epochs: + values = [self.warmup_lr_init + epoch * s for s in self.warmup_steps] + else: + values = [base_lr * (self.gamma ** (epoch // self.step_size)) for base_lr in self.base_lr_values] + if values is not None: + self.update_groups(values) + + def update_groups(self, values): + if not isinstance(values, (list, tuple)): + values = [values] * len(self.optimizer.param_groups) + for param_group, value in zip(self.optimizer.param_groups, values): + param_group['lr'] = value diff --git a/weights/last.ckpt b/weights/last.ckpt deleted file mode 100644 index 004f200..0000000 Binary files a/weights/last.ckpt and /dev/null differ diff --git a/weights/last.pth b/weights/mobilenet_v3_large_half.pt similarity index 57% rename from weights/last.pth rename to weights/mobilenet_v3_large_half.pt index 509d276..31958ba 100644 Binary files a/weights/last.pth and b/weights/mobilenet_v3_large_half.pt differ diff --git a/weights/mobilenet_v3_small_half.pt b/weights/mobilenet_v3_small_half.pt new file mode 100644 index 0000000..7eacd8c Binary files /dev/null and b/weights/mobilenet_v3_small_half.pt differ