Skip to content

Commit

Permalink
Merge pull request #36 from symflower/assessment
Browse files Browse the repository at this point in the history
Assessment
  • Loading branch information
zimmski authored Apr 17, 2024
2 parents 6fa7212 + 94679d5 commit 5698aa8
Show file tree
Hide file tree
Showing 12 changed files with 190 additions and 50 deletions.
7 changes: 4 additions & 3 deletions cmd/eval-dev-quality/cmd/evaluate.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"golang.org/x/exp/maps"

"github.com/symflower/eval-dev-quality/evaluate"
"github.com/symflower/eval-dev-quality/evaluate/metrics"
"github.com/symflower/eval-dev-quality/language"
"github.com/symflower/eval-dev-quality/model"
"github.com/symflower/eval-dev-quality/provider"
Expand Down Expand Up @@ -99,12 +100,12 @@ func (command *Evaluate) Execute(args []string) (err error) {

// Check that models and languages can be evaluated by executing the "plain" repositories.
log.Printf("Checking that models and languages can be used for evaluation")
metricsPerModel := map[string]evaluate.Metrics{}
metricsPerModel := map[string]metrics.Metrics{}
problemsPerModel := map[string][]error{}
{
// Ensure we report metrics for every model even if they are excluded.
for _, modelID := range command.Models {
metricsPerModel[modelID] = evaluate.Metrics{}
metricsPerModel[modelID] = metrics.Metrics{}
}

for _, languageID := range command.Languages {
Expand Down Expand Up @@ -167,7 +168,7 @@ func (command *Evaluate) Execute(args []string) (err error) {
log.Printf("Evaluation score for %q: %s", modelID, metricsPerModel[modelID])
}

csv, err := evaluate.FormatStringCSV(metricsPerModel)
csv, err := metrics.FormatStringCSV(metricsPerModel)
if err != nil {
log.Fatalf("ERROR: could not create result summary: %s", err)
}
Expand Down
49 changes: 49 additions & 0 deletions evaluate/metrics/assessment.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package metrics

// AssessmentKey defines a key for a numerical key-value assessment pair.
type AssessmentKey string

var (
// allAssessmentKeys holds all registered assessment keys.
allAssessmentKeys []AssessmentKey
// allAssessmentKeysStrings returns all registered assessment keys as strings.
allAssessmentKeysStrings []string
)

// RegisterAssessmentKey registers a new assessment key.
func RegisterAssessmentKey(key string) AssessmentKey {
assessment := AssessmentKey(key)
allAssessmentKeys = append(allAssessmentKeys, assessment)
allAssessmentKeysStrings = append(allAssessmentKeysStrings, key)

return assessment
}

var (
// AssessmentKeyNoExcessResponse indicates that a model did not produce more content as requested.
AssessmentKeyNoExcessResponse = RegisterAssessmentKey("no-excess-response")
)

// Assessments holds numerical assessment metrics.
type Assessments map[AssessmentKey]uint

// Merge combines two assessments into a new assessment and returns it.
func (a Assessments) Merge(o Assessments) Assessments {
if a == nil {
a = Assessments{}
}
if o == nil {
o = Assessments{}
}

assessments := map[AssessmentKey]uint{}

for _, k := range allAssessmentKeys {
assessments[k] = a[k]
}
for _, k := range allAssessmentKeys {
assessments[k] = o[k]
}

return Assessments(assessments)
}
24 changes: 20 additions & 4 deletions evaluate/metrics.go → evaluate/metrics/metrics.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package evaluate
package metrics

import (
"encoding/csv"
Expand All @@ -13,6 +13,7 @@ import (
)

// Metrics holds numerical benchmarking metrics.
// TODO Move all metrics to assessment. https://github.com/symflower/eval-dev-quality/issues/34
type Metrics struct {
// Executed is the number of benchmarking candidates with successful execution.
Executed uint
Expand All @@ -23,6 +24,9 @@ type Metrics struct {

// Coverage holds the coverage of the benchmarking candidates.
Coverage []float64

// Assessments holds numerical assessments of a generation.
Assessments Assessments
}

// Add sums two metrics objects.
Expand All @@ -33,6 +37,8 @@ func (m Metrics) Add(o Metrics) Metrics {
Total: m.Total + o.Total,

Coverage: append(m.Coverage, o.Coverage...),

Assessments: m.Assessments.Merge(o.Assessments),
}
}

Expand Down Expand Up @@ -69,21 +75,31 @@ func (m Metrics) String() string {
}

// StringCSV returns a CSV row string representation of the metrics.
func (m Metrics) StringCSV() []string {
return []string{
func (m Metrics) StringCSV() (row []string) {
assessment := m.Assessments
if assessment == nil {
assessment = Assessments{}
}

row = []string{
fmt.Sprintf("%d", m.Total),
fmt.Sprintf("%d", m.Executed),
fmt.Sprintf("%d", m.Problems),
fmt.Sprintf("%.0f", m.AverageCoverage()),
}
for _, key := range allAssessmentKeys {
row = append(row, fmt.Sprintf("%d", assessment[key]))
}

return row
}

// FormatStringCSV formats the given metrics as CSV.
func FormatStringCSV(metricsPerModel map[string]Metrics) (string, error) {
var out strings.Builder
csv := csv.NewWriter(&out)

if err := csv.Write([]string{"model", "files-total", "files-executed", "files-problems", "coverage-statement"}); err != nil {
if err := csv.Write(append([]string{"model", "files-total", "files-executed", "files-problems", "coverage-statement"}, allAssessmentKeysStrings...)); err != nil {
return "", err
}
categories := maps.Keys(metricsPerModel)
Expand Down
18 changes: 12 additions & 6 deletions evaluate/metrics_test.go → evaluate/metrics/metrics_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package evaluate
package metrics

import (
"testing"
Expand Down Expand Up @@ -33,8 +33,8 @@ func TestFormatStringCSV(t *testing.T) {
},

ExpectedString: `
model,files-total,files-executed,files-problems,coverage-statement
Model,0,0,0,0
model,files-total,files-executed,files-problems,coverage-statement,no-excess-response
Model,0,0,0,0,0
`,
})
validate(t, &testCase{
Expand All @@ -46,19 +46,25 @@ func TestFormatStringCSV(t *testing.T) {
Executed: 3,
Problems: 2,
Coverage: []float64{100.0},
Assessments: Assessments{
AssessmentKeyNoExcessResponse: 3,
},
},
"ModelB": Metrics{
Total: 4,
Executed: 2,
Problems: 2,
Coverage: []float64{70.0},
Assessments: Assessments{
AssessmentKeyNoExcessResponse: 2,
},
},
},

ExpectedString: `
model,files-total,files-executed,files-problems,coverage-statement
ModelA,5,3,2,100
ModelB,4,2,2,70
model,files-total,files-executed,files-problems,coverage-statement,no-excess-response
ModelA,5,3,2,100,3
ModelB,4,2,2,70,2
`,
})
}
7 changes: 5 additions & 2 deletions evaluate/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ import (
pkgerrors "github.com/pkg/errors"
"github.com/zimmski/osutil"

"github.com/symflower/eval-dev-quality/evaluate/metrics"
"github.com/symflower/eval-dev-quality/language"
"github.com/symflower/eval-dev-quality/model"
)

// EvaluateRepository evaluate a repository with the given model and language.
func EvaluateRepository(model model.Model, language language.Language, repositoryPath string) (metrics Metrics, problems []error, err error) {
func EvaluateRepository(model model.Model, language language.Language, repositoryPath string) (metrics metrics.Metrics, problems []error, err error) {
log.Printf("Evaluating model %q using language %q and repository %q", model.ID(), language.ID(), repositoryPath)
defer func() {
log.Printf("Evaluated model %q using language %q and repository %q: encountered %d problems", model.ID(), language.ID(), repositoryPath, len(problems))
Expand Down Expand Up @@ -45,12 +46,14 @@ func EvaluateRepository(model model.Model, language language.Language, repositor

for _, filePath := range filePaths {
metrics.Total++
if err := model.GenerateTestsForFile(language, temporaryRepositoryPath, filePath); err != nil {
assessments, err := model.GenerateTestsForFile(language, temporaryRepositoryPath, filePath)
if err != nil {
problems = append(problems, pkgerrors.WithMessage(err, filePath))
metrics.Problems++

continue
}
metrics.Assessments = metrics.Assessments.Merge(assessments)

coverage, err := language.Execute(temporaryRepositoryPath)
if err != nil {
Expand Down
17 changes: 11 additions & 6 deletions model/llm/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
pkgerrors "github.com/pkg/errors"
"github.com/zimmski/osutil/bytesutil"

"github.com/symflower/eval-dev-quality/evaluate/metrics"
"github.com/symflower/eval-dev-quality/language"
"github.com/symflower/eval-dev-quality/model"
"github.com/symflower/eval-dev-quality/model/llm/prompt"
Expand Down Expand Up @@ -77,10 +78,10 @@ func (m *llm) ID() (id string) {
}

// GenerateTestsForFile generates test files for the given implementation file in a repository.
func (m *llm) GenerateTestsForFile(language language.Language, repositoryPath string, filePath string) (err error) {
func (m *llm) GenerateTestsForFile(language language.Language, repositoryPath string, filePath string) (assessment metrics.Assessments, err error) {
data, err := os.ReadFile(filepath.Join(repositoryPath, filePath))
if err != nil {
return err
return nil, err
}
fileContent := strings.TrimSpace(string(data))

Expand All @@ -94,20 +95,24 @@ func (m *llm) GenerateTestsForFile(language language.Language, repositoryPath st
ImportPath: importPath,
})
if err != nil {
return err
return nil, err
}

response, err := m.provider.Query(context.Background(), m.model, request)
if err != nil {
return err
return nil, err
}
log.Printf("Model %q responded to query %s with: %s", m.ID(), string(bytesutil.PrefixLines([]byte(request), []byte("\t"))), string(bytesutil.PrefixLines([]byte(response), []byte("\t"))))

testContent := prompt.ParseResponse(response)
assessment, testContent := prompt.ParseResponse(response)

// TODO Ask the model for the test file name or compute it in a more sophisticated manner.
fileExtension := filepath.Ext(filePath)
testFilePath := filepath.Join(repositoryPath, strings.TrimSuffix(filePath, fileExtension)+"_test"+fileExtension)

return os.WriteFile(testFilePath, []byte(testContent), 0644)
if err := os.WriteFile(testFilePath, []byte(testContent), 0644); err != nil {
return nil, pkgerrors.WithStack(err)
}

return assessment, nil
}
9 changes: 8 additions & 1 deletion model/llm/llm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/zimmski/osutil/bytesutil"

"github.com/symflower/eval-dev-quality/evaluate/metrics"
"github.com/symflower/eval-dev-quality/language"
providertesting "github.com/symflower/eval-dev-quality/provider/testing"
)
Expand All @@ -26,6 +27,7 @@ func TestModelLLMGenerateTestsForFile(t *testing.T) {
SourceFileContent string
SourceFilePath string

ExpectedAssessment metrics.Assessments
ExpectedTestFileContent string
ExpectedTestFilePath string
}
Expand All @@ -42,7 +44,9 @@ func TestModelLLMGenerateTestsForFile(t *testing.T) {
tc.SetupMock(mock)
llm := NewLLMModel(mock, tc.ModelID)

assert.NoError(t, llm.GenerateTestsForFile(tc.Language, temporaryPath, tc.SourceFilePath))
actualAssessment, actualError := llm.GenerateTestsForFile(tc.Language, temporaryPath, tc.SourceFilePath)
assert.NoError(t, actualError)
assert.Equal(t, tc.ExpectedAssessment, actualAssessment)

actualTestFileContent, err := os.ReadFile(filepath.Join(temporaryPath, tc.ExpectedTestFilePath))
assert.NoError(t, err)
Expand Down Expand Up @@ -83,6 +87,9 @@ func TestModelLLMGenerateTestsForFile(t *testing.T) {
SourceFileContent: sourceFileContent,
SourceFilePath: sourceFilePath,

ExpectedAssessment: metrics.Assessments{
metrics.AssessmentKeyNoExcessResponse: 1,
},
ExpectedTestFileContent: `
package native
Expand Down
32 changes: 24 additions & 8 deletions model/llm/prompt/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,41 @@ import (
"regexp"
"strings"

"github.com/symflower/eval-dev-quality/evaluate/metrics"
"github.com/zimmski/osutil/bytesutil"
)

var (
codeTagRe = regexp.MustCompile("(^|\n)\\s*```\\w*($|\n)")
codeTagDuplicatedRe = regexp.MustCompile("```(\\s|\n)*```")
codeTagMatch = regexp.MustCompile("(^|\n)\\s*```\\w*($|\n)")
codeTagDuplicatedMatch = regexp.MustCompile("```(\\s|\n)*```")
)

// ParseResponse parses code from a model's response.
func ParseResponse(response string) (code string) {
func ParseResponse(response string) (assessment metrics.Assessments, code string) {
assessment = metrics.Assessments{}

// Some models produce duplicated code tags, so unify them if needed.
response = codeTagDuplicatedRe.ReplaceAllString(response, "```")
response = codeTagDuplicatedMatch.ReplaceAllString(response, "```")

blocks := bytesutil.GuardedBlocks(response, codeTagMatch, codeTagMatch)

// When no code blocks are found, assume that just the code is returned.
if len(blocks) == 0 {
assessment[metrics.AssessmentKeyNoExcessResponse] = 1

blocks := bytesutil.GuardedBlocks(response, codeTagRe, codeTagRe)
if len(blocks) == 0 { // When no code blocks are found, assume that just the code is returned.
return strings.TrimSpace(response)
return assessment, strings.TrimSpace(response)
}

// Assume the first code block contains the response code fragment.
block := blocks[0]

return strings.TrimSpace(codeTagRe.ReplaceAllString(block, ""))
// Check if the response contained only that single code block.
responseWithoutBlock := strings.Replace(response, block, "", 1)
if len(strings.TrimSpace(responseWithoutBlock)) == 0 {
assessment[metrics.AssessmentKeyNoExcessResponse] = 1
} else {
assessment[metrics.AssessmentKeyNoExcessResponse] = 0
}

return assessment, strings.TrimSpace(codeTagMatch.ReplaceAllString(block, ""))
}
Loading

0 comments on commit 5698aa8

Please sign in to comment.