1
1
import re
2
2
from collections .abc import Sequence
3
+ from pathlib import PurePosixPath
4
+ from textwrap import indent
5
+
6
+ import numpy as np
7
+ import matplotlib
8
+ import matplotlib .pyplot as plt
9
+ from matplotlib .animation import Animation
10
+ from matplotlib .figure import Figure
11
+ from sphinx_gallery .scrapers import (
12
+ figure_rst , _anim_rst , _matplotlib_fig_titles , HLIST_HEADER ,
13
+ HLIST_IMAGE_MATPLOTLIB )
14
+ import plotly .graph_objects as go
15
+ try :
16
+ import kaleido # noqa: F401
17
+ except ImportError :
18
+ write_plotly_image = None
19
+ else :
20
+ from plotly .io import write_image as write_plotly_image
3
21
4
22
from stonesoup .base import Base
5
23
24
+
6
25
STONESOUP_TYPE_REGEX = re .compile (r'stonesoup\.(\w+\.)*' )
7
26
8
27
@@ -70,25 +89,7 @@ def setup(app):
70
89
app .connect ('autodoc-process-signature' , shorten_type_hints )
71
90
72
91
73
- import os
74
- import matplotlib
75
- import matplotlib .pyplot as plt
76
- from textwrap import indent
77
-
78
- from sphinx_gallery .scrapers import (
79
- figure_rst , _anim_rst , _matplotlib_fig_titles , HLIST_HEADER ,
80
- HLIST_IMAGE_MATPLOTLIB )
81
-
82
- import plotly .graph_objects as go
83
- try :
84
- import kaleido
85
- except ImportError :
86
- write_plotly_image = None
87
- else :
88
- from plotly .io import write_image as write_plotly_image
89
-
90
-
91
- class gallery_scraper ():
92
+ class GalleryScraper ():
92
93
def __init__ (self ):
93
94
self .plotted_figures = set ()
94
95
self .current_src_file = None
@@ -124,18 +125,15 @@ def __call__(self, block, block_vars, gallery_conf, **kwargs):
124
125
self .plotted_figures = set ()
125
126
self .current_src_file = block_vars ['src_file' ]
126
127
127
- from matplotlib .animation import Animation
128
- from matplotlib .figure import Figure
129
128
image_path_iterator = block_vars ['image_path_iterator' ]
130
129
image_rsts = []
131
130
132
131
# Check for animations
133
- anims = list ()
134
- if gallery_conf . get ( 'matplotlib_animations' , False ) :
132
+ anims = {}
133
+ if gallery_conf [ 'matplotlib_animations' ] :
135
134
for ani in block_vars ['example_globals' ].values ():
136
135
if isinstance (ani , Animation ):
137
- anims .append (ani )
138
-
136
+ anims [ani ._fig ] = ani
139
137
# Then standard images
140
138
new_figures = set (plt .get_fignums ()) - self .plotted_figures
141
139
last_line = block [1 ].strip ().split ('\n ' )[- 1 ]
@@ -149,30 +147,22 @@ def __call__(self, block, block_vars, gallery_conf, **kwargs):
149
147
else :
150
148
if isinstance (output , Figure ):
151
149
new_figures .add (output .number )
152
- elif isinstance (output , go .Figure ):
153
- if write_plotly_image is not None :
154
- image_path = next (image_path_iterator )
155
- if 'format' in kwargs :
156
- image_path = '%s.%s' % (os .path .splitext (image_path )[0 ],
157
- kwargs ['format' ])
158
- write_plotly_image (output , image_path , kwargs .get ('format' ))
150
+ elif isinstance (output , go .Figure ) and write_plotly_image is not None :
151
+ image_path = PurePosixPath (next (image_path_iterator ))
152
+ if "format" in kwargs :
153
+ image_path = image_path .with_suffix ("." + kwargs ["format" ])
154
+ write_plotly_image (output , str (image_path ), kwargs .get ('format' ))
159
155
160
156
for fig_num , image_path in zip (new_figures , image_path_iterator ):
161
- if 'format' in kwargs :
162
- image_path = '%s.%s' % (os .path .splitext (image_path )[0 ],
163
- kwargs ['format' ])
164
- # Set the fig_num figure as the current figure as we can't
165
- # save a figure that's not the current figure.
157
+ image_path = PurePosixPath (image_path )
158
+ if "format" in kwargs :
159
+ image_path = image_path .with_suffix ("." + kwargs ["format" ])
160
+ # Convert figure number to Figure.
166
161
fig = plt .figure (fig_num )
167
162
self .plotted_figures .add (fig_num )
168
163
# Deal with animations
169
- cont = False
170
- for anim in anims :
171
- if anim ._fig is fig :
172
- image_rsts .append (_anim_rst (anim , image_path , gallery_conf ))
173
- cont = True
174
- break
175
- if cont :
164
+ if anim := anims .get (fig ):
165
+ image_rsts .append (_anim_rst (anim , image_path , gallery_conf ))
176
166
continue
177
167
# get fig titles
178
168
fig_titles = _matplotlib_fig_titles (fig )
@@ -183,8 +173,7 @@ def __call__(self, block, block_vars, gallery_conf, **kwargs):
183
173
for attr in ['facecolor' , 'edgecolor' ]:
184
174
fig_attr = getattr (fig , 'get_' + attr )()
185
175
default_attr = matplotlib .rcParams ['figure.' + attr ]
186
- if to_rgba (fig_attr ) != to_rgba (default_attr ) and \
187
- attr not in kwargs :
176
+ if to_rgba (fig_attr ) != to_rgba (default_attr ) and attr not in kwargs :
188
177
these_kwargs [attr ] = fig_attr
189
178
these_kwargs ['bbox_inches' ] = "tight"
190
179
fig .savefig (image_path , ** these_kwargs )
@@ -194,24 +183,28 @@ def __call__(self, block, block_vars, gallery_conf, **kwargs):
194
183
if len (image_rsts ) == 1 :
195
184
rst = image_rsts [0 ]
196
185
elif len (image_rsts ) > 1 :
197
- image_rsts = [re .sub (r':class: sphx-glr-single-img' ,
198
- ':class: sphx-glr-multi-img' ,
199
- image ) for image in image_rsts ]
200
- image_rsts = [HLIST_IMAGE_MATPLOTLIB + indent (image , u' ' * 6 )
201
- for image in image_rsts ]
186
+ image_rsts = [
187
+ re .sub (r':class: sphx-glr-single-img' , ':class: sphx-glr-multi-img' , image )
188
+ for image in image_rsts ]
189
+ image_rsts = [
190
+ HLIST_IMAGE_MATPLOTLIB + indent (image , ' ' * 6 ) for image in image_rsts
191
+ ]
202
192
rst = HLIST_HEADER + '' .join (image_rsts )
203
193
return rst
204
194
205
195
206
- class reset_numpy_random_seed :
196
+ class ResetNumPyRandomSeed :
207
197
208
198
def __init__ (self ):
209
199
self .state = None
210
200
211
201
def __call__ (self , gallery_conf , fname , when ):
212
- import numpy as np
213
202
if when == 'before' :
214
203
self .state = np .random .get_state ()
215
204
elif when == 'after' :
216
205
# Set state attribute back to `None`
217
206
self .state = np .random .set_state (self .state )
207
+
208
+
209
+ gallery_scraper = GalleryScraper ()
210
+ reset_numpy_random_seed = ResetNumPyRandomSeed ()
0 commit comments