Skip to content

Commit

Permalink
Check all files generated by running the "report" command, to be more…
Browse files Browse the repository at this point in the history
… explicit about the file system changes

Part of #205
  • Loading branch information
ruiAzevedo19 committed Jul 22, 2024
1 parent 6abdd1a commit 9496913
Showing 1 changed file with 69 additions and 9 deletions.
78 changes: 69 additions & 9 deletions cmd/eval-dev-quality/cmd/report_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import (
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/symflower/eval-dev-quality/evaluate/metrics"
"github.com/symflower/eval-dev-quality/evaluate/report"
"github.com/symflower/eval-dev-quality/log"
"github.com/zimmski/osutil"
Expand Down Expand Up @@ -44,8 +46,8 @@ func TestReportExecute(t *testing.T) {

Arguments func(workingDirectory string) []string

ExpectedEvaluationCSVFileContent string
ExpectedPanicContains string
ExpectedResultFiles map[string]func(t *testing.T, filePath string, data string)
ExpectedPanicContains string
}

validate := func(t *testing.T, tc *testCase) {
Expand Down Expand Up @@ -93,10 +95,26 @@ func TestReportExecute(t *testing.T) {
assert.Contains(t, recovered, tc.ExpectedPanicContains)
}

file, err := os.ReadFile(filepath.Join(resultPath, "evaluation.csv"))
actualResultFiles, err := osutil.FilesRecursive(temporaryPath)
require.NoError(t, err)

assert.Equal(t, tc.ExpectedEvaluationCSVFileContent, string(file))
for i, p := range actualResultFiles {
actualResultFiles[i], err = filepath.Rel(temporaryPath, p)
require.NoError(t, err)
}
sort.Strings(actualResultFiles)
expectedResultFiles := make([]string, 0, len(tc.ExpectedResultFiles))
for filePath, validate := range tc.ExpectedResultFiles {
expectedResultFiles = append(expectedResultFiles, filePath)

if validate != nil {
data, err := os.ReadFile(filepath.Join(temporaryPath, filePath))
if assert.NoError(t, err) {
validate(t, filePath, string(data))
}
}
}
sort.Strings(expectedResultFiles)
assert.Equal(t, expectedResultFiles, actualResultFiles)
})
}

Expand All @@ -109,6 +127,9 @@ func TestReportExecute(t *testing.T) {
}
},

ExpectedResultFiles: map[string]func(t *testing.T, filePath string, data string){
filepath.Join("result-directory", "evaluation.csv"): nil,
},
ExpectedPanicContains: `the path needs to end with "evaluation.csv"`,
})
validate(t, &testCase{
Expand All @@ -125,7 +146,13 @@ func TestReportExecute(t *testing.T) {
}
},

ExpectedEvaluationCSVFileContent: fmt.Sprintf("%s\n%s", strings.Join(report.EvaluationHeader(), ","), claudeEvaluationCSVFileContent),
ExpectedResultFiles: map[string]func(t *testing.T, filePath string, data string){
"evaluation.csv": nil,
filepath.Join("result-directory", "evaluation.csv"): func(t *testing.T, filePath, data string) {
expectedContent := fmt.Sprintf("%s\n%s", strings.Join(report.EvaluationHeader(), ","), claudeEvaluationCSVFileContent)
assert.Equal(t, expectedContent, data)
},
},
})
validate(t, &testCase{
Name: "Multiple files",
Expand All @@ -145,7 +172,15 @@ func TestReportExecute(t *testing.T) {
}
},

ExpectedEvaluationCSVFileContent: fmt.Sprintf("%s\n%s%s%s", strings.Join(report.EvaluationHeader(), ","), claudeEvaluationCSVFileContent, gemmaEvaluationCSVFileContent, gpt4EvaluationCSVFileContent),
ExpectedResultFiles: map[string]func(t *testing.T, filePath string, data string){
filepath.Join("docs", "v5", "claude", "evaluation.csv"): nil,
filepath.Join("docs", "v5", "gemma", "evaluation.csv"): nil,
filepath.Join("docs", "v5", "openrouter", "gpt4", "evaluation.csv"): nil,
filepath.Join("result-directory", "evaluation.csv"): func(t *testing.T, filePath, data string) {
expectedContent := fmt.Sprintf("%s\n%s%s%s", strings.Join(report.EvaluationHeader(), ","), claudeEvaluationCSVFileContent, gemmaEvaluationCSVFileContent, gpt4EvaluationCSVFileContent)
assert.Equal(t, expectedContent, data)
},
},
})
validate(t, &testCase{
Name: "Multiple files with glob pattern",
Expand All @@ -162,8 +197,15 @@ func TestReportExecute(t *testing.T) {
"--evaluation-path", filepath.Join(workingDirectory, "docs", "v5", "*", "evaluation.csv"),
}
},

ExpectedEvaluationCSVFileContent: fmt.Sprintf("%s\n%s%s%s", strings.Join(report.EvaluationHeader(), ","), claudeEvaluationCSVFileContent, gemmaEvaluationCSVFileContent, gpt4EvaluationCSVFileContent),
ExpectedResultFiles: map[string]func(t *testing.T, filePath string, data string){
filepath.Join("docs", "v5", "claude", "evaluation.csv"): nil,
filepath.Join("docs", "v5", "gemma", "evaluation.csv"): nil,
filepath.Join("docs", "v5", "gpt4", "evaluation.csv"): nil,
filepath.Join("result-directory", "evaluation.csv"): func(t *testing.T, filePath, data string) {
expectedContent := fmt.Sprintf("%s\n%s%s%s", strings.Join(report.EvaluationHeader(), ","), claudeEvaluationCSVFileContent, gemmaEvaluationCSVFileContent, gpt4EvaluationCSVFileContent)
assert.Equal(t, expectedContent, data)
},
},
})
}

Expand Down Expand Up @@ -251,3 +293,21 @@ func evaluationFileWithContent(t *testing.T, workingDirectory string, content st
require.NoError(t, os.MkdirAll(workingDirectory, 0700))
require.NoError(t, os.WriteFile(filepath.Join(workingDirectory, "evaluation.csv"), []byte(content), 0700))
}

// validateReportLinks checks if the Markdown report data contains all the links to other relevant report files.
func validateReportLinks(t *testing.T, data string, modelLogNames []string) {
assert.Contains(t, data, "](./categories.svg)")
assert.Contains(t, data, "](./evaluation.csv)")
assert.Contains(t, data, "](./evaluation.log)")
for _, m := range modelLogNames {
assert.Contains(t, data, fmt.Sprintf("](./%s/)", m))
}
}

// validateSVGContent checks if the SVG data contains all given categories and an axis label for the maximal model count.
func validateSVGContent(t *testing.T, data string, categories []*metrics.AssessmentCategory, maxModelCount uint) {
for _, category := range categories {
assert.Contains(t, data, fmt.Sprintf("%s</text>", category.Name))
}
assert.Contains(t, data, fmt.Sprintf("%d</text>", maxModelCount))
}

0 comments on commit 9496913

Please sign in to comment.