Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(rpc): TransactionTrace matches rpc v6 specs #2581

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
318 changes: 318 additions & 0 deletions rpc/v6/adapters.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
package rpcv6

import (
"errors"

"github.com/NethermindEth/juno/core/felt"
"github.com/NethermindEth/juno/starknet"
"github.com/NethermindEth/juno/utils"
"github.com/NethermindEth/juno/vm"
)

/****************************************************
VM Adapters
*****************************************************/

func AdaptVMTransactionTrace(trace *vm.TransactionTrace) TransactionTrace {
var validateInvocation *FunctionInvocation
if trace.ValidateInvocation != nil {
validateInvocation = utils.Ptr(AdaptVMFunctionInvocation(trace.ValidateInvocation))
}

var executeInvocation *ExecuteInvocation
if trace.ExecuteInvocation != nil {
executeInvocation = utils.Ptr(AdaptVMExecuteInvocation(trace.ExecuteInvocation))
}

var feeTransferInvocation *FunctionInvocation
if trace.FeeTransferInvocation != nil {
feeTransferInvocation = utils.Ptr(AdaptVMFunctionInvocation(trace.FeeTransferInvocation))
}

var constructorInvocation *FunctionInvocation
if trace.ConstructorInvocation != nil {
constructorInvocation = utils.Ptr(AdaptVMFunctionInvocation(trace.ConstructorInvocation))
}

var functionInvocation *FunctionInvocation
if trace.FunctionInvocation != nil {
functionInvocation = utils.Ptr(AdaptVMFunctionInvocation(trace.FunctionInvocation))
}

var stateDiff *StateDiff
if trace.StateDiff != nil {
stateDiff = utils.Ptr(AdaptVMStateDiff(trace.StateDiff))
}

return TransactionTrace{
Type: TransactionType(trace.Type),
ValidateInvocation: validateInvocation,
ExecuteInvocation: executeInvocation,
FeeTransferInvocation: feeTransferInvocation,
ConstructorInvocation: constructorInvocation,
FunctionInvocation: functionInvocation,
StateDiff: stateDiff,
}
}

func AdaptVMExecuteInvocation(vmFnInvocation *vm.ExecuteInvocation) ExecuteInvocation {
var functionInvocation *FunctionInvocation
if vmFnInvocation.FunctionInvocation != nil {
functionInvocation = utils.Ptr(AdaptVMFunctionInvocation(vmFnInvocation.FunctionInvocation))
}

return ExecuteInvocation{
RevertReason: vmFnInvocation.RevertReason,
FunctionInvocation: functionInvocation,
}
}

func AdaptVMFunctionInvocation(vmFnInvocation *vm.FunctionInvocation) FunctionInvocation {
// Adapt inner calls
adaptedCalls := make([]FunctionInvocation, len(vmFnInvocation.Calls))
for index := range vmFnInvocation.Calls {
adaptedCalls[index] = AdaptVMFunctionInvocation(&vmFnInvocation.Calls[index])
}

// Adapt execution resources
var adaptedResources *ComputationResources
if r := vmFnInvocation.ExecutionResources; r != nil {
adaptedResources = &ComputationResources{
Steps: r.Steps,
MemoryHoles: r.MemoryHoles,
Pedersen: r.Pedersen,
RangeCheck: r.RangeCheck,
Bitwise: r.Bitwise,
Ecdsa: r.Ecdsa,
EcOp: r.EcOp,
Keccak: r.Keccak,
Poseidon: r.Poseidon,
SegmentArena: r.SegmentArena,
}
}

// Adapt events
adaptedEvents := make([]OrderedEvent, len(vmFnInvocation.Events))
for index := range vmFnInvocation.Events {
vmEvent := &vmFnInvocation.Events[index]

adaptedEvents[index] = OrderedEvent{
Order: vmEvent.Order,
Keys: vmEvent.Keys,
Data: vmEvent.Data,
}
}

return FunctionInvocation{
ContractAddress: vmFnInvocation.ContractAddress,
EntryPointSelector: vmFnInvocation.EntryPointSelector,
Calldata: vmFnInvocation.Calldata,
CallerAddress: vmFnInvocation.CallerAddress,
ClassHash: vmFnInvocation.ClassHash,
EntryPointType: vmFnInvocation.EntryPointType,
CallType: vmFnInvocation.CallType,
Result: vmFnInvocation.Result,
Calls: adaptedCalls,
Events: adaptedEvents,
Messages: vmFnInvocation.Messages,
ExecutionResources: adaptedResources,
}
}

func AdaptVMStateDiff(vmStateDiff *vm.StateDiff) StateDiff {
// Adapt storage diffs
adaptedStorageDiffs := make([]StorageDiff, len(vmStateDiff.StorageDiffs))
for index := range vmStateDiff.StorageDiffs {
vmStorageDiff := &vmStateDiff.StorageDiffs[index]

// Adapt storage entries
adaptedEntries := make([]Entry, len(vmStorageDiff.StorageEntries))
for entryIndex := range vmStorageDiff.StorageEntries {
vmEntry := &vmStorageDiff.StorageEntries[entryIndex]

adaptedEntries[entryIndex] = Entry{
Key: vmEntry.Key,
Value: vmEntry.Value,
}
}

adaptedStorageDiffs[index] = StorageDiff{
Address: vmStorageDiff.Address,
StorageEntries: adaptedEntries,
}
}

// Adapt nonces
adaptedNonces := make([]Nonce, len(vmStateDiff.Nonces))
for index := range vmStateDiff.Nonces {
vmNonce := &vmStateDiff.Nonces[index]

adaptedNonces[index] = Nonce{
ContractAddress: vmNonce.ContractAddress,
Nonce: vmNonce.Nonce,
}
}

// Adapt deployed contracts
adaptedDeployedContracts := make([]DeployedContract, len(vmStateDiff.DeployedContracts))
for index := range vmStateDiff.DeployedContracts {
vmDeployedContract := &vmStateDiff.DeployedContracts[index]

adaptedDeployedContracts[index] = DeployedContract{
Address: vmDeployedContract.Address,
ClassHash: vmDeployedContract.ClassHash,
}
}

// Adapt declared classes
adaptedDeclaredClasses := make([]DeclaredClass, len(vmStateDiff.DeclaredClasses))
for index := range vmStateDiff.DeclaredClasses {
vmDeclaredClass := &vmStateDiff.DeclaredClasses[index]

adaptedDeclaredClasses[index] = DeclaredClass{
ClassHash: vmDeclaredClass.ClassHash,
CompiledClassHash: vmDeclaredClass.CompiledClassHash,
}
}

// Adapt replaced classes
adaptedReplacedClasses := make([]ReplacedClass, len(vmStateDiff.ReplacedClasses))
for index := range vmStateDiff.ReplacedClasses {
vmReplacedClass := &vmStateDiff.ReplacedClasses[index]

adaptedReplacedClasses[index] = ReplacedClass{
ContractAddress: vmReplacedClass.ContractAddress,
ClassHash: vmReplacedClass.ClassHash,
}
}

return StateDiff{
StorageDiffs: adaptedStorageDiffs,
Nonces: adaptedNonces,
DeployedContracts: adaptedDeployedContracts,
DeprecatedDeclaredClasses: vmStateDiff.DeprecatedDeclaredClasses,
DeclaredClasses: adaptedDeclaredClasses,
ReplacedClasses: adaptedReplacedClasses,
}
}

/****************************************************
Feeder Adapters
*****************************************************/

func AdaptFeederBlockTrace(block *BlockWithTxs, blockTrace *starknet.BlockTrace) ([]TracedBlockTransaction, error) {
if blockTrace == nil {
return nil, nil
}

if len(block.Transactions) != len(blockTrace.Traces) {
return nil, errors.New("mismatched number of txs and traces")
}

// Adapt every feeder block trace to rpc v6 trace
adaptedTraces := make([]TracedBlockTransaction, len(blockTrace.Traces))
for index := range blockTrace.Traces {
feederTrace := &blockTrace.Traces[index]

trace := TransactionTrace{
Type: block.Transactions[index].Type,
}

if fee := feederTrace.FeeTransferInvocation; fee != nil {
trace.FeeTransferInvocation = utils.Ptr(AdaptFeederFunctionInvocation(fee))
}

if val := feederTrace.ValidateInvocation; val != nil {
trace.ValidateInvocation = utils.Ptr(AdaptFeederFunctionInvocation(val))
}

if fct := feederTrace.FunctionInvocation; fct != nil {
fnInvocation := utils.Ptr(AdaptFeederFunctionInvocation(fct))

switch block.Transactions[index].Type {
case TxnDeploy, TxnDeployAccount:
trace.ConstructorInvocation = fnInvocation
case TxnInvoke:
trace.ExecuteInvocation = new(ExecuteInvocation)
if feederTrace.RevertError != "" {
trace.ExecuteInvocation.RevertReason = feederTrace.RevertError
} else {
trace.ExecuteInvocation.FunctionInvocation = fnInvocation
}
case TxnL1Handler:
trace.FunctionInvocation = fnInvocation
}
}

adaptedTraces[index] = TracedBlockTransaction{
TransactionHash: &feederTrace.TransactionHash,
TraceRoot: &trace,
}
}

return adaptedTraces, nil
}

