From f668c9fb169f6549676c8679f3631d6756d30fe9 Mon Sep 17 00:00:00 2001 From: Rui Azevedo Date: Tue, 2 Jul 2024 15:40:18 +0100 Subject: [PATCH] Do not run "symflower fix" if the original response failed with a timeout, so the model and the fix assessments are consistent Closes #232 --- evaluate/task/task-write-test.go | 10 ++++++++ evaluate/task/task-write-test_test.go | 33 +++++++++++++++++++++++---- 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/evaluate/task/task-write-test.go b/evaluate/task/task-write-test.go index 9b7cb776f..c171bc175 100644 --- a/evaluate/task/task-write-test.go +++ b/evaluate/task/task-write-test.go @@ -1,6 +1,8 @@ package task import ( + "context" + "errors" "path/filepath" pkgerrors "github.com/pkg/errors" @@ -97,6 +99,14 @@ func (t *TaskWriteTests) Run(repository evaltask.Repository) (repositoryAssessme if err != nil { problems = append(problems, pkgerrors.WithMessage(err, filePath)) + // If there is an execution timeout do not run "symflower fix" because the code itself is correct. + if errors.Is(err, context.DeadlineExceeded) { + modelAssessment.Add(modelAssessmentForFile) + withSymflowerAssessment.Add(withSymflowerAssessmentForFile) + + continue + } + // Run "symflower fix" if the model response fails to execute. if t.Language.ID() == "golang" { // Currently we only support Go for "symflower fix". log.Print("model response alone failed execution, attempting to fix with \"symflower fix \"") diff --git a/evaluate/task/task-write-test_test.go b/evaluate/task/task-write-test_test.go index 03517352e..cfe2b8263 100644 --- a/evaluate/task/task-write-test_test.go +++ b/evaluate/task/task-write-test_test.go @@ -1,16 +1,20 @@ package task import ( + "context" "os" "path/filepath" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/symflower/eval-dev-quality/evaluate/metrics" metricstesting "github.com/symflower/eval-dev-quality/evaluate/metrics/testing" tasktesting "github.com/symflower/eval-dev-quality/evaluate/task/testing" + "github.com/symflower/eval-dev-quality/language" "github.com/symflower/eval-dev-quality/language/golang" + languagetesting "github.com/symflower/eval-dev-quality/language/testing" "github.com/symflower/eval-dev-quality/log" modeltesting "github.com/symflower/eval-dev-quality/model/testing" "github.com/symflower/eval-dev-quality/task" @@ -87,7 +91,7 @@ func TestTaskWriteTestsRun(t *testing.T) { t.Run("Symflower Fix", func(t *testing.T) { t.Run("Go", func(t *testing.T) { - validateGo := func(t *testing.T, testName string, testFileContent string, expectedAssessments map[task.Identifier]metrics.Assessments, expectedProblems []string, assertTestsPass bool) { + validateGo := func(t *testing.T, testName string, language language.Language, testFileContent string, expectedAssessments map[task.Identifier]metrics.Assessments, expectedProblems []string, assertTestsPass bool) { temporaryDirectoryPath := t.TempDir() repositoryPath := filepath.Join(temporaryDirectoryPath, "golang", "plain") require.NoError(t, osutil.CopyTree(filepath.Join("..", "..", "testdata", "golang", "plain"), repositoryPath)) @@ -99,7 +103,7 @@ func TestTaskWriteTestsRun(t *testing.T) { Name: testName, Model: modelMock, - Language: &golang.Language{}, + Language: language, TestDataPath: temporaryDirectoryPath, RepositoryPath: filepath.Join("golang", "plain"), @@ -128,7 +132,7 @@ func TestTaskWriteTestsRun(t *testing.T) { metrics.AssessmentKeyCoverage: 10, }, } - validateGo(t, "Model generated correct test", bytesutil.StringTrimIndentations(` + validateGo(t, "Model generated correct test", &golang.Language{}, bytesutil.StringTrimIndentations(` package plain import "testing" @@ -152,7 +156,7 @@ func TestTaskWriteTestsRun(t *testing.T) { expectedProblems := []string{ "imported and not used", } - validateGo(t, "Model generated test with unused import", bytesutil.StringTrimIndentations(` + validateGo(t, "Model generated test with unused import", &golang.Language{}, bytesutil.StringTrimIndentations(` package plain import ( @@ -178,12 +182,31 @@ func TestTaskWriteTestsRun(t *testing.T) { "expected declaration, found this", "unable to format source code", } - validateGo(t, "Model generated test that is unfixable", bytesutil.StringTrimIndentations(` + validateGo(t, "Model generated test that is unfixable", &golang.Language{}, bytesutil.StringTrimIndentations(` package plain this is not valid go code `), expectedAssessments, expectedProblems, false) } + { + expectedAssessments := map[task.Identifier]metrics.Assessments{ + IdentifierWriteTests: metrics.Assessments{ + metrics.AssessmentKeyResponseNoError: 1, + }, + IdentifierWriteTestsSymflowerFix: metrics.Assessments{ + metrics.AssessmentKeyResponseNoError: 1, + }, + } + expectedProblems := []string{ + "context deadline exceeded", + } + + languageMock := languagetesting.NewMockLanguageNamed(t, "golang") + languageMock.On("Files", mock.Anything, mock.Anything).Return([]string{filepath.Join("golang", "plain")}, nil).Once() + languageMock.On("Execute", mock.Anything, mock.Anything).Return(uint64(0), nil, context.DeadlineExceeded).Once() + + validateGo(t, "Execution timeout", languageMock, "", expectedAssessments, expectedProblems, false) + } }) }) }