Skip to content

Commit e7eb81a

Browse files
Revert (#233) which don't generate multi times input for weight sharing case (#290)
1 parent dfc501a commit e7eb81a

File tree

1 file changed

+9
-11
lines changed
  • models/image_recognition/pytorch/common

1 file changed

+9
-11
lines changed

models/image_recognition/pytorch/common/main.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -494,23 +494,21 @@ def run_weights_sharing_model(m, tid, args):
494494
start_time = time.time()
495495
num_images = 0
496496
time_consume = 0
497+
x = torch.randn(args.batch_size, 3, 224, 224)
498+
if args.bf16:
499+
x = x.to(torch.bfloat16)
500+
if args.ipex:
501+
x = x.contiguous(memory_format=torch.channels_last)
502+
497503
with torch.no_grad():
498504
while num_images < steps:
499-
if args.bf16:
500-
for i in range(24):
501-
x = torch.randn(args.batch_size, 3, 224, 224).to(torch.bfloat16)
502-
else:
503-
for i in range(24):
504-
x = torch.randn(args.batch_size, 3, 224, 224)
505-
if args.ipex:
506-
x = x.contiguous(memory_format=torch.channels_last)
507505
start_time = time.time()
508506
if not args.jit and args.bf16:
509507
with torch.cpu.amp.autocast():
510508
y = m(x)
511509
else:
512510
y = m(x)
513-
511+
514512
end_time = time.time()
515513
if num_images > args.warmup_iterations:
516514
time_consume += end_time - start_time
@@ -585,7 +583,7 @@ def validate(val_loader, model, criterion, args):
585583
output = model(images)
586584
else:
587585
output = model(images)
588-
586+
589587
if i >= args.warmup_iterations:
590588
batch_time.update(time.time() - end)
591589

@@ -613,7 +611,7 @@ def validate(val_loader, model, criterion, args):
613611
output = model(images)
614612
else:
615613
output = model(images)
616-
614+
617615
# compute output
618616
batch_time.update(time.time() - end)
619617
#print(output)

0 commit comments

Comments
 (0)