Skip to content

Commit e38998c

Browse files
Update
1 parent 538b0aa commit e38998c

13 files changed

+234
-241021
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ render_logs
66
my_example.py
77
pyproject.toml
88
.pypirc
9-
*.egg-info
9+
*.egg-info
10+
*.pkl

README.md

+55-46
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,23 @@ df.sort_index(inplace= True)
4848
df.dropna(inplace= True)
4949
df.drop_duplicates(inplace=True)
5050
```
51+
**1.1 (Optional) Download data** : The package provide a easy way to download data (works with CCTX ans use asyncio to get FAST) :
52+
```python
53+
from gym_trading_env.downloader import download
54+
import datetime
55+
56+
download(
57+
exchange_names = ["binance", "bitfinex2", "huobi"],
58+
symbols= ["BTC/USDT", "ETH/USDT"],
59+
timeframe= "30m",
60+
dir = "test/data",
61+
since= datetime.datetime(year= 2019, month= 1, day=1),
62+
until = datetime.datetime(year= 2023, month= 1, day=1),
63+
)
64+
```
65+
This function use pickle format to save the OHLCV data. You will need to import the dataset with ```pd.read_pickle('... .pkl', ...)```. The function supports exchange_names ```binance```, ```biftfinex2``` (API v2) and ```huobi```.
66+
67+
5168
**2 - Create your feature**. Your RL-agent will need some good, preprocessed features. It is your job to make sure it has everything it needs.
5269
**The feature column names need to contain the keyword 'feature'. The environment will automatically detect them !**
5370

@@ -72,16 +89,16 @@ The history object is similar to a DataFrame. It uses timestep and/or columns to
7289

7390

7491
Accessible columns of history object :
75-
- ```step```: ...,# Step = t
76-
- ```date```: ...,# Date at step t, datetime
77-
- ```reward```: ..., # Reward at step t
78-
- ```position_index```: ..., # Index of the position at step t amoung your position argument
79-
- ```position``` : ..., # Portfolio position at step t
92+
- ```step``` : Step = t
93+
- ```date``` : Date at step t, datetime
94+
- ```reward``` : Reward at step t
95+
- ```position_index``` : Index of the position at step t amoung your position argument
96+
- ```position``` : Portfolio position at step t
8097

8198
*It gathers every data (not used as features) from your DataFrame and labels them with 'data_{column}'. For example :*
82-
- ```data_close```: Close price,
83-
- ```data_open```: Open price,
84-
- ```data_high```:Hight price,
99+
- ```data_close``` : Close price,
100+
- ```data_open``` : Open price,
101+
- ```data_high``` : High price,
85102

86103
*......*
87104

@@ -91,13 +108,13 @@ Accessible columns of history object :
91108
*It stores the distribution of the portfolio :*
92109
- ```portfolio_distribution_asset``` : The amount of owned asset (stock),
93110
- ```portfolio_distribution_fiat``` : The amount of owned fiat currency,
94-
- ```portfolio_distribution_borrowed_asset```: The amount of borrowed asset,
95-
- ```portfolio_distribution_borrowed_fiat```: The amount of borrowed fiat currency,
96-
- ```portfolio_distribution_interest_asset```: The total of cumalated interest generated by the borrowed asset.
97-
- ```portfolio_distribution_interest_fiat```: The total of cumalated interest generated by the borrowed fiat currency.
111+
- ```portfolio_distribution_borrowed_asset``` : The amount of borrowed asset,
112+
- ```portfolio_distribution_borrowed_fiat``` : The amount of borrowed fiat currency,
113+
- ```portfolio_distribution_interest_asset``` : The total of cumalated interest generated by the borrowed asset.
114+
- ```portfolio_distribution_interest_fiat``` : The total of cumalated interest generated by the borrowed fiat currency.
98115

99116

100-
**4 - Create the environment**
117+
**4 - Initiate the environment**
101118

102119
```python
103120
env = TradingEnv(...)
@@ -140,60 +157,52 @@ First, you need to save your results at the end of every episode you want to ren
140157

141158
```python
142159
...
143-
# At the end of episode you want to render
160+
# At the end of the episode you want to render
144161
env.save_for_render(dir = "render_logs")
145162
```
146163

147164
Then in the separated render script. You can import and initiate a render object, and run the render in a localhost web app :
148165
```python
149166
from gym_trading_env.renderer import Renderer
150-
renderer = Renderer(render_dir="render_logs")
167+
renderer = Renderer(render_logs_dir="render_logs")
151168
renderer.run()
152169
```
153170

