From a53fdc599fb3a047594967d1d5d9b686712527a5 Mon Sep 17 00:00:00 2001 From: phqb Date: Thu, 23 May 2024 13:02:42 +0700 Subject: [PATCH] added TickDataProviderMsgp to encode/decode TickDataProvider interface --- entities/pool.go | 18 +-- entities/pool_gen.go | 4 +- entities/pool_test.go | 22 ++++ entities/tickdataprovider_msgpencode.go | 150 ++++++++++++++++++++++++ entities/ticklistdataprovider.go | 4 + go.mod | 5 +- go.sum | 7 +- 7 files changed, 189 insertions(+), 21 deletions(-) create mode 100644 entities/tickdataprovider_msgpencode.go diff --git a/entities/pool.go b/entities/pool.go index 2971f6f..939837d 100644 --- a/entities/pool.go +++ b/entities/pool.go @@ -42,7 +42,7 @@ type Pool struct { SqrtRatioX96 *big.Int Liquidity *big.Int TickCurrent int - TickDataProvider *TickListDataProvider + TickDataProvider *TickDataProviderWrapper token0Price *entities.Price token1Price *entities.Price @@ -113,16 +113,6 @@ func NewPool(tokenA, tokenB *entities.Token, fee constants.FeeAmount, sqrtRatioX token1 = tokenA } - var tickListDataProvider *TickListDataProvider - if ticks != nil { - switch ticks := ticks.(type) { - case *TickListDataProvider: - tickListDataProvider = ticks - default: - return nil, errors.New("unsupported TickDataProvider concrete type") - } - } - return &Pool{ Token0: token0, Token1: token1, @@ -130,7 +120,7 @@ func NewPool(tokenA, tokenB *entities.Token, fee constants.FeeAmount, sqrtRatioX SqrtRatioX96: sqrtRatioX96, Liquidity: liquidity, TickCurrent: tickCurrent, - TickDataProvider: tickListDataProvider, + TickDataProvider: NewTickDataProviderWrapper(ticks), }, nil } @@ -209,7 +199,7 @@ func (p *Pool) GetOutputAmount(inputAmount *entities.CurrencyAmount, sqrtPriceLi swapResult.sqrtRatioX96, swapResult.liquidity, swapResult.currentTick, - p.TickDataProvider, + p.TickDataProvider.Get(), ) if err != nil { return nil, err @@ -252,7 +242,7 @@ func (p *Pool) GetInputAmount(outputAmount *entities.CurrencyAmount, sqrtPriceLi swapResult.sqrtRatioX96, swapResult.liquidity, swapResult.currentTick, - p.TickDataProvider, + p.TickDataProvider.Get(), ) if err != nil { return nil, err diff --git a/entities/pool_gen.go b/entities/pool_gen.go index 6acde45..5a4a8da 100644 --- a/entities/pool_gen.go +++ b/entities/pool_gen.go @@ -114,7 +114,7 @@ func (z *Pool) DecodeMsg(dc *msgp.Reader) (err error) { z.TickDataProvider = nil } else { if z.TickDataProvider == nil { - z.TickDataProvider = new(TickListDataProvider) + z.TickDataProvider = new(TickDataProviderWrapper) } err = z.TickDataProvider.DecodeMsg(dc) if err != nil { @@ -428,7 +428,7 @@ func (z *Pool) UnmarshalMsg(bts []byte) (o []byte, err error) { z.TickDataProvider = nil } else { if z.TickDataProvider == nil { - z.TickDataProvider = new(TickListDataProvider) + z.TickDataProvider = new(TickDataProviderWrapper) } bts, err = z.TickDataProvider.UnmarshalMsg(bts) if err != nil { diff --git a/entities/pool_test.go b/entities/pool_test.go index 308eabf..4b03085 100644 --- a/entities/pool_test.go +++ b/entities/pool_test.go @@ -1,6 +1,7 @@ package entities import ( + "bytes" "math/big" "testing" @@ -9,6 +10,7 @@ import ( "github.com/daoleno/uniswapv3-sdk/utils" "github.com/ethereum/go-ethereum/common" "github.com/stretchr/testify/assert" + "github.com/tinylib/msgp/msgp" ) var ( @@ -237,3 +239,23 @@ func TestGetInputAmount2(t *testing.T) { assert.Equal(t, getInputAmountResult.ReturnedAmount.Quotient(), big.NewInt(7074025631378098)) assert.Equal(t, getInputAmountResult.RemainingAmountOut.Quotient(), big.NewInt(-480436293)) } + +func TestPoolMsgpEndecode(t *testing.T) { + poolWithNilProvider := newTestPool() + poolWithNilProvider.TickDataProvider = nil + pools := []*Pool{ + newTestPool(), + poolWithNilProvider, + } + for _, pool := range pools { + encoded := new(bytes.Buffer) + err := msgp.Encode(encoded, pool) + assert.NoError(t, err) + + decoded := new(Pool) + err = msgp.Decode(encoded, decoded) + assert.NoError(t, err) + + assert.EqualValues(t, pool, decoded) + } +} diff --git a/entities/tickdataprovider_msgpencode.go b/entities/tickdataprovider_msgpencode.go new file mode 100644 index 0000000..31157b3 --- /dev/null +++ b/entities/tickdataprovider_msgpencode.go @@ -0,0 +1,150 @@ +package entities + +import ( + "fmt" + "reflect" + + "github.com/tinylib/msgp/msgp" +) + +var ( + // Mapping from string representation of a TickDataProvider concrete type to the concrete type itself. + // The string representation is used as type discriminator when encoding/decoding TickDataProvider. + tickDataProviderImplMap = map[string]reflect.Type{} +) + +// RegisterTickDataProviderImpl registers the concrete types of an TickDataProvider. +// This function is not thread-safe and should be only call in init(). +func RegisterTickDataProviderImpl(provider TickDataProvider) { + if _, ok := provider.(msgp.Encodable); !ok { + panic("expected provider to implement msgp.Encodable") + } + if _, ok := provider.(msgp.Decodable); !ok { + panic("expected provider to implement msgp.Decodable") + } + if _, ok := provider.(msgp.Marshaler); !ok { + panic("expected provider to implement msgp.Marshaler") + } + if _, ok := provider.(msgp.Unmarshaler); !ok { + panic("expected provider to implement msgp.Unmarshaler") + } + if _, ok := provider.(msgp.Sizer); !ok { + panic("expected provider to implement msgp.Sizer") + } + typ := reflect.ValueOf(provider).Elem().Type() + tickDataProviderImplMap[typ.String()] = typ +} + +// TickDataProviderWrapper is a wrapper of TickDataProvider and is implemented msgp.Encodable, msgp.Decodable, msgp.Marshaler, msgp.Unmarshaler, and msgp.Sizer +type TickDataProviderWrapper struct { + TickDataProvider +} + +func NewTickDataProviderWrapper(provider TickDataProvider) *TickDataProviderWrapper { + if provider == nil { + return nil + } + return &TickDataProviderWrapper{provider} +} + +// Get the inner TickDataProvider, return nil if TickDataProviderWrapper is nil +func (p *TickDataProviderWrapper) Get() TickDataProvider { + if p == nil { + return nil + } + return p.TickDataProvider +} + +// EncodeMsg implements msgp.Encodable +func (p *TickDataProviderWrapper) EncodeMsg(en *msgp.Writer) (err error) { + typ := reflect.ValueOf(p.TickDataProvider).Elem().Type() + err = en.WriteString(typ.String()) + if err != nil { + return + } + + if _, ok := tickDataProviderImplMap[typ.String()]; !ok { + err = fmt.Errorf("unregistered type %s", typ.String()) + return + } + + err = p.TickDataProvider.(msgp.Encodable).EncodeMsg(en) + if err != nil { + err = msgp.WrapError(err, "TickDataProvider") + return + } + return +} + +// DecodeMsg implements msgp.Decodable +func (p *TickDataProviderWrapper) DecodeMsg(dc *msgp.Reader) (err error) { + var typStr string + typStr, err = dc.ReadString() + if err != nil { + return + } + + typ, ok := tickDataProviderImplMap[typStr] + if !ok { + err = fmt.Errorf("unregistered type %s", typStr) + return + } + + providerVal := reflect.New(typ) + err = providerVal.Interface().(msgp.Decodable).DecodeMsg(dc) + if err != nil { + err = msgp.WrapError(err, "TickDataProviderMsgp") + return + } + p.TickDataProvider = providerVal.Interface().(TickDataProvider) + return +} + +// MarshalMsg implements msgp.Marshaler +func (p *TickDataProviderWrapper) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, p.Msgsize()) + typ := reflect.ValueOf(p.TickDataProvider).Elem().Type() + o = msgp.AppendString(o, typ.String()) + + if _, ok := tickDataProviderImplMap[typ.String()]; !ok { + err = fmt.Errorf("unregistered type %s", typ.String()) + return + } + + o, err = p.TickDataProvider.(msgp.Marshaler).MarshalMsg(o) + if err != nil { + err = msgp.WrapError(err, "TickDataProvider") + return + } + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (p *TickDataProviderWrapper) UnmarshalMsg(bts []byte) (o []byte, err error) { + var typStr string + typStr, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + return + } + + typ, ok := tickDataProviderImplMap[typStr] + if !ok { + err = fmt.Errorf("unregistered type %s", typStr) + return + } + + providerVal := reflect.New(typ) + bts, err = providerVal.Interface().(msgp.Unmarshaler).UnmarshalMsg(bts) + if err != nil { + err = msgp.WrapError(err, "TickDataProvider") + return + } + p.TickDataProvider = providerVal.Interface().(TickDataProvider) + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (p *TickDataProviderWrapper) Msgsize() int { + typ := reflect.ValueOf(p.TickDataProvider).Elem().Type() + return msgp.StringPrefixSize + len(typ.String()) + p.TickDataProvider.(msgp.Sizer).Msgsize() +} diff --git a/entities/ticklistdataprovider.go b/entities/ticklistdataprovider.go index ee6d927..e236ca3 100644 --- a/entities/ticklistdataprovider.go +++ b/entities/ticklistdataprovider.go @@ -3,6 +3,10 @@ package entities +func init() { + RegisterTickDataProviderImpl(&TickListDataProvider{}) +} + // A data provider for ticks that is backed by an in-memory array of ticks. type TickListDataProvider struct { ticks []Tick diff --git a/go.mod b/go.mod index ea659cb..c47a7ee 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/KyberNetwork/uniswapv3-sdk go 1.18 replace ( - github.com/daoleno/uniswap-sdk-core v0.1.5 => github.com/KyberNetwork/uniswap-sdk-core v0.1.8 + github.com/daoleno/uniswap-sdk-core v0.1.5 => github.com/KyberNetwork/uniswap-sdk-core v0.1.9 github.com/daoleno/uniswapv3-sdk v0.4.0 => github.com/KyberNetwork/uniswapv3-sdk v0.4.0 ) @@ -13,6 +13,7 @@ require ( github.com/ethereum/go-ethereum v1.10.21 github.com/shopspring/decimal v1.3.1 github.com/stretchr/testify v1.8.0 + github.com/tinylib/msgp v1.1.9 ) require ( @@ -29,12 +30,12 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rjeczalik/notify v0.9.2 // indirect github.com/shirou/gopsutil v3.21.11+incompatible // indirect - github.com/tinylib/msgp v1.1.9 // indirect github.com/tklauser/go-sysconf v0.3.10 // indirect github.com/tklauser/numcpus v0.5.0 // indirect github.com/yusufpapurcu/wmi v1.2.2 // indirect golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect golang.org/x/mod v0.13.0 // indirect + golang.org/x/sync v0.4.0 // indirect golang.org/x/sys v0.13.0 // indirect golang.org/x/tools v0.14.0 // indirect gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce // indirect diff --git a/go.sum b/go.sum index f701893..4309209 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/KyberNetwork/uniswap-sdk-core v0.1.8 h1:Fc5YkvqvVKAIYBeTDXe6N7GRJ/0lR1FVZ3GONYbxNL8= -github.com/KyberNetwork/uniswap-sdk-core v0.1.8/go.mod h1:ih9PJ/qgEbc1VbWOWWPeiH02a+inIBm2+XNtZmONcyY= +github.com/KyberNetwork/uniswap-sdk-core v0.1.9 h1:04GqfoYM1vTk55HGlIZfVJ3Wo1W8lsAelW6L+qkfv6c= +github.com/KyberNetwork/uniswap-sdk-core v0.1.9/go.mod h1:ih9PJ/qgEbc1VbWOWWPeiH02a+inIBm2+XNtZmONcyY= github.com/KyberNetwork/uniswapv3-sdk v0.4.0 h1:hbTeJBFgFqYqYTduGuEnb4JIvCtcmuvBTFuRARJIa1Y= github.com/KyberNetwork/uniswapv3-sdk v0.4.0/go.mod h1:K+cqy6zkitxxfShghmuoVwjGJWO16FTXAV+dvddXtgw= github.com/VictoriaMetrics/fastcache v1.6.0 h1:C/3Oi3EiBCqufydp1neRZkqcwmEiuRT9c3fqvvgKm5o= @@ -90,7 +90,8 @@ golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d h1:sK3txAijHtOK88l68nt020 golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/mod v0.13.0 h1:I/DsJXRlw/8l/0c24sM9yb0T4z9liZTduXvdAWYiysY= golang.org/x/mod v0.13.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= +golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ= +golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20180926160741-c2ed4eda69e7/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=