Skip to content

Commit

Permalink
Merge pull request #236 from symflower/232-handle-timeout
Browse files Browse the repository at this point in the history
Do not run "symflower fix" if the original response failed with a timeout, so the model and the fix assessments are consistent
  • Loading branch information
bauersimon authored Jul 4, 2024
2 parents 4365556 + 083c69b commit b2a9140
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 5 deletions.
10 changes: 10 additions & 0 deletions evaluate/task/task-write-test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package task

import (
"context"
"errors"
"fmt"
"path/filepath"

Expand Down Expand Up @@ -82,6 +84,14 @@ func (t *TaskWriteTests) Run(ctx evaltask.Context) (repositoryAssessment map[eva
if err != nil {
problems = append(problems, pkgerrors.WithMessage(err, filePath))

// If there is an execution timeout do not run "symflower fix" because the code itself is correct.
if errors.Is(err, context.DeadlineExceeded) {
modelAssessment.Add(modelAssessmentForFile)
withSymflowerAssessment.Add(withSymflowerAssessmentForFile)

continue
}

// Run "symflower fix" if the model response fails to execute.
if ctx.Language.ID() == "golang" { // Currently we only support Go for "symflower fix".
log.Print("model response alone failed execution, attempting to fix with \"symflower fix \"")
Expand Down
34 changes: 29 additions & 5 deletions evaluate/task/task-write-test_test.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
package task

import (
"context"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/symflower/eval-dev-quality/evaluate/metrics"
metricstesting "github.com/symflower/eval-dev-quality/evaluate/metrics/testing"
tasktesting "github.com/symflower/eval-dev-quality/evaluate/task/testing"
"github.com/symflower/eval-dev-quality/language"
"github.com/symflower/eval-dev-quality/language/golang"
languagetesting "github.com/symflower/eval-dev-quality/language/testing"
"github.com/symflower/eval-dev-quality/log"
modeltesting "github.com/symflower/eval-dev-quality/model/testing"
"github.com/symflower/eval-dev-quality/task"
evaltask "github.com/symflower/eval-dev-quality/task"
"github.com/zimmski/osutil"
"github.com/zimmski/osutil/bytesutil"
Expand Down Expand Up @@ -82,7 +87,7 @@ func TestTaskWriteTestsRun(t *testing.T) {

t.Run("Symflower Fix", func(t *testing.T) {
t.Run("Go", func(t *testing.T) {
validateGo := func(t *testing.T, testName string, testFileContent string, expectedAssessments map[evaltask.Identifier]metrics.Assessments, expectedProblems []string, assertTestsPass bool) {
validateGo := func(t *testing.T, testName string, language language.Language, testFileContent string, expectedAssessments map[evaltask.Identifier]metrics.Assessments, expectedProblems []string, assertTestsPass bool) {
temporaryDirectoryPath := t.TempDir()
repositoryPath := filepath.Join(temporaryDirectoryPath, "golang", "plain")
require.NoError(t, osutil.CopyTree(filepath.Join("..", "..", "testdata", "golang", "plain"), repositoryPath))
Expand All @@ -94,7 +99,7 @@ func TestTaskWriteTestsRun(t *testing.T) {
Name: testName,

Model: modelMock,
Language: &golang.Language{},
Language: language,
TestDataPath: temporaryDirectoryPath,
RepositoryPath: filepath.Join("golang", "plain"),

Expand Down Expand Up @@ -123,7 +128,7 @@ func TestTaskWriteTestsRun(t *testing.T) {
metrics.AssessmentKeyCoverage: 10,
},
}
validateGo(t, "Model generated correct test", bytesutil.StringTrimIndentations(`
validateGo(t, "Model generated correct test", &golang.Language{}, bytesutil.StringTrimIndentations(`
package plain
import "testing"
Expand All @@ -147,7 +152,7 @@ func TestTaskWriteTestsRun(t *testing.T) {
expectedProblems := []string{
"imported and not used",
}
validateGo(t, "Model generated test with unused import", bytesutil.StringTrimIndentations(`
validateGo(t, "Model generated test with unused import", &golang.Language{}, bytesutil.StringTrimIndentations(`
package plain
import (
Expand All @@ -173,12 +178,31 @@ func TestTaskWriteTestsRun(t *testing.T) {
"expected declaration, found this",
"unable to format source code",
}
validateGo(t, "Model generated test that is unfixable", bytesutil.StringTrimIndentations(`
validateGo(t, "Model generated test that is unfixable", &golang.Language{}, bytesutil.StringTrimIndentations(`
package plain
this is not valid go code
`), expectedAssessments, expectedProblems, false)
}
{
expectedAssessments := map[task.Identifier]metrics.Assessments{
IdentifierWriteTests: metrics.Assessments{
metrics.AssessmentKeyResponseNoError: 1,
},
IdentifierWriteTestsSymflowerFix: metrics.Assessments{
metrics.AssessmentKeyResponseNoError: 1,
},
}
expectedProblems := []string{
"context deadline exceeded",
}

languageMock := languagetesting.NewMockLanguageNamed(t, "golang")
languageMock.On("Files", mock.Anything, mock.Anything).Return([]string{filepath.Join("golang", "plain")}, nil).Once()
languageMock.On("Execute", mock.Anything, mock.Anything).Return(uint64(0), nil, context.DeadlineExceeded).Once()

validateGo(t, "Execution timeout", languageMock, "", expectedAssessments, expectedProblems, false)
}
})
})
}

0 comments on commit b2a9140

Please sign in to comment.