Skip to content

Commit f6a074b

Browse files
Merge pull request #252 from nikhilkhatri/main
Allow fixed parameters in GeneralEncoder
2 parents 79ce996 + 69f0699 commit f6a074b

File tree

2 files changed

+98
-1
lines changed

2 files changed

+98
-1
lines changed

test/encoding/test_encodings.py

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""
2+
MIT License
3+
4+
Copyright (c) 2020-present TorchQuantum Authors
5+
6+
Permission is hereby granted, free of charge, to any person obtaining a copy
7+
of this software and associated documentation files (the "Software"), to deal
8+
in the Software without restriction, including without limitation the rights
9+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
copies of the Software, and to permit persons to whom the Software is
11+
furnished to do so, subject to the following conditions:
12+
13+
The above copyright notice and this permission notice shall be included in all
14+
copies or substantial portions of the Software.
15+
16+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+
SOFTWARE.
23+
"""
24+
25+
# test the controlled unitary function
26+
27+
28+
import torchquantum as tq
29+
import torch
30+
from test.utils import check_all_close
31+
32+
33+
def test_GeneralEncoder():
34+
35+
parameterised_funclist = [
36+
{"input_idx": [0], "func": "crx", "wires": [1, 0]},
37+
{"input_idx": [1, 2, 3], "func": "u3", "wires": [1]},
38+
{"input_idx": [4], "func": "ry", "wires": [0]},
39+
{"input_idx": [5], "func": "ry", "wires": [1]},
40+
]
41+
42+
semiparam_funclist = [
43+
{"params": [0.2], "func": "crx", "wires": [1, 0]},
44+
{"params": [0.3, 0.4, 0.5], "func": "u3", "wires": [1]},
45+
{"input_idx": [0], "func": "ry", "wires": [0]},
46+
{"input_idx": [1], "func": "ry", "wires": [1]},
47+
]
48+
49+
expected_states = torch.complex(
50+
torch.Tensor(
51+
[[0.8423, 0.4474, 0.2605, 0.1384], [0.7649, 0.5103, 0.3234, 0.2157]]
52+
),
53+
torch.Tensor(
54+
[[-0.0191, 0.0522, -0.0059, 0.0162], [-0.0233, 0.0483, -0.0099, 0.0204]]
55+
),
56+
)
57+
58+
parameterised_enc = tq.GeneralEncoder(parameterised_funclist)
59+
semiparam_enc = tq.GeneralEncoder(semiparam_funclist)
60+
61+
param_vec = torch.Tensor(
62+
[[0.2, 0.3, 0.4, 0.5, 0.6, 0.7], [0.2, 0.3, 0.4, 0.5, 0.8, 0.9]]
63+
)
64+
semiparam_vec = torch.Tensor([[0.6, 0.7], [0.8, 0.9]])
65+
66+
qd = tq.QuantumDevice(n_wires=2)
67+
68+
qd.reset_states(bsz=2)
69+
parameterised_enc(qd, param_vec)
70+
state1 = qd.get_states_1d()
71+
72+
qd.reset_states(bsz=2)
73+
semiparam_enc(qd, semiparam_vec)
74+
state2 = qd.get_states_1d()
75+
76+
check_all_close(state1, state2)
77+
check_all_close(state1, expected_states)
78+
79+
80+
if __name__ == "__main__":
81+
test_GeneralEncoder()

torchquantum/encoding/encodings.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,18 @@ class GeneralEncoder(Encoder, metaclass=ABCMeta):
8181
{'input_idx': [12, 13, 14], 'func': 'u3', 'wires': [3]},
8282
{'input_idx': [15], 'func': 'u1', 'wires': [3]},
8383
]
84+
85+
Example 3:
86+
[
87+
{'params': [0.25], 'func': 'rx', 'wires': [0]},
88+
{'params': [0.25], 'func': 'rx', 'wires': [1]},
89+
{'params': [0.25], 'func': 'rx', 'wires': [2]},
90+
{'params': [0.25], 'func': 'rx', 'wires': [3]},
91+
{'input_idx': [0], 'func': 'ry', 'wires': [0]},
92+
{'input_idx': [1], 'func': 'ry', 'wires': [1]},
93+
{'input_idx': [2], 'func': 'ry', 'wires': [2]},
94+
{'input_idx': [3], 'func': 'ry', 'wires': [3]}
95+
]
8496
"""
8597

8698
def __init__(self, func_list):
@@ -91,7 +103,11 @@ def __init__(self, func_list):
91103
def forward(self, qdev: tq.QuantumDevice, x):
92104
for info in self.func_list:
93105
if tq.op_name_dict[info["func"]].num_params > 0:
94-
params = x[:, info["input_idx"]]
106+
# If params are provided in encoder, use those,
107+
# else use params from x
108+
params = (torch.Tensor(info["params"]).repeat(x.shape[0], 1)
109+
if info.get("params")
110+
else x[:, info["input_idx"]])
95111
else:
96112
params = None
97113
func_name_dict[info["func"]](

0 commit comments

Comments
 (0)