diff --git a/torchquantum/functional/gate_wrapper.py b/torchquantum/functional/gate_wrapper.py index f1383f2f..c446df6b 100644 --- a/torchquantum/functional/gate_wrapper.py +++ b/torchquantum/functional/gate_wrapper.py @@ -371,7 +371,7 @@ def gate_wrapper( params = params.unsqueeze(0) if params.dim() == 2 else params else: if params.dim() == 1: - params = params.unsqueeze(-1) + params = params.unsqueeze(0).unsqueeze(-1) elif params.dim() == 0: params = params.unsqueeze(-1).unsqueeze(-1) # params = params.unsqueeze(-1) if params.dim() == 1 else params