154171
You can add **metrics** and plot **lines** with :
155172
```python
156-
# Add lines
157-
# - Simple Moving Average - 10
158-
renderer.add_scatter(
159-
name = "sma10",
160-
function = lambda df : df["close"].rolling(10).mean(),
161-
scatter_args = {
162-
"line": {"color":'blue'}
163-
})
164-
# - Simple Moving Average - 40
165-
renderer.add_scatter(
166-
name = "sma40",
167-
function = lambda df : df["close"].rolling(40).mean(),
168-
scatter_args = {
169-
"line": {"color": "purple"}
170-
})
171-
172-
# Add metrics
173-
def max_drawdown(df):
174-
current_max = df["portfolio_valuation"].iloc[0]
175-
max_drawdown = 0
176-
for i in range(len(df)):
177-
current_max = max(df["portfolio_valuation"].iloc[i], current_max)
178-
max_drawdown = min(max_drawdown, (df["portfolio_valuation"].iloc[i] - current_max)/current_max)
179-
return f"{max_drawdown*100:0.2f}%"
180-
181-
renderer.add_metric("Max drawdown", max_drawdown)
173+
renderer = Renderer(render_logs_dir="render_logs")
174+
175+
# Add Custom Lines (Simple Moving Average)
176+
renderer.add_line( name= "sma10", function= lambda df : df["close"].rolling(10).mean(), line_options ={"width" : 1, "color": "purple"})
177+
renderer.add_line( name= "sma20", function= lambda df : df["close"].rolling(20).mean(), line_options ={"width" : 1, "color": "blue"})
178+
179+
# Add Custom Metrics (Annualized metrics)
180+
renderer.add_metric(
181+
name = "Annual Market Return",
182+
function = lambda df : f"{ ((df['close'].iloc[-1] / df['close'].iloc[0])**(pd.Timedelta(days=365)/(df.index.values[-1] - df.index.values[0]))-1)*100:0.2f}%"
183+
)
184+
185+
renderer.add_metric(
186+
name = "Annual Portfolio Return",
187+
function = lambda df : f"{((df['portfolio_valuation'].iloc[-1] / df['portfolio_valuation'].iloc[0])**(pd.Timedelta(days=365)/(df.index.values[-1] - df.index.values[0]))-1)*100:0.2f}%"
188+
)
189+
190+
renderer.run()
182191
```
183192

184193
<img alt="Render example" src ="https://github.com/ClementPerroud/Gym-Trading-Env/blob/main/readme_images/render_customization.gif?raw=true" width = "800"/>
185194

186195

187196

188-
```.add_scatter``` takes arguments :
189-
- ```name``` : The name of the scatter
190-
- ```function``` : The function used to compute the line. The function must take an argument ```df``` which is a DateFrame and return a Series, 1D-Array or list.
191-
- ```scatter_args``` : Paramaters added to the go.Scatter object during the process. It can be used to customize your plots. The [documentation of the go.Scatter object](https://plotly.com/python-api-reference/generated/plotly.graph_objects.Scatter.html) might help you.
197+
```.add_line``` takes arguments :
198+
- ```name``` (*required*): The name of the scatter
199+
- ```function``` (*required*): The function used to compute the line. The function must take an argument ```df``` which is a DateFrame and return a Series, 1D-Array or list.
200+
- ```line_options``` : Can contain a dict with keys ```color``` and ```width```
192201

193202

194203
```.add_metric``` takes arguments :
195204
- ```name``` : The name of the metric
196-
- ```function``` : The function used to compute the line. The function must take an argument ```df``` which is a DateFrame and return a scalar.
205+
- ```function``` : The function used to compute the line. The function must take an argument ```df``` which is a DateFrame and return a **string** !
197206

198207

199208
Enjoy :)

readme_images/render.PNG

15.7 KB
Loading

readme_images/render.gif

860 KB
Loading

render.html

-240,965
This file was deleted.
-8.7 MB
Binary file not shown.
-8.68 MB
Binary file not shown.

