Skip to content

Commit eda17c1

Browse files
Merge pull request #272 from AbdullahKazi500/AbdullahKazi500-patch-3
2 parents 91c80de + 82cf184 commit eda17c1

File tree

1 file changed

+188
-0
lines changed

1 file changed

+188
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
import torch
2+
import torch.optim as optim
3+
import torchquantum as tq
4+
import torchquantum.functional as tqf
5+
import numpy as np
6+
7+
class QuantumPulseDemo(tq.QuantumModule):
8+
"""
9+
Quantum pulse demonstration module.
10+
11+
This module defines a parameterized quantum pulse and applies it to a quantum device.
12+
"""
13+
14+
def __init__(self):
15+
"""
16+
Initializes the QuantumPulseDemo module.
17+
18+
Args:
19+
n_wires (int): Number of quantum wires (qubits).
20+
n_steps (int): Number of steps for the quantum pulse.
21+
hamil (list): Hamiltonian for the quantum pulse.
22+
"""
23+
super().__init__()
24+
self.n_wires = 2
25+
26+
# Quantum encoder
27+
self.encoder = tq.GeneralEncoder([
28+
{'input_idx': [0], 'func': 'rx', 'wires': [0]},
29+
{'input_idx': [1], 'func': 'rx', 'wires': [1]}
30+
])
31+
32+
# Define parameterized quantum pulse
33+
self.pulse = tq.pulse.QuantumPulseDirect(n_steps=4, hamil=[[0, 1], [1, 0]])
34+
35+
def forward(self, x):
36+
"""
37+
Forward pass through the QuantumPulseDemo module.
38+
39+
Args:
40+
x (torch.Tensor): Input tensor.
41+
42+
Returns:
43+
torch.Tensor: Measurement result from the quantum device.
44+
"""
45+
qdev = tq.QuantumDevice(n_wires=self.n_wires, bsz=x.shape[0], device=x.device)
46+
self.encoder(qdev, x)
47+
self.apply_pulse(qdev)
48+
return tq.measure(qdev)
49+
50+
def apply_pulse(self, qdev):
51+
"""
52+
Applies the parameterized quantum pulse to the quantum device.
53+
54+
Args:
55+
qdev (tq.QuantumDevice): Quantum device to apply the pulse to.
56+
"""
57+
pulse_params = self.pulse.pulse_shape.detach().cpu().numpy()
58+
# Apply pulse to the quantum device (adjust based on actual pulse application logic)
59+
for params in pulse_params:
60+
tqf.rx(qdev, wires=0, params=params)
61+
tqf.rx(qdev, wires=1, params=params)
62+
63+
class OM_EOM_Simulation:
64+
"""
65+
Optical modulation with electro-optic modulator (EOM) simulation.
66+
67+
This class simulates a sequence of optical pulses with or without EOM modulation.
68+
"""
69+
70+
def __init__(self, pulse_duration, modulation_bandwidth=None, eom_mode=False):
71+
"""
72+
Initializes the OM_EOM_Simulation.
73+
74+
Args:
75+
pulse_duration (float): Duration of each pulse.
76+
modulation_bandwidth (float, optional): Bandwidth of modulation. Defaults to None.
77+
eom_mode (bool, optional): Whether to simulate EOM mode. Defaults to False.
78+
"""
79+
self.pulse_duration = pulse_duration
80+
self.modulation_bandwidth = modulation_bandwidth
81+
self.eom_mode = eom_mode
82+
83+
def simulate_sequence(self):
84+
"""
85+
Simulates a sequence of optical pulses with specified parameters.
86+
87+
Returns:
88+
list: Sequence of pulses and delays.
89+
"""
90+
# Initialize the sequence
91+
sequence = []
92+
93+
# Add pulses and delays to the sequence
94+
if self.modulation_bandwidth:
95+
# Apply modulation bandwidth effect
96+
sequence.append(('Delay', 0))
97+
sequence.append(('Pulse', 'NoisyChannel'))
98+
for _ in range(3):
99+
# Apply pulses with specified duration
100+
sequence.append(('Delay', self.pulse_duration))
101+
if self.eom_mode:
102+
# Apply EOM mode operation
103+
sequence.append(('Pulse', 'EOM'))
104+
else:
105+
# Apply regular pulse
106+
sequence.append(('Pulse', 'Regular'))
107+
# Apply a delay between pulses
108+
sequence.append(('Delay', 0))
109+
110+
return sequence
111+
112+
class QuantumPulseDemoRunner:
113+
"""
114+
Runner for training the QuantumPulseDemo model and simulating the OM_EOM_Simulation.
115+
"""
116+
117+
def __init__(self, pulse_duration, modulation_bandwidth=None, eom_mode=False):
118+
"""
119+
Initializes the QuantumPulseDemoRunner.
120+
121+
Args:
122+
pulse_duration (float): Duration of each pulse.
123+
modulation_bandwidth (float, optional): Bandwidth of modulation. Defaults to None.
124+
eom_mode (bool, optional): Whether to simulate EOM mode. Defaults to False.
125+
"""
126+
self.model = QuantumPulseDemo()
127+
self.optimizer = optim.Adam(params=self.model.pulse.parameters(), lr=5e-3)
128+
self.target_unitary = self._initialize_target_unitary()
129+
self.simulator = OM_EOM_Simulation(pulse_duration, modulation_bandwidth, eom_mode)
130+
131+
def _initialize_target_unitary(self):
132+
"""
133+
Initializes the target unitary matrix.
134+
135+
Returns:
136+
torch.Tensor: Target unitary matrix.
137+
"""
138+
theta = 0.6
139+
return torch.tensor(
140+
[
141+
[np.cos(theta / 2), -1j * np.sin(theta / 2)],
142+
[-1j * np.sin(theta / 2), np.cos(theta / 2)],
143+
],
144+
dtype=torch.complex64,
145+
)
146+
147+
def train(self, epochs=1000):
148+
"""
149+
Trains the QuantumPulseDemo model.
150+
151+
Args:
152+
epochs (int, optional): Number of training epochs. Defaults to 1000.
153+
"""
154+
for epoch in range(epochs):
155+
x = torch.tensor([[np.pi, np.pi]], dtype=torch.float32)
156+
157+
qdev = self.model(x)
158+
159+
loss = (
160+
1
161+
- (
162+
torch.trace(self.model.pulse.get_unitary() @ self.target_unitary)
163+
/ self.target_unitary.shape[0]
164+
).abs()
165+
** 2
166+
)
167+
168+
self.optimizer.zero_grad()
169+
loss.backward()
170+
self.optimizer.step()
171+
172+
if epoch % 100 == 0:
173+
print(f'Epoch {epoch}, Loss: {loss.item()}')
174+
print('Current Pulse Shape:', self.model.pulse.pulse_shape)
175+
print('Current Unitary:\n', self.model.pulse.get_unitary())
176+
177+
def run_simulation(self):
178+
"""
179+
Runs the OM_EOM_Simulation.
180+
"""
181+
sequence = self.simulator.simulate_sequence()
182+
for step in sequence:
183+
print(step)
184+
185+
# Example usage
186+
runner = QuantumPulseDemoRunner(pulse_duration=100, modulation_bandwidth=5, eom_mode=True)
187+
runner.train()
188+
runner.run_simulation()

0 commit comments

Comments
 (0)