From db96595cc39dbd49a477c0dfbabb05ccb87e1250 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 19 Jan 2023 12:59:24 -0800 Subject: [PATCH] should be using 1 - probs for score --- phenaki_pytorch/phenaki_pytorch.py | 3 ++- setup.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/phenaki_pytorch/phenaki_pytorch.py b/phenaki_pytorch/phenaki_pytorch.py index 5d87cbb..71c48ac 100644 --- a/phenaki_pytorch/phenaki_pytorch.py +++ b/phenaki_pytorch/phenaki_pytorch.py @@ -663,7 +663,8 @@ def sample( noise = noise_K * (uniform(scores.shape, device) - 0.5) * noise_multiplier scores = scores + noise else: - scores = logits.gather(2, rearrange(pred_video_ids, '... -> ... 1')) + probs = logits.softmax(dim = -1) + scores = probs.gather(2, rearrange(pred_video_ids, '... -> ... 1')) scores = 1 - rearrange(scores, '... 1 -> ...') scores = torch.where(mask, scores, -1e4) diff --git a/setup.py b/setup.py index 2bbe833..15b5fdc 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'phenaki-pytorch', packages = find_packages(exclude=[]), - version = '0.0.70', + version = '0.0.71', license='MIT', description = 'Phenaki - Pytorch', author = 'Phil Wang',