src/gym_trading_env/downloader.py

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import asyncio
2+
import ccxt.async_support as ccxt
3+
import pandas as pd
4+
import datetime
5+
from tqdm import tqdm
6+
7+
8+
EXCHANGE_LIMIT_RATES = {
9+
"bitfinex2": {
10+
"limit":10_000,
11+
"pause_every": 1,
12+
"pause" : 3, #seconds
13+
},
14+
"binance": {
15+
"limit":1_000,
16+
"pause_every": 10,
17+
"pause" : 1, #seconds
18+
},
19+
"huobi": {
20+
"limit":1_000,
21+
"pause_every": 10,
22+
"pause" : 1, #seconds
23+
}
24+
}
25+
26+
async def _ohlcv(exchange, symbol, timeframe, limit, step_since, timedelta):
27+
result = await exchange.fetch_ohlcv(symbol = symbol, timeframe= timeframe, limit= limit, since=step_since)
28+
result_df = pd.DataFrame(result, columns=["timestamp_open", "open", "high", "low", "close", "volume"])
29+
result_df["date_open"] = pd.to_datetime(result_df["timestamp_open"], unit= "ms")
30+
result_df["date_close"] = pd.to_datetime(result_df["timestamp_open"] + timedelta, unit= "ms")
31+
return result_df
32+
33+
async def _download_symbol(exchange, symbol, timeframe = '5m', since = int(datetime.datetime(year=2020, month= 1, day= 1).timestamp()*1E3), until = int(datetime.datetime.now().timestamp()*1E3), limit = 1000, pause_every = 10, pause = 1):
34+
timedelta = int(pd.Timedelta(timeframe).to_timedelta64()/1E6)
35+
tasks = []
36+
results = []
37+
for step_since in range(since, until, limit * timedelta):
38+
tasks.append(
39+
asyncio.create_task(_ohlcv(exchange, symbol, timeframe, limit, step_since, timedelta))
40+
)
41+
if len(tasks) >= pause_every:
42+
results.extend(await asyncio.gather(*tasks))
43+
await asyncio.sleep(pause)
44+
tasks = []
45+
if len(tasks) > 0 :
46+
results.extend(await asyncio.gather(*tasks))
47+
final_df = pd.concat(results, ignore_index= True)
48+
final_df = final_df.loc[(since < final_df["timestamp_open"]) & (final_df["timestamp_open"] < until), :]
49+
final_df.set_index('date_open', drop=True, inplace=True)
50+
final_df.sort_index(inplace= True)
51+
final_df.dropna(inplace=True)
52+
final_df.drop_duplicates(inplace=True)
53+
return final_df
54+
55+
async def _download_symbols(exchange_name, symbols, dir, timeframe, **kwargs):
56+
exchange = getattr(ccxt, exchange_name)({ 'enableRateLimit': True })
57+
for symbol in symbols:
58+
df = await _download_symbol(exchange = exchange, symbol = symbol, timeframe= timeframe, **kwargs)
59+
df.to_pickle(f"{dir}/{exchange_name}-{symbol.replace('/', '')}-{timeframe}.pkl")
60+
await exchange.close()
61+
62+
async def _download(exchange_names, symbols, timeframe, dir, since : datetime.datetime, until : datetime.datetime = datetime.datetime.now()):
63+
tasks = []
64+
for exchange_name in exchange_names:
65+
66+
limit = EXCHANGE_LIMIT_RATES[exchange_name]["limit"]
67+
pause_every = EXCHANGE_LIMIT_RATES[exchange_name]["pause_every"]
68+
pause = EXCHANGE_LIMIT_RATES[exchange_name]["pause"]
69+
tasks.append(
70+
_download_symbols(
71+
exchange_name = exchange_name, symbols= symbols, timeframe= timeframe, dir = dir,
72+
limit = limit, pause_every = pause_every, pause = pause,
73+
since = int(since.timestamp()*1E3), until = int(until.timestamp()*1E3)
74+
)
75+
)
76+
await asyncio.gather(*tasks)
77+
def download(*args, **kwargs):
78+
loop = asyncio.get_event_loop()
79+
loop.run_until_complete(
80+
_download(*args, **kwargs)
81+
)
82+
83+
async def main():
84+
await _download(
85+
["binance", "bitfinex2", "huobi"],
86+
symbols= ["BTC/USDT", "ETH/USDT"],
87+
timeframe= "30m",
88+
dir = "test/data",
89+
since= datetime.datetime(year= 2019, month= 1, day=1),
90+
)
91+
92+
93+
94+
if __name__ == "__main__":
95+
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
96+
asyncio.run(main())

src/gym_trading_env/environments.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import datetime
66
import glob
7+
from pathlib import Path
78

89
from collections import Counter
910
from .utils.history import History
@@ -81,9 +82,8 @@ def _get_obs(self):
8182
"position" : self._position
8283
}
8384

84-
def reset(self, seed = None, df = None):
85+
def reset(self, seed = None):
8586
super().reset(seed = seed)
86-
if df is not None: self._set_df(df)
8787
self._step = 0
8888
self._limit_orders = {}
8989
if self.windows is not None: self._step = self.windows
@@ -178,19 +178,28 @@ def save_for_render(self, dir = "render_logs"):
178178
render_df.to_pickle(f"{dir}/{self.name}_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl")
179179

