Skip to content

Commit

Permalink
fix accuracy calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
enricobu96 committed Apr 13, 2024
1 parent f577cd5 commit 79fb198
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions mACHINE-LEARNINGS/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,11 @@ def execute(train_set_size, batch_size, lr, epochs, is_verbose, weight_decay):

if IS_VERBOSE:
print('Training: Epoch %d - Batch %d/%d: Loss: %.4f' %
(epoch+1, batch_num, len(dataloaders["train"]), train_loss / (batch_num + 1)))
(epoch+1, batch_num+1, len(dataloaders["train"]), train_loss / (batch_num + 1)))

n_batch = batch_num

print('EPOCH', epoch+1, 'ACCURACY:', (acc.item() / (train_size/BATCH_SIZE)))
print('EPOCH', epoch+1, 'ACCURACY:', (acc.item() / (n_batch+1)))

"""
TEST PHASE
Expand All @@ -103,7 +105,9 @@ def execute(train_set_size, batch_size, lr, epochs, is_verbose, weight_decay):
print('Evaluating: Batch %d/%d: Loss: %.4f' %
(batch_num, len(dataloaders["test"]), test_loss / (batch_num + 1)))

print('TEST ACCURACY:', (acc.item() / (test_size/BATCH_SIZE)))
n_batch = batch_num

print('TEST ACCURACY:', (acc.item() / (n_batch+1)))

print('Saving model...')
torch.save(model.state_dict(), 'models/trained/model.pt')
Expand Down

0 comments on commit 79fb198

Please sign in to comment.