Skip to content

Commit 6b30997

Browse files
Merge pull request #270 from AbdullahKazi500/AbdullahKazi500-patch-2
2 parents eda17c1 + d433bbe commit 6b30997

File tree

5 files changed

+766
-0
lines changed

5 files changed

+766
-0
lines changed

examples/QuantumGan/ README.md

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Quantum Generative Adversarial Network (QGAN) Example
2+
3+
This repository contains an example implementation of a Quantum Generative Adversarial Network (QGAN) using PyTorch and TorchQuantum. The example is provided in a Jupyter Notebook for interactive exploration.
4+
5+
## Overview
6+
7+
A QGAN consists of two main components:
8+
9+
1. **Generator:** This network generates fake quantum data samples.
10+
2. **Discriminator:** This network tries to distinguish between real and fake quantum data samples.
11+
12+
The goal is to train the generator to produce quantum data that is indistinguishable from real data, according to the discriminator. This is achieved through an adversarial training process, where the generator and discriminator are trained simultaneously in a competitive manner.
13+
14+
## Repository Contents
15+
16+
- `qgan_notebook.ipynb`: Jupyter Notebook demonstrating the QGAN implementation.
17+
- `qgan_script.py`: Python script containing the QGAN model and a main function for initializing the model with command-line arguments.
18+
19+
## Installation
20+
21+
To run the examples, you need to have the following dependencies installed:
22+
23+
- Python 3
24+
- PyTorch
25+
- TorchQuantum
26+
- Jupyter Notebook
27+
- ipywidgets
28+
29+
You can install the required Python packages using pip:
30+
31+
```bash
32+
pip install torch torchquantum jupyter ipywidgets
33+
```
34+
35+
36+
Running the Examples
37+
Jupyter Notebook
38+
Open the qgan_notebook.ipynb file in Jupyter Notebook.
39+
Execute the notebook cells to see the QGAN model in action.
40+
Python Script
41+
You can also run the QGAN model using the Python script. The script uses argparse to handle command-line arguments.
42+
43+
bash
44+
Copy code
45+
python qgan_script.py <n_qubits> <latent_dim>
46+
Replace <n_qubits> and <latent_dim> with the desired number of qubits and latent dimensions.
47+
48+
Notebook Details
49+
The Jupyter Notebook is structured as follows:
50+
51+
Introduction: Provides an overview of the QGAN and its components.
52+
Import Libraries: Imports the necessary libraries, including PyTorch and TorchQuantum.
53+
Generator Class: Defines the quantum generator model.
54+
Discriminator Class: Defines the quantum discriminator model.
55+
QGAN Class: Combines the generator and discriminator into a single QGAN model.
56+
Main Function: Initializes the QGAN model and prints its structure.
57+
Interactive Model Creation: Uses ipywidgets to create an interactive interface for adjusting the number of qubits and latent dimensions.
58+
Understanding QGANs
59+
QGANs are a type of Generative Adversarial Network (GAN) that operate in the quantum domain. They leverage quantum circuits to generate and evaluate data samples. The adversarial training process involves two competing networks:
60+
61+
The Generator creates fake quantum data samples from a latent space.
62+
The Discriminator attempts to distinguish these fake samples from real quantum data.
63+
Through training, the generator improves its ability to create realistic quantum data, while the discriminator enhances its ability to identify fake data. This process results in a generator that can produce high-quality quantum data samples.
64+
65+
66+
## QGAN Implementation for CIFAR-10 Dataset
67+
This implementation trains a QGAN on the CIFAR-10 dataset to generate fake images. It follows a similar structure to the TorchQuantum QGAN, with the addition of data loading and processing specific to the CIFAR-10 dataset.
68+
Generated images can be seen in the folder
69+
70+
This `README.md` file explains the purpose of the repository, the structure of the notebook, and how to run the examples, along with a brief overview of the QGAN concept for those unfamiliar with it.
71+
72+
73+
## Reference
74+
- [ ] https://arxiv.org/abs/2312.09939

examples/QuantumGan/QGan.py

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import argparse
2+
import torch
3+
import torch.nn as nn
4+
import torch.optim as optim
5+
import torchquantum as tq
6+
7+
class Generator(nn.Module):
8+
def __init__(self, n_qubits: int, latent_dim: int):
9+
super().__init__()
10+
self.n_qubits = n_qubits
11+
self.latent_dim = latent_dim
12+
13+
# Quantum encoder
14+
self.encoder = tq.GeneralEncoder([
15+
{'input_idx': [i], 'func': 'rx', 'wires': [i]}
16+
for i in range(self.n_qubits)
17+
])
18+
19+
# RX gates
20+
self.rxs = nn.ModuleList([
21+
tq.RX(has_params=True, trainable=True) for _ in range(self.n_qubits)
22+
])
23+
24+
def forward(self, x):
25+
qdev = tq.QuantumDevice(n_wires=self.n_qubits, bsz=x.shape[0], device=x.device)
26+
self.encoder(qdev, x)
27+
28+
for i in range(self.n_qubits):
29+
self.rxs[i](qdev, wires=i)
30+
31+
return tq.measure(qdev)
32+
33+
class Discriminator(nn.Module):
34+
def __init__(self, n_qubits: int):
35+
super().__init__()
36+
self.n_qubits = n_qubits
37+
38+
# Quantum encoder
39+
self.encoder = tq.GeneralEncoder([
40+
{'input_idx': [i], 'func': 'rx', 'wires': [i]}
41+
for i in range(self.n_qubits)
42+
])
43+
44+
# RX gates
45+
self.rxs = nn.ModuleList([
46+
tq.RX(has_params=True, trainable=True) for _ in range(self.n_qubits)
47+
])
48+
49+
# Quantum measurement
50+
self.measure = tq.MeasureAll(tq.PauliZ)
51+
52+
def forward(self, x):
53+
qdev = tq.QuantumDevice(n_wires=self.n_qubits, bsz=x.shape[0], device=x.device)
54+
self.encoder(qdev, x)
55+
56+
for i in range(self.n_qubits):
57+
self.rxs[i](qdev, wires=i)
58+
59+
return self.measure(qdev)
60+
61+
class QGAN(nn.Module):
62+
def __init__(self, n_qubits: int, latent_dim: int):
63+
super().__init__()
64+
self.generator = Generator(n_qubits, latent_dim)
65+
self.discriminator = Discriminator(n_qubits)
66+
67+
def forward(self, z):
68+
fake_data = self.generator(z)
69+
fake_output = self.discriminator(fake_data)
70+
return fake_output
71+
72+
def main(n_qubits, latent_dim):
73+
model = QGAN(n_qubits, latent_dim)
74+
print(model)
75+
76+
if __name__ == "__main__":
77+
parser = argparse.ArgumentParser(description="Quantum Generative Adversarial Network (QGAN) Example")
78+
parser.add_argument('n_qubits', type=int, help='Number of qubits')
79+
parser.add_argument('latent_dim', type=int, help='Dimension of the latent space')
80+
81+
args = parser.parse_args()
82+
83+
main(args.n_qubits, args.latent_dim)
84+

0 commit comments

Comments
 (0)