Skip to content

Commit 90dba44

Browse files
committed
Made refactoring, added readme.md
1 parent eda8e03 commit 90dba44

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1145
-3181
lines changed

.gitignore

+4-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ __pycache__/
33
*.py[cod]
44
*$py.class
55

6-
training_data
7-
test_models
6+
trained-models
7+
papers
8+
analysis
9+
runpod
810

911
# Distribution / packaging
1012
dist/

README.md

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Snake Diffusion model
2+
3+
It is an educational repo to build realtime snake game based on Diffusion model. It was inspired by great papers:
4+
* Doom Diffusion from Google ([paper](https://arxiv.org/html/2408.14837v1))
5+
* Oasis ([github](https://github.com/etched-ai/open-oasis))
6+
* Diamond ([paper](https://arxiv.org/pdf/2405.12399))
7+
8+
My goal was to build something similar and I have choosen Snake game for simple logic. It took near 2 months of different experiments to get a ready-to-play model.
9+
10+
If you don't have GPU you can use [runpod.io](runpod.io)(it is paid). Also I created a [Google colab](https://colab.research.google.com/drive/1OxneGBeb4B1U5dszVf_2UDYZLHshJQ5T?usp=sharing) for playing.
11+
12+
## Model scheme
13+
14+
After couple of experiments I chose EDM diffusion model, because it shows high performance on small sample steps. DDIM requires much more steps to generate the same quality.
15+
16+
![Model scheme](assets/scheme.png)
17+
18+
## Install requirements
19+
20+
```shell
21+
pip install -r requirements.txt
22+
```
23+
24+
## Training
25+
26+
To train a new model, you should have a dataset. You can download it running a script:
27+
```shell
28+
bash scripts/download-dataset.sh
29+
```
30+
31+
Or generate manually:
32+
33+
```shell
34+
python src/generate_dataset.py --model agent.pth --dataset training_data --record
35+
```
36+
37+
Then you can start training with command:
38+
```shell
39+
python src/train.py --model-type edm --output models --loader loader.pkl --gen-val-images --config Diffusion.yaml
40+
```
41+
42+
I trained my model on [runpod.io](runpod.io). It had 32 epochs, took ~27 hours and the cost was 10$.
43+
44+
## Inference
45+
46+
You can download my ready-to-use model:
47+
```shell
48+
git clone https://huggingface.co/juramoshkov/snake-diffusion models
49+
```
50+
Then run [Play.ipynb](src/play.ipynb), where you can play Snake with 1 fps 🤓.
51+
52+
Another way to play is to use [Google colab](https://colab.research.google.com/drive/1OxneGBeb4B1U5dszVf_2UDYZLHshJQ5T?usp=sharing)

assets/scheme.png

186 KB
Loading

config/Diffusion.yaml

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
training:
2+
epochs: 30
3+
batch_size: 4
4+
num_workers: 2
5+
save_every_epoch: 2
6+
7+
generation:
8+
image_size: 64
9+
input_channels: 3
10+
output_channels: 3
11+
context_length: 4
12+
actions_count: 5
13+
14+
edm:
15+
p_mean: -1.2
16+
p_std: 1.2
17+
sigma_data: 0.5
18+
sigma_min: 0.002
19+
sigma_max: 80
20+
rho: 7
21+
unet:
22+
__type__: models.gen.blocks.UNetConfig
23+
steps: [2, 2, 2, 2]
24+
channels: [64, 64, 64, 64]
25+
cond_channels: 256
26+
attn_step_indexes: [false, false, false, false]
27+
28+
ddpm:
29+
T: 1000
30+
unet:
31+
__type__: models.gen.blocks.UNetConfig
32+
steps: [2, 2, 2, 2]
33+
channels: [64, 64, 64, 64]
34+
cond_channels: 256
35+
attn_step_indexes: [false, false, false, false]

config/SnakeAgent.yaml

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
q_agent:
2+
max_memory: 100000
3+
batch_size: 1000
4+
lr: 1.0e-3
5+
hidden_state: 256
6+
value_for_end_game:
7+
__type__: q_agent.ValueForEndGame
8+
value: "last_action"
9+
iterations: 80000
10+
min_deaths_to_record: 60
11+
12+
env:
13+
__type__: game.snake.env.GameEnvironment
14+
game:
15+
__type__: game.snake.game.SnakeGame
16+
width: 64
17+
height: 64
18+
speed: 240
19+
block_size: 5

diffusion/check_time_render.py

-65
This file was deleted.

diffusion/ddpm/ddpm.py

-161
This file was deleted.

0 commit comments

Comments
 (0)