diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 1fe19dc..8773a8d 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -26,9 +26,7 @@ builds: goarch: arm64 ldflags: - -s -w - - -X github.com/omarluq/librecode/internal/vinfo.Version={{.Version}} - - -X github.com/omarluq/librecode/internal/vinfo.Commit={{.Commit}} - - -X github.com/omarluq/librecode/internal/vinfo.BuildDate={{.Date}} + - -X github.com/omarluq/librecode/internal/vinfo.version={{.Version}}|{{.Commit}}|{{.Date}} archives: - id: default diff --git a/Taskfile.yml b/Taskfile.yml index c72d062..d0e1a66 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -41,10 +41,9 @@ tasks: sh: git rev-parse --short HEAD 2>/dev/null || printf none BUILD_DATE: sh: date -u +"%Y-%m-%dT%H:%M:%SZ" + BUILD_METADATA: '{{.VERSION}}|{{.COMMIT}}|{{.BUILD_DATE}}' LDFLAGS: >- - -X github.com/omarluq/librecode/internal/vinfo.version={{.VERSION}} - - + -X github.com/omarluq/librecode/internal/vinfo.version={{.BUILD_METADATA}} cmds: - go build -ldflags "{{.LDFLAGS}}" -o {{.BUILD_DIR}}/{{.BINARY_NAME}} {{.MAIN_PACKAGE}} diff --git a/internal/core/messages.go b/internal/core/messages.go index 1122d4d..ee161a5 100644 --- a/internal/core/messages.go +++ b/internal/core/messages.go @@ -1,12 +1,16 @@ package core import ( + "encoding/json" + "fmt" "strconv" "strings" "time" ) const ( + legacyCanceledJSONKey = "cancel" + "led" + // CompactionSummaryPrefix wraps compacted conversation history. CompactionSummaryPrefix = "The conversation history before this point was compacted into the following summary:" + "\n\n\n" @@ -40,11 +44,42 @@ type BashExecutionMessage struct { Output string `json:"output"` FullOutputPath string `json:"full_output_path,omitempty"` Timestamp int64 `json:"timestamp"` - Canceled bool "json:\"cancel\u006ced\"" + Canceled bool `json:"canceled"` Truncated bool `json:"truncated"` ExcludeFromContext bool `json:"exclude_from_context,omitempty"` } +// UnmarshalJSON preserves compatibility with sessions written before the +// canonical American-English canceled key was fixed. +func (message *BashExecutionMessage) UnmarshalJSON(data []byte) error { + type bashExecutionMessageAlias BashExecutionMessage + + var decoded bashExecutionMessageAlias + if err := json.Unmarshal(data, &decoded); err != nil { + return fmt.Errorf("decode bash execution message: %w", err) + } + + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + return fmt.Errorf("decode bash execution message keys: %w", err) + } + + if _, ok := raw["canceled"]; !ok { + if legacy, ok := raw[legacyCanceledJSONKey]; ok { + var legacyCanceled bool + if err := json.Unmarshal(legacy, &legacyCanceled); err != nil { + return fmt.Errorf("decode legacy bash execution canceled key: %w", err) + } + + decoded.Canceled = legacyCanceled + } + } + + *message = BashExecutionMessage(decoded) + + return nil +} + // CustomMessage is extension-injected context. type CustomMessage struct { Details any `json:"details,omitempty"` diff --git a/internal/core/messages_test.go b/internal/core/messages_test.go index 4e7f859..8980a10 100644 --- a/internal/core/messages_test.go +++ b/internal/core/messages_test.go @@ -1,6 +1,7 @@ package core_test import ( + "encoding/json" "testing" "github.com/stretchr/testify/assert" @@ -108,6 +109,57 @@ func TestMessageConstructorsAndLLMConversions(t *testing.T) { assert.Contains(t, compactionLLM.Content[0].Text, "compact summary") } +func TestBashExecutionMessageJSON(t *testing.T) { + t.Parallel() + + encoded, err := json.Marshal(core.BashExecutionMessage{ + ExitCode: nil, + Command: "sleep 10", + Output: "", + FullOutputPath: "", + Timestamp: 123, + Canceled: true, + Truncated: false, + ExcludeFromContext: false, + }) + require.NoError(t, err) + assert.Contains(t, string(encoded), `"canceled":true`) + assert.NotContains(t, string(encoded), "cancel"+"led") + + tests := []struct { + name string + raw string + want bool + }{ + { + name: "canonical canceled key", + raw: `{"command":"sleep 10","canceled":true}`, + want: true, + }, + { + name: "legacy British-English key", + raw: `{"command":"sleep 10","cancel` + `led":true}`, + want: true, + }, + { + name: "canonical key wins over legacy key", + raw: `{"command":"sleep 10","canceled":false,"cancel` + `led":true}`, + want: false, + }, + } + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + var decoded core.BashExecutionMessage + require.NoError(t, json.Unmarshal([]byte(testCase.raw), &decoded)) + + assert.Equal(t, testCase.want, decoded.Canceled) + assert.Equal(t, "sleep 10", decoded.Command) + }) + } +} + func TestBashExecutionToLLM(t *testing.T) { t.Parallel() diff --git a/internal/core/skills_activation.go b/internal/core/skills_activation.go index 2e21a55..524e97c 100644 --- a/internal/core/skills_activation.go +++ b/internal/core/skills_activation.go @@ -298,20 +298,21 @@ func normalizeSkillToken(token string) string { } func isSkillStopWord(token string) bool { - stopWords := map[string]bool{ - "about": true, "after": true, "agent": true, "also": true, "and": true, - "any": true, "apply": true, "are": true, "build": true, "can": true, - "code": true, "coding": true, "cover": true, "covers": true, "debug": true, - "designed": true, "especially": true, "for": true, "from": true, "guide": true, - "helps": true, "implement": true, "into": true, "not": true, "only": true, - "project": true, "provides": true, "review": true, "similar": true, "skill": true, - "task": true, "tasks": true, "that": true, "the": true, "their": true, "these": true, - "this": true, "tool": true, "tools": true, "trigger": true, "use": true, - "when": true, "whenever": true, "with": true, "work": true, "working": true, - "write": true, "you": true, + switch token { + case "about", "after", "agent", "also", "and", + "any", "apply", "are", "build", "can", + "code", "coding", "cover", "covers", "debug", + "designed", "especially", "for", "from", "guide", + "helps", "implement", "into", "not", "only", + "project", "provides", "review", "similar", "skill", + "task", "tasks", "that", "the", "their", "these", + "this", "tool", "tools", "trigger", "use", + "when", "whenever", "with", "work", "working", + "write", "you": + return true + default: + return false } - - return stopWords[token] } func truncateSkillContent(content string) (string, bool) { diff --git a/internal/core/skills_test.go b/internal/core/skills_test.go index ae81cda..f98cdd3 100644 --- a/internal/core/skills_test.go +++ b/internal/core/skills_test.go @@ -243,6 +243,25 @@ func TestAutoActivateSkillsSelectsMatchingSkill(t *testing.T) { assert.Contains(t, activated[0].Content, "Run tests") } +func TestAutoActivateSkillsIgnoresActivationStopWords(t *testing.T) { + cwd := t.TempDir() + home := t.TempDir() + t.Setenv("HOME", home) + writeTestFile(t, filepath.Join(cwd, core.ConfigDirName, "skills", "loud", "SKILL.md"), strings.Join([]string{ + frontmatterDelimiter, + "name: loud", + "description: Use when the agent and any task", + frontmatterDelimiter, + "Loud instructions.", + }, "\n")) + + result := core.LoadSkills(cwd, nil, true) + detail := core.AutoActivateSkillsDetailed("the agent and any task", result.Skills) + + require.Empty(t, detail.Diagnostics) + assert.Empty(t, detail.Activated) +} + func TestAutoActivateSkillsRequiresStrongIntent(t *testing.T) { cwd := t.TempDir() home := t.TempDir() diff --git a/internal/tool/fetch.go b/internal/tool/fetch.go index 9d1a0c7..77ef2fd 100644 --- a/internal/tool/fetch.go +++ b/internal/tool/fetch.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "mime" @@ -248,7 +249,10 @@ func (fetchTool *FetchTool) fetchURL( setFetchHeaders(request) - response, err := fetchTool.httpClientWithRedirectValidation(requestCtx).Do(request) + client, closeIdleConnections := fetchTool.httpClientWithRedirectValidation(requestCtx) + defer closeIdleConnections() + + response, err := client.Do(request) if err != nil { return nil, fetchResponseInfo{}, oops.In("tool").Code("fetch_request").Wrapf(err, "fetch url") } @@ -298,11 +302,17 @@ func (fetchTool *FetchTool) httpClient() *http.Client { return http.DefaultClient } -func (fetchTool *FetchTool) httpClientWithRedirectValidation(ctx context.Context) *http.Client { +func (fetchTool *FetchTool) httpClientWithRedirectValidation( + ctx context.Context, +) (client *http.Client, closeIdleConnections func()) { baseClient := fetchTool.httpClient() - client := *baseClient + clonedClient := *baseClient + + transport, closeIdleConnections := fetchTool.transportWithNetworkValidation(baseClient.Transport) + clonedClient.Transport = transport + baseCheckRedirect := baseClient.CheckRedirect - client.CheckRedirect = func(request *http.Request, via []*http.Request) error { + clonedClient.CheckRedirect = func(request *http.Request, via []*http.Request) error { if baseCheckRedirect != nil { if err := baseCheckRedirect(request, via); err != nil { return err @@ -314,7 +324,86 @@ func (fetchTool *FetchTool) httpClientWithRedirectValidation(ctx context.Context return fetchTool.validatePublicFetchURL(ctx, request.URL) } - return &client + return &clonedClient, closeIdleConnections +} + +func (fetchTool *FetchTool) transportWithNetworkValidation( + baseTransport http.RoundTripper, +) (roundTripper http.RoundTripper, closeIdleConnections func()) { + if fetchTool.allowPrivateNetworks { + return baseTransport, func() {} + } + + transport, ok := cloneFetchHTTPTransport(baseTransport) + if !ok { + return baseTransport, func() {} + } + + transport.Proxy = nil + transport.DialContext = validatingFetchDialContext(fetchDialContext(transport)) + + if transport.DialTLSContext != nil { + transport.DialTLSContext = validatingFetchDialContext(transport.DialTLSContext) + } + + return transport, transport.CloseIdleConnections +} + +func cloneFetchHTTPTransport(baseTransport http.RoundTripper) (*http.Transport, bool) { + if baseTransport == nil { + baseTransport = http.DefaultTransport + } + + transport, ok := baseTransport.(*http.Transport) + if !ok { + return nil, false + } + + return transport.Clone(), true +} + +func fetchDialContext(transport *http.Transport) func(context.Context, string, string) (net.Conn, error) { + if transport.DialContext != nil { + return transport.DialContext + } + + dialer := &net.Dialer{} + + return dialer.DialContext +} + +func validatingFetchDialContext( + dialContext func(context.Context, string, string) (net.Conn, error), +) func(context.Context, string, string) (net.Conn, error) { + return func(ctx context.Context, network, address string) (net.Conn, error) { + return validateFetchDialedConnection(dialContext(ctx, network, address)) + } +} + +func validateFetchDialedConnection(conn net.Conn, dialErr error) (net.Conn, error) { + if dialErr != nil { + return nil, dialErr + } + + if conn == nil { + return nil, oops.In("tool").Code("fetch_nil_connection").Errorf("fetch dial returned nil connection") + } + + if err := validatePublicFetchRemoteAddr(conn.RemoteAddr()); err != nil { + if closeErr := conn.Close(); closeErr != nil { + return nil, errors.Join( + err, + oops.In("tool").Code("fetch_close_rejected_connection").Wrapf( + closeErr, + "close rejected fetch connection", + ), + ) + } + + return nil, err + } + + return conn, nil } func (fetchTool *FetchTool) validatePublicFetchURL(ctx context.Context, requestURL *url.URL) error { @@ -390,6 +479,37 @@ func validatePublicFetchIP(ipAddress net.IP) error { return nil } +func validatePublicFetchRemoteAddr(remoteAddr net.Addr) error { + ipAddress := fetchRemoteAddrIP(remoteAddr) + if ipAddress == nil { + return oops.In("tool").Code("fetch_invalid_remote_address").Errorf("fetch remote address is not an IP") + } + + return validatePublicFetchIP(ipAddress) +} + +func fetchRemoteAddrIP(remoteAddr net.Addr) net.IP { + if remoteAddr == nil { + return nil + } + + switch addr := remoteAddr.(type) { + case *net.TCPAddr: + return addr.IP + case *net.UDPAddr: + return addr.IP + case *net.IPAddr: + return addr.IP + } + + host, _, err := net.SplitHostPort(remoteAddr.String()) + if err != nil { + host = remoteAddr.String() + } + + return parseFetchHostIP(normalizedFetchHost(host)) +} + func isPrivateFetchIP(ipAddress net.IP) bool { return ipAddress.IsLoopback() || ipAddress.IsPrivate() || diff --git a/internal/tool/fetch_internal_test.go b/internal/tool/fetch_internal_test.go index 6092fb9..dd7f133 100644 --- a/internal/tool/fetch_internal_test.go +++ b/internal/tool/fetch_internal_test.go @@ -21,6 +21,7 @@ import ( ) const ( + fetchTestExampleHost = "example.test" fetchTestExampleURL = "https://example.com" fetchTestIgnoredFooter = "Ignore footer" fetchTestIgnoredHeader = "Ignore header" @@ -445,10 +446,10 @@ func TestFetchTool_RejectsPrivateNetworkTargets(t *testing.T) { }, { lookups: map[string][]net.IPAddr{ - "example.test": {{IP: net.ParseIP("192.168.1.10")}}, + fetchTestExampleHost: {{IP: net.ParseIP("192.168.1.10")}}, }, name: "private dns result", - rawURL: "http://example.test", + rawURL: "http://" + fetchTestExampleHost, }, } @@ -469,7 +470,7 @@ func TestFetchTool_RejectsPrivateNetworkRedirect(t *testing.T) { t.Parallel() fetchTool := fetchTestLookupTool(map[string][]net.IPAddr{ - "example.test": {{IP: net.ParseIP("93.184.216.34")}}, + fetchTestExampleHost: {{IP: net.ParseIP("93.184.216.34")}}, }) fetchTool.client = &http.Client{Transport: roundTripFunc(func(request *http.Request) (*http.Response, error) { response := fetchTestHTTPResponse(request, io.NopCloser(strings.NewReader("redirect"))) @@ -480,7 +481,37 @@ func TestFetchTool_RejectsPrivateNetworkRedirect(t *testing.T) { return response, nil })} - _, err := fetchTool.Fetch(context.Background(), fetchInputForTest("http://example.test/start", "")) + _, err := fetchTool.Fetch(context.Background(), fetchInputForTest("http://"+fetchTestExampleHost+"/start", "")) + + require.Error(t, err) + assert.Contains(t, err.Error(), "private or local networks") +} + +func TestFetchTool_RejectsPrivateDialedAddress(t *testing.T) { + t.Parallel() + + server := fetchTestPrivateNetworkServer(http.HandlerFunc(func(writer http.ResponseWriter, _ *http.Request) { + if _, err := writer.Write([]byte("unexpected private target")); err != nil { + return + } + })) + defer server.Close() + + serverURL, err := url.Parse(server.URL) + require.NoError(t, err) + + fetchTool := fetchTestLookupTool(map[string][]net.IPAddr{ + fetchTestExampleHost: {{IP: net.ParseIP("93.184.216.34")}}, + }) + fetchTool.client = &http.Client{Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) { + var dialer net.Dialer + + return dialer.DialContext(ctx, network, serverURL.Host) + }, + }} + + _, err = fetchTool.Fetch(context.Background(), fetchInputForTest("http://"+fetchTestExampleHost, "")) require.Error(t, err) assert.Contains(t, err.Error(), "private or local networks") @@ -799,7 +830,10 @@ func TestFetchTool_RedirectValidation(t *testing.T) { fetchTool := fetchTestPrivateNetworkTool() fetchTool.client = &http.Client{CheckRedirect: testCase.baseCheckRedirect} - client := fetchTool.httpClientWithRedirectValidation(context.Background()) + + client, closeIdleConnections := fetchTool.httpClientWithRedirectValidation(context.Background()) + defer closeIdleConnections() + request, err := http.NewRequestWithContext( context.Background(), http.MethodGet, diff --git a/internal/tool/ignore.go b/internal/tool/ignore.go index d140db5..b0d5c4d 100644 --- a/internal/tool/ignore.go +++ b/internal/tool/ignore.go @@ -18,26 +18,14 @@ const ( readIgnoreCacheCapacity = 16 ) -func defaultReadIgnorePatterns() []string { - return []string{ - ".git/", - "node_modules/", - ".env", - ".gocache/", - ".gomodcache/", - ".tmp/", - "bin/", - "/skills/", - } -} - type ignorePattern struct { pattern gitignore.Pattern source string } type readIgnoreCache struct { - patterns *hot.HotCache[string, repositoryIgnorePatterns] + patterns *hot.HotCache[string, repositoryIgnorePatterns] + defaultPatterns []ignorePattern } type repositoryIgnorePatterns struct { @@ -60,6 +48,7 @@ type ignorePathState struct { func newReadIgnoreCache() *readIgnoreCache { return &readIgnoreCache{ + defaultPatterns: newDefaultReadIgnorePatterns(), patterns: hot.NewHotCache[string, repositoryIgnorePatterns](hot.WTinyLFU, readIgnoreCacheCapacity). WithLoaders(func(workspaceRoots []string) (map[string]repositoryIgnorePatterns, error) { patterns := make(map[string]repositoryIgnorePatterns, len(workspaceRoots)) @@ -129,15 +118,11 @@ func pathEscapesRoot(relativePath string) bool { } func readIgnorePatterns(workspaceRoot string, cache *readIgnoreCache) []ignorePattern { - patterns := make([]ignorePattern, 0, len(defaultReadIgnorePatterns())) - for _, pattern := range defaultReadIgnorePatterns() { - patterns = append(patterns, ignorePattern{ - pattern: gitignore.ParsePattern(pattern, nil), - source: pattern, - }) - } - + defaults := readDefaultIgnorePatterns(cache) repositoryPatterns := cache.repositoryPatterns(workspaceRoot) + patterns := make([]ignorePattern, 0, len(defaults)+len(repositoryPatterns)) + patterns = append(patterns, defaults...) + for _, pattern := range repositoryPatterns { patterns = append(patterns, ignorePattern{ pattern: pattern, @@ -148,6 +133,37 @@ func readIgnorePatterns(workspaceRoot string, cache *readIgnoreCache) []ignorePa return patterns } +func readDefaultIgnorePatterns(cache *readIgnoreCache) []ignorePattern { + if cache != nil { + return cache.defaultPatterns + } + + return newDefaultReadIgnorePatterns() +} + +func newDefaultReadIgnorePatterns() []ignorePattern { + patternSources := [...]string{ + ".git/", + "node_modules/", + ".env", + ".gocache/", + ".gomodcache/", + ".tmp/", + "bin/", + "/skills/", + } + patterns := make([]ignorePattern, 0, len(patternSources)) + + for _, source := range patternSources { + patterns = append(patterns, ignorePattern{ + pattern: gitignore.ParsePattern(source, nil), + source: source, + }) + } + + return patterns +} + func (cache *readIgnoreCache) repositoryPatterns(workspaceRoot string) []gitignore.Pattern { if cache == nil { return readRepositoryIgnorePatterns(workspaceRoot).patterns diff --git a/internal/vinfo/version.go b/internal/vinfo/version.go index c67320a..1e259bf 100644 --- a/internal/vinfo/version.go +++ b/internal/vinfo/version.go @@ -7,35 +7,145 @@ import ( "strings" ) -const devVersion = "dev" +const ( + buildInfoModifiedKey = "vcs.modified" + buildInfoRevisionKey = "vcs.revision" + buildInfoTimeKey = "vcs.time" + devVersion = "dev" + dirtySuffix = "-dirty" + defaultCommit = "none" + defaultBuildDate = "unknown" + metadataSeparator = "|" + shortRevisionBytes = 8 + trueValue = "true" +) +// version is set by release/build -ldflags as "version|commit|buildDate". var version = devVersion -func commit() string { - return "none" -} - -func buildDate() string { - return "unknown" +type buildMetadata struct { + version string + commit string + buildDate string } // String returns a human-readable build version string. func String() string { - value := version + var info *debug.BuildInfo + if buildInfo, ok := debug.ReadBuildInfo(); ok { + info = buildInfo + } + + return stringFromVersion(version, info) +} + +func stringFromVersion(value string, info *debug.BuildInfo) string { + metadata := parseBuildMetadata(value) + if info != nil { + metadata = metadata.withBuildInfoFallback(info) + } + + return metadata.String() +} + +func (metadata buildMetadata) String() string { + return fmt.Sprintf("%s (commit=%s, built=%s)", metadata.version, metadata.commit, metadata.buildDate) +} + +func parseBuildMetadata(value string) buildMetadata { + versionValue, rest, hasMetadata := strings.Cut(value, metadataSeparator) + + metadata := buildMetadata{ + version: versionPart(versionValue), + commit: defaultCommit, + buildDate: defaultBuildDate, + } + if !hasMetadata { + return metadata + } + + commitValue, buildDateValue, hasBuildDate := strings.Cut(rest, metadataSeparator) + + metadata.commit = metadataPart(commitValue, defaultCommit) + if hasBuildDate { + metadata.buildDate = metadataPart(buildDateValue, defaultBuildDate) + } + + return metadata +} + +func (metadata buildMetadata) withBuildInfoFallback(info *debug.BuildInfo) buildMetadata { + if metadata.version == devVersion { + metadata.version = buildInfoVersion(info) + } + + if metadata.commit == defaultCommit { + revision := shortRevision(buildInfoSetting(info.Settings, buildInfoRevisionKey)) + metadata.commit = metadataPart(revision, defaultCommit) + } + + if metadata.buildDate == defaultBuildDate { + metadata.buildDate = metadataPart(buildInfoSetting(info.Settings, buildInfoTimeKey), defaultBuildDate) + } + + return metadata +} + +func buildInfoVersion(info *debug.BuildInfo) string { + moduleVersion := versionPart(info.Main.Version) + if moduleVersion != devVersion { + return moduleVersion + } + + return fallbackVCSVersion(info.Settings) +} + +func fallbackVCSVersion(settings []debug.BuildSetting) string { + revision := shortRevision(buildInfoSetting(settings, buildInfoRevisionKey)) + if revision == "" { + return devVersion + } - if value == devVersion { - if info, ok := debug.ReadBuildInfo(); ok { - value = fallbackVersion(info.Main.Version) + if buildInfoSetting(settings, buildInfoModifiedKey) == trueValue { + return revision + dirtySuffix + } + + return revision +} + +func buildInfoSetting(settings []debug.BuildSetting, key string) string { + for _, setting := range settings { + if setting.Key == key { + return strings.TrimSpace(setting.Value) } } - return fmt.Sprintf("%s (commit=%s, built=%s)", value, commit(), buildDate()) + return "" +} + +func shortRevision(revision string) string { + revision = strings.TrimSpace(revision) + if len(revision) <= shortRevisionBytes { + return revision + } + + return revision[:shortRevisionBytes] +} + +func metadataPart(value, fallback string) string { + value = strings.TrimSpace(value) + if value == "" { + return fallback + } + + return value } -func fallbackVersion(version string) string { - if version == "" || version == "(devel)" { +func versionPart(value string) string { + value = strings.TrimSpace(value) + if value == "" || value == "(devel)" { return devVersion } - return strings.TrimSpace(version) + return value } diff --git a/internal/vinfo/version_internal_test.go b/internal/vinfo/version_internal_test.go index 6f78578..ec45999 100644 --- a/internal/vinfo/version_internal_test.go +++ b/internal/vinfo/version_internal_test.go @@ -1,26 +1,164 @@ package vinfo import ( + "runtime/debug" "testing" "github.com/stretchr/testify/assert" ) +const ( + testBuildDate = "2026-06-24T00:00:00Z" + testCommit = "abc1234" + testDirtyVersion = testShortRevision + dirtySuffix + testFullRevision = "25e59c5c54787d963bda41fe594517598334ff27" + testShortRevision = "25e59c5c" + testUpdatedBuildDate = "2026-06-25T00:00:00Z" + testVersion = "1.2.3" + testVersionExpected = testVersion + " (commit=" + testCommit + ", built=" + testBuildDate + ")" +) + func TestStringUsesInjectedBuildMetadata(t *testing.T) { t.Parallel() - oldVersion := version + assert.Equal( + t, + testVersionExpected, + stringFromVersion(testVersion+metadataSeparator+testCommit+metadataSeparator+testBuildDate, nil), + ) +} + +func TestParseBuildMetadata(t *testing.T) { + t.Parallel() + + tests := []struct { + expected buildMetadata + name string + value string + }{ + { + name: "plain version", + value: testVersion, + expected: buildMetadata{ + version: testVersion, + commit: defaultCommit, + buildDate: defaultBuildDate, + }, + }, + { + name: "full metadata", + value: " " + testVersion + " | " + testCommit + " | " + testBuildDate + " ", + expected: buildMetadata{ + version: testVersion, + commit: testCommit, + buildDate: testBuildDate, + }, + }, + { + name: "missing metadata fields use defaults", + value: "| |", + expected: buildMetadata{ + version: devVersion, + commit: defaultCommit, + buildDate: defaultBuildDate, + }, + }, + } + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + assert.Equal(t, testCase.expected, parseBuildMetadata(testCase.value)) + }) + } +} + +func TestBuildMetadataFallsBackToBuildInfo(t *testing.T) { + t.Parallel() - t.Cleanup(func() { - version = oldVersion + metadata := buildMetadata{ + version: devVersion, + commit: defaultCommit, + buildDate: defaultBuildDate, + }.withBuildInfoFallback(&debug.BuildInfo{ + Main: debug.Module{}, + Settings: []debug.BuildSetting{ + {Key: buildInfoRevisionKey, Value: testFullRevision}, + {Key: buildInfoModifiedKey, Value: trueValue}, + {Key: buildInfoTimeKey, Value: testBuildDate}, + }, }) - version = "1.2.3" + assert.Equal(t, buildMetadata{ + version: testDirtyVersion, + commit: testShortRevision, + buildDate: testBuildDate, + }, metadata) +} + +func TestBuildMetadataUsesModuleVersion(t *testing.T) { + t.Parallel() + + metadata := buildMetadata{ + version: devVersion, + commit: defaultCommit, + buildDate: defaultBuildDate, + }.withBuildInfoFallback(&debug.BuildInfo{ + Main: debug.Module{Version: testVersion}, + Settings: []debug.BuildSetting{ + {Key: buildInfoRevisionKey, Value: testFullRevision}, + {Key: buildInfoModifiedKey, Value: trueValue}, + {Key: buildInfoTimeKey, Value: testBuildDate}, + }, + }) + + assert.Equal(t, buildMetadata{ + version: testVersion, + commit: testShortRevision, + buildDate: testBuildDate, + }, metadata) +} + +func TestBuildMetadataPreservesInjectedValues(t *testing.T) { + t.Parallel() + + metadata := buildMetadata{ + version: testVersion, + commit: testCommit, + buildDate: testBuildDate, + }.withBuildInfoFallback(&debug.BuildInfo{ + Main: debug.Module{}, + Settings: []debug.BuildSetting{ + {Key: buildInfoRevisionKey, Value: testFullRevision}, + {Key: buildInfoModifiedKey, Value: trueValue}, + {Key: buildInfoTimeKey, Value: testUpdatedBuildDate}, + }, + }) + + assert.Equal(t, buildMetadata{ + version: testVersion, + commit: testCommit, + buildDate: testBuildDate, + }, metadata) +} + +func TestBuildInfoSetting(t *testing.T) { + t.Parallel() + + settings := []debug.BuildSetting{{Key: buildInfoRevisionKey, Value: " " + testCommit + " "}} + + assert.Equal(t, testCommit, buildInfoSetting(settings, buildInfoRevisionKey)) + assert.Empty(t, buildInfoSetting(settings, "missing")) +} + +func TestShortRevision(t *testing.T) { + t.Parallel() - assert.Equal(t, "1.2.3 (commit=none, built=unknown)", String()) + assert.Equal(t, testShortRevision, shortRevision(testFullRevision)) + assert.Equal(t, testCommit, shortRevision(" "+testCommit+" ")) } -func TestFallbackVersion(t *testing.T) { +func TestVersionPart(t *testing.T) { t.Parallel() tests := []struct { @@ -36,7 +174,7 @@ func TestFallbackVersion(t *testing.T) { t.Run(testCase.name, func(t *testing.T) { t.Parallel() - assert.Equal(t, testCase.expected, fallbackVersion(testCase.version)) + assert.Equal(t, testCase.expected, versionPart(testCase.version)) }) } }