Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

使用monitor监测firing rate时,平均发放率一直都很小 #600

Open
mengqiShen opened this issue Jan 16, 2025 · 0 comments
Open

Comments

@mengqiShen
Copy link

首先感谢您优秀的工作。

我在跑spikformer时,使用monitor.get_avg_firing_rate(all=True)记录平均脉冲发放率,但我发现不管跑几个epoch,返回的发放率总是很小(但网络的训练、测试结果正确),大约在0.04左右,我的代码如下,请问我是否正确使用了monitor?(除了有关monitor的代码是我自己加的,其他均来自spikformer的开源代码)

`def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''):
batch_time_m = AverageMeter()
losses_m = AverageMeter()
top1_m = AverageMeter()
top5_m = AverageMeter()

model.eval()
mon = Monitor(model, device='cuda:0', backend='torch')
mon.enable()

end = time.time()
last_idx = len(loader) - 1
with torch.no_grad():
    for batch_idx, (input, target) in enumerate(loader):
        last_batch = batch_idx == last_idx
        if not args.prefetcher:
            input = input.cuda()
            target = target.cuda()
        if args.channels_last:
            input = input.contiguous(memory_format=torch.channels_last)

        with amp_autocast():
            output = model(input)
        if isinstance(output, (tuple, list)):
            output = output[0]

        # augmentation reduction
        reduce_factor = args.tta
        if reduce_factor > 1:
            output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
            target = target[0:target.size(0):reduce_factor]

        loss = loss_fn(output, target)
        functional.reset_net(model)

        acc1, acc5 = accuracy(output, target, topk=(1, 5))

        if args.distributed:
            reduced_loss = reduce_tensor(loss.data, args.world_size)
            acc1 = reduce_tensor(acc1, args.world_size)
            acc5 = reduce_tensor(acc5, args.world_size)
        else:
            reduced_loss = loss.data

        torch.cuda.synchronize()

        losses_m.update(reduced_loss.item(), input.size(0))
        top1_m.update(acc1.item(), output.size(0))
        top5_m.update(acc5.item(), output.size(0))

        batch_time_m.update(time.time() - end)
        end = time.time()
        if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
            log_name = 'Test' + log_suffix
            _logger.info(
                '{0}: [{1:>4d}/{2}]  '
                'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '
                'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '
                'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
                    log_name, batch_idx, last_idx, batch_time=batch_time_m,
                    loss=losses_m, top1=top1_m, top5=top5_m))
    mon_fire_rate = mon.get_avg_firing_rate(all=True)
    mon_nonfire_rate = mon.get_nonfire_ratio(all=True)
    print('-----firing rate by moniter------', mon_fire_rate)
    print('-----firing non-rate by moniter------', mon_nonfire_rate)
    mon.reset()
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])

return metrics

`
运行的结果为:
909c5a01e8ed0e164b9353bfc222d52
ce04cd9bc28b005228e5dd054c1b97a

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant