diff --git a/cmd/eval-dev-quality/cmd/evaluate.go b/cmd/eval-dev-quality/cmd/evaluate.go index 1d3f5460..07b01a41 100644 --- a/cmd/eval-dev-quality/cmd/evaluate.go +++ b/cmd/eval-dev-quality/cmd/evaluate.go @@ -156,7 +156,7 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate. } if command.Parallel == 0 { - command.logger.Panic("the 'parallel' parameter has to be greater then 0") + command.logger.Panic("the 'parallel' parameter has to be greater then zero") } if command.RuntimeImage == "" { @@ -446,13 +446,18 @@ func (command *Evaluate) evaluateLocal(evaluationContext *evaluate.Context) (err // evaluateDocker executes the evaluation for each model inside a docker container. func (command *Evaluate) evaluateDocker(ctx *evaluate.Context) (err error) { - // Filter all the args to pass them onto the container. - args := util.FilterArgs(os.Args[2:], []string{ - "--model", - "--parallel", - "--result-path", - "--runtime", - }) + availableFlags := util.Flags(command) + ignoredFlags := []string{ + "model", + "parallel", + "result-path", + "runtime", + } + + // Filter all the args to only contain flags which can be used. + args := util.FilterArgsKeep(os.Args[2:], availableFlags) + // Filter the args to remove all flags unsuited for running the container. + args = util.FilterArgsRemove(args, ignoredFlags) parallel := util.NewParallel(command.Parallel) @@ -521,13 +526,18 @@ func (command *Evaluate) evaluateKubernetes(ctx *evaluate.Context) (err error) { return pkgerrors.Wrap(err, "could not create kubernetes job template") } - // Filter all the args to pass them onto the container. - args := util.FilterArgs(os.Args[2:], []string{ - "--model", - "--parallel", - "--result-path", - "--runtime", - }) + availableFlags := util.Flags(command) + ignoredFlags := []string{ + "model", + "parallel", + "result-path", + "runtime", + } + + // Filter all the args to only contain flags which can be used. + args := util.FilterArgsKeep(os.Args[2:], availableFlags) + // Filter the args to remove all flags unsuited for running the container. + args = util.FilterArgsRemove(args, ignoredFlags) parallel := util.NewParallel(command.Parallel) diff --git a/cmd/eval-dev-quality/cmd/evaluate_test.go b/cmd/eval-dev-quality/cmd/evaluate_test.go index f941ecdd..51a7fb2b 100644 --- a/cmd/eval-dev-quality/cmd/evaluate_test.go +++ b/cmd/eval-dev-quality/cmd/evaluate_test.go @@ -836,6 +836,7 @@ func TestEvaluateInitialize(t *testing.T) { ValidateCommand func(t *testing.T, command *Evaluate) ValidateContext func(t *testing.T, context *evaluate.Context) ValidateResults func(t *testing.T, resultsPath string) + ValidatePanic string } validate := func(t *testing.T, tc *testCase) { @@ -854,6 +855,17 @@ func TestEvaluateInitialize(t *testing.T) { tc.Command.ResultPath = strings.ReplaceAll(tc.Command.ResultPath, "$TEMP_PATH", temporaryDirectory) var actualEvaluationContext *evaluate.Context + + if tc.ValidatePanic != "" { + assert.PanicsWithValue(t, tc.ValidatePanic, func() { + c, cleanup := tc.Command.Initialize([]string{}) + cleanup() + actualEvaluationContext = c + }) + + return + } + assert.NotPanics(t, func() { c, cleanup := tc.Command.Initialize([]string{}) cleanup() @@ -1013,4 +1025,59 @@ func TestEvaluateInitialize(t *testing.T) { } }, }) + validate(t, &testCase{ + Name: "Local runtime does not allow parallel parameter", + + Command: makeValidCommand(func(command *Evaluate) { + command.Runtime = "local" + command.Parallel = 2 + }), + + ValidatePanic: "the 'parallel' parameter can't be used with local execution", + }) + validate(t, &testCase{ + Name: "Attempts parameter hast to be greater then zero", + + Command: makeValidCommand(func(command *Evaluate) { + command.QueryAttempts = 0 + }), + + ValidatePanic: "number of configured query attempts must be greater than zero", + }) + validate(t, &testCase{ + Name: "Execution timeout parameter hast to be greater then zero", + + Command: makeValidCommand(func(command *Evaluate) { + command.ExecutionTimeout = 0 + }), + + ValidatePanic: "execution timeout for compilation and tests must be greater than zero", + }) + validate(t, &testCase{ + Name: "Runs parameter hast to be greater then zero", + + Command: makeValidCommand(func(command *Evaluate) { + command.Runs = 0 + }), + + ValidatePanic: "number of configured runs must be greater than zero", + }) + + t.Run("Docker", func(t *testing.T) { + if osutil.IsDarwin() { // The MacOS runner on Github do not have "docker" in their path and would mess with the test. + t.Skip("Unsupported OS") + } + + validate(t, &testCase{ + Name: "Parallel parameter hast to be greater then zero", + + Command: makeValidCommand(func(command *Evaluate) { + command.Runtime = "docker" + command.Parallel = 0 + }), + + ValidatePanic: "the 'parallel' parameter has to be greater then zero", + }) + }) + } diff --git a/util/exec.go b/util/exec.go index 30ed9bba..9b8ffee8 100644 --- a/util/exec.go +++ b/util/exec.go @@ -5,6 +5,7 @@ import ( "context" "io" "os/exec" + "reflect" "strings" "sync" "time" @@ -62,11 +63,37 @@ func CommandWithResult(ctx context.Context, logger *log.Logger, command *Command return writer.String(), nil } -// FilterArgs parses args and removes the ignored ones. -func FilterArgs(args []string, ignored []string) (filtered []string) { +// Flags returns a list of `long` flags bound on the command or nil. +func Flags(cmd any) (args []string) { + typ := reflect.TypeOf(cmd) + + // Dereference pointer + if typ.Kind() == reflect.Pointer { + typ = typ.Elem() + } + + if typ.Kind() != reflect.Struct { + return nil + } + + for i := 0; i < typ.NumField(); i++ { + field := typ.Field(i) + arg, ok := field.Tag.Lookup("long") + if !ok { + continue + } + + args = append(args, arg) + } + + return args +} + +// FilterArgs filters the arguments by either ignoring/allowing them in the result. +func FilterArgs(args []string, filter []string, ignore bool) (filtered []string) { filterMap := map[string]bool{} - for _, v := range ignored { - filterMap[v] = true + for _, v := range filter { + filterMap["--"+v] = true } // Resolve args with equals sign. @@ -79,13 +106,14 @@ func FilterArgs(args []string, ignored []string) (filtered []string) { } } - skip := false + skip := true for _, v := range resolvedArgs { - if skip && strings.HasPrefix(v, "--") { - skip = false - } - if filterMap[v] { - skip = true + if strings.HasPrefix(v, "--") { + if ignore { + skip = filterMap[v] + } else { + skip = !filterMap[v] + } } if skip { @@ -98,6 +126,16 @@ func FilterArgs(args []string, ignored []string) (filtered []string) { return filtered } +// FilterArgsKeep filters the given argument list and only returns arguments defined present in "filter". +func FilterArgsKeep(args []string, filter []string) (filtered []string) { + return FilterArgs(args, filter, false) +} + +// FilterArgsRemove filters the given argument list and returns arguments where "filter" entries are removed. +func FilterArgsRemove(args []string, filter []string) (filtered []string) { + return FilterArgs(args, filter, true) +} + // Parallel holds a buffered channel for limiting parallel executions. type Parallel struct { ch chan struct{} diff --git a/util/exec_test.go b/util/exec_test.go index 50a6083f..3fc0c0a8 100644 --- a/util/exec_test.go +++ b/util/exec_test.go @@ -38,15 +38,16 @@ func TestFilterArgs(t *testing.T) { type testCase struct { Name string - Args []string - Ignored []string + Args []string + Filter []string + Ignore bool ExpectedFiltered []string } validate := func(t *testing.T, tc *testCase) { t.Run(tc.Name, func(t *testing.T) { - actualFiltered := FilterArgs(tc.Args, tc.Ignored) + actualFiltered := FilterArgs(tc.Args, tc.Filter, tc.Ignore) assert.Equal(t, tc.ExpectedFiltered, actualFiltered) }) @@ -61,9 +62,10 @@ func TestFilterArgs(t *testing.T) { "--runs", "5", }, - Ignored: []string{ - "--runtime", + Filter: []string{ + "runtime", }, + Ignore: true, ExpectedFiltered: []string{ "--runs", @@ -80,9 +82,10 @@ func TestFilterArgs(t *testing.T) { "--foo", "bar", }, - Ignored: []string{ - "--runtime", + Filter: []string{ + "runtime", }, + Ignore: true, ExpectedFiltered: []string{ "--runs", @@ -91,6 +94,103 @@ func TestFilterArgs(t *testing.T) { "bar", }, }) + + validate(t, &testCase{ + Name: "Filter arguments with an allow list", + + Args: []string{ + "--runtime=abc", + "--runs=5", + "--foo", + "bar", + }, + + Filter: []string{ + "runtime", + }, + + ExpectedFiltered: []string{ + "--runtime", + "abc", + }, + }) +} + +func TestFilterArgsKeep(t *testing.T) { + type testCase struct { + Name string + + Args []string + Filter []string + + ExpectedFiltered []string + } + + validate := func(t *testing.T, tc *testCase) { + t.Run(tc.Name, func(t *testing.T) { + actualFiltered := FilterArgsKeep(tc.Args, tc.Filter) + + assert.Equal(t, tc.ExpectedFiltered, actualFiltered) + }) + } + + validate(t, &testCase{ + Name: "Keep arguments", + + Args: []string{ + "--runtime=abc", + "--runs=5", + "--foo", + "bar", + }, + + Filter: []string{ + "runtime", + }, + + ExpectedFiltered: []string{ + "--runtime", + "abc", + }, + }) +} + +func TestFilterArgsRemove(t *testing.T) { + type testCase struct { + Name string + + Args []string + Filter []string + + ExpectedFiltered []string + } + + validate := func(t *testing.T, tc *testCase) { + t.Run(tc.Name, func(t *testing.T) { + actualFiltered := FilterArgsRemove(tc.Args, tc.Filter) + + assert.Equal(t, tc.ExpectedFiltered, actualFiltered) + }) + } + + validate(t, &testCase{ + Name: "Remove arguments", + + Args: []string{ + "--runtime", + "abc", + "--runs", + "5", + }, + Filter: []string{ + "runtime", + }, + + ExpectedFiltered: []string{ + "--runs", + "5", + }, + }) } func TestParallelExecute(t *testing.T) { @@ -145,3 +245,50 @@ func TestParallelExecute(t *testing.T) { }) } + +func TestFlags(t *testing.T) { + type testCase struct { + Name string + + Cmd any + + ExpectedArgs []string + } + + validate := func(t *testing.T, tc *testCase) { + t.Run(tc.Name, func(t *testing.T) { + actualArgs := Flags(tc.Cmd) + + assert.Equal(t, tc.ExpectedArgs, actualArgs) + }) + } + + validate(t, &testCase{ + Name: "Struct Command", + + Cmd: struct { + PropA string `long:"PropA"` + PropB string `long:"PropB"` + PropC string `short:"hey"` + }{}, + + ExpectedArgs: []string{ + "PropA", + "PropB", + }, + }) + + validate(t, &testCase{ + Name: "Pointer Struct Command", + + Cmd: &struct { + PropA string `long:"PropA"` + PropB string `long:"PropB"` + }{}, + + ExpectedArgs: []string{ + "PropA", + "PropB", + }, + }) +}