func AdaptFeederFunctionInvocation(snFnInvocation *starknet.FunctionInvocation) FunctionInvocation {
// Adapt internal calls
adaptedCalls := make([]FunctionInvocation, len(snFnInvocation.InternalCalls))
for index := range snFnInvocation.InternalCalls {
adaptedCalls[index] = AdaptFeederFunctionInvocation(&snFnInvocation.InternalCalls[index])
}

// Adapt events
adaptedEvents := make([]OrderedEvent, len(snFnInvocation.Events))
for index := range snFnInvocation.Events {
snEvent := &snFnInvocation.Events[index]

adaptedEvents[index] = OrderedEvent{
Order: snEvent.Order,
Keys: utils.Map(snEvent.Keys, utils.Ptr[felt.Felt]),
Data: utils.Map(snEvent.Data, utils.Ptr[felt.Felt]),
}
}

// Adapt messages
adaptedMessages := make([]vm.OrderedL2toL1Message, len(snFnInvocation.Messages))
for index := range snFnInvocation.Messages {
snMessage := &snFnInvocation.Messages[index]

adaptedMessages[index] = vm.OrderedL2toL1Message{
Order: snMessage.Order,
To: snMessage.ToAddr,
Payload: utils.Map(snMessage.Payload, utils.Ptr[felt.Felt]),
}
}

return FunctionInvocation{
ContractAddress: snFnInvocation.ContractAddress,
EntryPointSelector: snFnInvocation.Selector,
Calldata: snFnInvocation.Calldata,
CallerAddress: snFnInvocation.CallerAddress,
ClassHash: snFnInvocation.ClassHash,
EntryPointType: snFnInvocation.EntryPointType,
CallType: snFnInvocation.CallType,
Result: snFnInvocation.Result,
Calls: adaptedCalls,
Events: adaptedEvents,
Messages: adaptedMessages,
ExecutionResources: utils.Ptr(adaptFeederExecutionResources(&snFnInvocation.ExecutionResources)),
}
}

func adaptFeederExecutionResources(resources *starknet.ExecutionResources) ComputationResources {
builtins := &resources.BuiltinInstanceCounter

return ComputationResources{
Steps: resources.Steps,
MemoryHoles: resources.MemoryHoles,
Pedersen: builtins.Pedersen,
RangeCheck: builtins.RangeCheck,
Bitwise: builtins.Bitwise,
Ecdsa: builtins.Ecsda,
EcOp: builtins.EcOp,
Keccak: builtins.Keccak,
Poseidon: builtins.Poseidon,
SegmentArena: builtins.SegmentArena,
}
}
2 changes: 1 addition & 1 deletion rpc/v6/estimate_fee.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ type FeeEstimate struct {
func (h *Handler) EstimateFee(broadcastedTxns []BroadcastedTransaction,
simulationFlags []SimulationFlag, id BlockID,
) ([]FeeEstimate, *jsonrpc.Error) {
result, err := h.simulateTransactions(id, broadcastedTxns, append(simulationFlags, SkipFeeChargeFlag), true, true)
result, err := h.simulateTransactions(id, broadcastedTxns, append(simulationFlags, SkipFeeChargeFlag), true)
if err != nil {
return nil, err
}
Expand Down
3 changes: 1 addition & 2 deletions rpc/v6/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ import (
)

type traceCacheKey struct {
blockHash felt.Felt
v0_6Response bool
blockHash felt.Felt
}

type Handler struct {
Expand Down
2 changes: 1 addition & 1 deletion rpc/v6/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func adaptExecutionResources(resources *core.ExecutionResources, v0_6Response bo
Pedersen: resources.BuiltinInstanceCounter.Pedersen,
RangeCheck: resources.BuiltinInstanceCounter.RangeCheck,
Bitwise: resources.BuiltinInstanceCounter.Bitwise,
Ecsda: resources.BuiltinInstanceCounter.Ecsda,
Ecdsa: resources.BuiltinInstanceCounter.Ecsda,
EcOp: resources.BuiltinInstanceCounter.EcOp,
Keccak: resources.BuiltinInstanceCounter.Keccak,
Poseidon: resources.BuiltinInstanceCounter.Poseidon,
Expand Down
Loading
Loading