180180
class MultiDatasetTradingEnv(TradingEnv):
181-
def __init(self, dataset_dir, *args, **kwargs):
181+
def __init__(self, dataset_dir, preprocess, *args, **kwargs):
182182
self.dataset_dir = dataset_dir
183+
self.preprocess = preprocess
183184
self.dataset_pathes = glob.glob(self.dataset_dir)
184185
self.dataset_nb_uses = np.zeros(shape=(len(self.dataset_pathes), ))
185-
df = self.pick_dataset()
186-
super().__init__(df, *args, **kwargs)
186+
super().__init__(self.next_dataset(), *args, **kwargs)
187187

188-
def pick_dataset(self):
188+
def next_dataset(self):
189189
# Find the indexes of the less explored dataset
190190
potential_dataset_pathes = np.where(self.dataset_nb_uses == self.dataset_nb_uses.min())[0]
191191
# Pick one of them
192192
random_int = np.random.randint(potential_dataset_pathes.size)
193193
dataset_path = self.dataset_pathes[random_int]
194194
self.dataset_nb_uses[random_int] += 1 # Update nb use counts
195-
return pd.read_pickle(dataset_path)
195+
196+
self.name = Path(dataset_path).name
197+
return self.preprocess(pd.read_pickle(dataset_path))
198+
199+
def reset(self, seed=None):
200+
self._set_df(
201+
self.next_dataset()
202+
)
203+
print(f"Selected dataset {self.name} ...")
204+
return super().reset(seed)
196205

src/gym_trading_env/templates/index.html

+3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
padding: 10px;
2828
border: 1px solid #ededed;
2929
border-radius: 5%;
30+
max-width: 100%;
31+
white-space:nowrap;
32+
text-overflow:ellipsis;
3033
}
3134
#metrics{
3235
display: flex;

test/example_download.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
2+
import sys
3+
sys.path.append("./src")
4+
5+
6+
from gym_trading_env.downloader import download
7+
import datetime
8+
9+
download(
10+
exchange_names = ["binance", "bitfinex2", "huobi"],
11+
symbols= ["BTC/USDT", "ETH/USDT"],
12+
timeframe= "30m",
13+
dir = "test/data",
14+
since= datetime.datetime(year= 2019, month= 1, day=1),
15+
)

test/example_environnement.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ def reward_function(history):
3939
)
4040

4141
# Run the simulation
42-
truncated = False
42+
truncated, done = False, False
4343
observation, info = env.reset()
44-
while not truncated:
44+
while not truncated and not done:
4545
action = 5 #OR manually : action = int(input("Action : "))
4646
observation, reward, done, truncated, info = env.step(action)
4747

test/example_multi_environnement.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import sys
2+
sys.path.append("./src")
3+
4+
import pandas as pd
5+
import numpy as np
6+
import time
7+
from gym_trading_env.environments import MultiDatasetTradingEnv
8+
9+
# Generating features
10+
# WARNING : the column names need to contain keyword 'feature' !
11+
def preprocess(df):
12+
df["feature_close"] = df["close"].pct_change()
13+
df["feature_open"] = df["open"]/df["close"]
14+
df["feature_high"] = df["high"]/df["close"]
15+
df["feature_low"] = df["low"]/df["close"]
16+
df["feature_volume"] = df["volume"] / df["volume"].rolling(7*24).max()
17+
df.dropna(inplace= True)
18+
return df
19+
20+
21+
# Create your own reward function with the history object
22+
def reward_function(history):
23+
return np.log(history["portfolio_valuation", -1] / history["portfolio_valuation", -2]) #log (p_t / p_t-1 )
24+
25+
env = MultiDatasetTradingEnv(
26+
dataset_dir= 'test/data/*.pkl',
27+
preprocess= preprocess,
28+
windows= 5,
29+
positions = [ -1, -0.5, 0, 0.5, 1, 1.5, 2], # From -1 (=full SHORT), to +1 (=full LONG) with 0 = no position
30+
initial_position = 0, #Initial position
31+
trading_fees = 0.01/100, # 0.01% per stock buy / sell
32+
borrow_interest_rate= 0.0003/100, #per timestep (= 1h here)
33+
reward_function = reward_function,
34+
portfolio_initial_value = 1000, # in FIAT (here, USD)
35+
)
36+
37+
# Run the simulation
38+
truncated = False
39+
observation, info = env.reset()
40+
while not truncated:
41+
action = env.action_space.sample() #OR manually : action = int(input("Action : "))
42+
observation, reward, done, truncated, info = env.step(action)
43+
44+
# Render
45+
env.save_for_render()

0 commit comments

Comments
 (0)