commit 08bb28ced55c6ba29105cefd3c21cd6895937460 Author: kite Date: Sat Jun 6 19:25:08 2026 +0800 feat(llm): enable Anthropic prompt caching and fix token accounting Add cache_control ephemeral breakpoints on system prompt and tool definitions to activate Anthropic prompt caching for multi-turn agent loops. Fix token double-counting by removing the redundant totalTokensUsed accumulator and computing it from sub-items. Surface cache read/write statistics in both text and JSON summary output. diff --git a/cmd/opencodereview/output.go b/cmd/opencodereview/output.go index 603a7ea..a2ccc7b 100644 --- a/cmd/opencodereview/output.go +++ b/cmd/opencodereview/output.go @@ -163,12 +163,14 @@ func buildDiffLines(comment model.LlmComment) []suggestdiff.DiffLine { } type jsonSummary struct { - FilesReviewed int64 `json:"files_reviewed"` - Comments int64 `json:"comments"` - TotalTokens int64 `json:"total_tokens"` - InputTokens int64 `json:"input_tokens"` - OutputTokens int64 `json:"output_tokens"` - Elapsed string `json:"elapsed"` + FilesReviewed int64 `json:"files_reviewed"` + Comments int64 `json:"comments"` + TotalTokens int64 `json:"total_tokens"` + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + CacheReadTokens int64 `json:"cache_read_tokens,omitempty"` + CacheWriteTokens int64 `json:"cache_write_tokens,omitempty"` + Elapsed string `json:"elapsed"` } type jsonOutput struct { @@ -193,17 +195,19 @@ func outputJSON(comments []model.LlmComment) error { } func outputJSONWithWarnings(comments []model.LlmComment, warnings []agent.AgentWarning, - filesReviewed, inputTokens, outputTokens, totalTokens int64, duration time.Duration) error { + filesReviewed, inputTokens, outputTokens, totalTokens, cacheReadTokens, cacheWriteTokens int64, duration time.Duration) error { out := jsonOutput{ Status: "success", Comments: comments, Summary: &jsonSummary{ - FilesReviewed: filesReviewed, - Comments: int64(len(comments)), - TotalTokens: totalTokens, - InputTokens: inputTokens, - OutputTokens: outputTokens, - Elapsed: duration.Round(time.Second).String(), + FilesReviewed: filesReviewed, + Comments: int64(len(comments)), + TotalTokens: totalTokens, + InputTokens: inputTokens, + OutputTokens: outputTokens, + CacheReadTokens: cacheReadTokens, + CacheWriteTokens: cacheWriteTokens, + Elapsed: duration.Round(time.Second).String(), }, } if len(comments) == 0 { diff --git a/cmd/opencodereview/review_cmd.go b/cmd/opencodereview/review_cmd.go index 17e2441..052bfc4 100644 --- a/cmd/opencodereview/review_cmd.go +++ b/cmd/opencodereview/review_cmd.go @@ -166,11 +166,11 @@ func runReview(args []string) error { } if opts.outputFormat != "json" { - telemetry.PrintTraceSummary(ag.FilesReviewed(), int64(len(comments)), ag.TotalInputTokens(), ag.TotalOutputTokens(), ag.TotalTokensUsed(), duration) + telemetry.PrintTraceSummary(ag.FilesReviewed(), int64(len(comments)), ag.TotalInputTokens(), ag.TotalOutputTokens(), ag.TotalTokensUsed(), ag.TotalCacheReadTokens(), ag.TotalCacheWriteTokens(), duration) } if opts.outputFormat == "json" { - return outputJSONWithWarnings(comments, ag.Warnings(), ag.FilesReviewed(), ag.TotalInputTokens(), ag.TotalOutputTokens(), ag.TotalTokensUsed(), duration) + return outputJSONWithWarnings(comments, ag.Warnings(), ag.FilesReviewed(), ag.TotalInputTokens(), ag.TotalOutputTokens(), ag.TotalTokensUsed(), ag.TotalCacheReadTokens(), ag.TotalCacheWriteTokens(), duration) } if opts.audience == "agent" { outputTextWithWarnings(comments, ag.Warnings()) diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 276ad90..8737b99 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -158,20 +158,21 @@ type compressionJob struct { // Agent orchestrates the AI-powered code review. type Agent struct { - args Args - diffs []model.Diff // parsed diffs - totalInsertions int64 - totalDeletions int64 - currentDate string - session *session.SessionHistory - totalTokensUsed int64 // accumulated total tokens from all LLM calls, accessed atomically - totalInputTokens int64 // accumulated input/prompt tokens, accessed atomically - totalOutputTokens int64 // accumulated completion tokens, accessed atomically - subtaskFailed int64 // count of failed subtasks, accessed atomically - warningsMu sync.Mutex - warnings []AgentWarning - compressionMu sync.Mutex - pendingJob *compressionJob + args Args + diffs []model.Diff // parsed diffs + totalInsertions int64 + totalDeletions int64 + currentDate string + session *session.SessionHistory + totalInputTokens int64 // accumulated input/prompt tokens, accessed atomically + totalOutputTokens int64 // accumulated completion tokens, accessed atomically + totalCacheReadTokens int64 // accumulated cache read tokens, accessed atomically + totalCacheWriteTokens int64 // accumulated cache write tokens, accessed atomically + subtaskFailed int64 // count of failed subtasks, accessed atomically + warningsMu sync.Mutex + warnings []AgentWarning + compressionMu sync.Mutex + pendingJob *compressionJob } // CommentWorkerPool manages a fixed-size pool of workers dedicated to @@ -311,9 +312,10 @@ func (a *Agent) Diffs() []model.Diff { return a.diffs } -// TotalTokensUsed returns the accumulated total tokens from all LLM calls. +// TotalTokensUsed returns PromptTokens + CompletionTokens across all LLM calls. +// For Anthropic, PromptTokens already includes cache read/write tokens. func (a *Agent) TotalTokensUsed() int64 { - return atomic.LoadInt64(&a.totalTokensUsed) + return atomic.LoadInt64(&a.totalInputTokens) + atomic.LoadInt64(&a.totalOutputTokens) } // TotalInputTokens returns the accumulated input/prompt tokens from all LLM calls. @@ -326,6 +328,16 @@ func (a *Agent) TotalOutputTokens() int64 { return atomic.LoadInt64(&a.totalOutputTokens) } +// TotalCacheReadTokens returns the accumulated cache read tokens from all LLM calls. +func (a *Agent) TotalCacheReadTokens() int64 { + return atomic.LoadInt64(&a.totalCacheReadTokens) +} + +// TotalCacheWriteTokens returns the accumulated cache write tokens from all LLM calls. +func (a *Agent) TotalCacheWriteTokens() int64 { + return atomic.LoadInt64(&a.totalCacheWriteTokens) +} + // Warnings returns a copy of non-fatal warnings recorded during review. func (a *Agent) Warnings() []AgentWarning { a.warningsMu.Lock() @@ -703,9 +715,10 @@ func (a *Agent) executePlanPhase(ctx context.Context, newPath, rawDiff, changeFi } rec.SetResponse(resp, time.Since(startTime)) if resp.Usage != nil { - atomic.AddInt64(&a.totalTokensUsed, int64(resp.Usage.TotalTokens)) - atomic.AddInt64(&a.totalInputTokens, int64(resp.Usage.PromptTokens+resp.Usage.CacheReadTokens)) - atomic.AddInt64(&a.totalOutputTokens, int64(resp.Usage.CompletionTokens+resp.Usage.CacheWriteTokens)) + atomic.AddInt64(&a.totalInputTokens, resp.Usage.PromptTokens) + atomic.AddInt64(&a.totalOutputTokens, resp.Usage.CompletionTokens) + atomic.AddInt64(&a.totalCacheReadTokens, resp.Usage.CacheReadTokens) + atomic.AddInt64(&a.totalCacheWriteTokens, resp.Usage.CacheWriteTokens) } fmt.Fprintf(stdout.Writer(), "[ocr] Plan completed for %s\n", newPath) return resp.Content(), nil @@ -787,11 +800,12 @@ func (a *Agent) performLlmCodeReview(ctx context.Context, messages []llm.Message totalTokens := int64(0) if resp.Usage != nil { totalTokens = resp.Usage.TotalTokens - atomic.AddInt64(&a.totalInputTokens, int64(resp.Usage.PromptTokens+resp.Usage.CacheReadTokens)) - atomic.AddInt64(&a.totalOutputTokens, int64(resp.Usage.CompletionTokens+resp.Usage.CacheWriteTokens)) + atomic.AddInt64(&a.totalInputTokens, resp.Usage.PromptTokens) + atomic.AddInt64(&a.totalOutputTokens, resp.Usage.CompletionTokens) + atomic.AddInt64(&a.totalCacheReadTokens, resp.Usage.CacheReadTokens) + atomic.AddInt64(&a.totalCacheWriteTokens, resp.Usage.CacheWriteTokens) } telemetry.RecordLLMRequest(ctx, a.args.Model, duration, totalTokens, "ok") - atomic.AddInt64(&a.totalTokensUsed, totalTokens) content := resp.Content() calls := resp.ToolCalls() @@ -920,9 +934,10 @@ func (a *Agent) executeToolCall(ctx context.Context, newPath string, call llm.To if resp != nil { rlRec.SetResponse(resp, time.Since(rlStart)) if resp.Usage != nil { - atomic.AddInt64(&a.totalTokensUsed, int64(resp.Usage.TotalTokens)) - atomic.AddInt64(&a.totalInputTokens, int64(resp.Usage.PromptTokens+resp.Usage.CacheReadTokens)) - atomic.AddInt64(&a.totalOutputTokens, int64(resp.Usage.CompletionTokens+resp.Usage.CacheWriteTokens)) + atomic.AddInt64(&a.totalInputTokens, resp.Usage.PromptTokens) + atomic.AddInt64(&a.totalOutputTokens, resp.Usage.CompletionTokens) + atomic.AddInt64(&a.totalCacheReadTokens, resp.Usage.CacheReadTokens) + atomic.AddInt64(&a.totalCacheWriteTokens, resp.Usage.CacheWriteTokens) } } else { rlRec.SetError(fmt.Errorf("re-location LLM call failed"), time.Since(rlStart)) @@ -1190,9 +1205,10 @@ func (a *Agent) runCompression(ctx context.Context, msgs []llm.Message, filePath } rec.SetResponse(resp, duration) if resp.Usage != nil { - atomic.AddInt64(&a.totalTokensUsed, int64(resp.Usage.TotalTokens)) - atomic.AddInt64(&a.totalInputTokens, int64(resp.Usage.PromptTokens+resp.Usage.CacheReadTokens)) - atomic.AddInt64(&a.totalOutputTokens, int64(resp.Usage.CompletionTokens+resp.Usage.CacheWriteTokens)) + atomic.AddInt64(&a.totalInputTokens, resp.Usage.PromptTokens) + atomic.AddInt64(&a.totalOutputTokens, resp.Usage.CompletionTokens) + atomic.AddInt64(&a.totalCacheReadTokens, resp.Usage.CacheReadTokens) + atomic.AddInt64(&a.totalCacheWriteTokens, resp.Usage.CacheWriteTokens) } rawSummary := stripMarkdownFences(resp.Content()) diff --git a/internal/llm/client.go b/internal/llm/client.go index dc70e37..914271c 100644 --- a/internal/llm/client.go +++ b/internal/llm/client.go @@ -612,9 +612,11 @@ func (c *AnthropicClient) buildAnthropicParams(model string, req ChatRequest) an } if len(systemBlocks) > 0 { + systemBlocks[len(systemBlocks)-1].CacheControl = anthropic.NewCacheControlEphemeralParam() params.System = systemBlocks } if len(tools) > 0 { + tools[len(tools)-1].OfTool.CacheControl = anthropic.NewCacheControlEphemeralParam() params.Tools = tools } if req.Temperature != nil { diff --git a/internal/llm/client_test.go b/internal/llm/client_test.go index e3adc81..6661ac4 100644 --- a/internal/llm/client_test.go +++ b/internal/llm/client_test.go @@ -2,6 +2,8 @@ package llm import ( "testing" + + anthropic "github.com/anthropics/anthropic-sdk-go" ) func TestNewOpenAIClient_URLNormalization(t *testing.T) { @@ -89,3 +91,119 @@ func TestNewAnthropicClient_URLNormalization(t *testing.T) { }) } } + +func TestBuildAnthropicParams_CacheControl(t *testing.T) { + client := NewAnthropicClient(ClientConfig{URL: "https://api.anthropic.com"}) + + req := ChatRequest{ + Messages: []Message{ + {Role: "system", Content: "You are a code reviewer."}, + {Role: "system", Content: "Be concise."}, + {Role: "user", Content: "Review this code."}, + }, + Tools: []ToolDef{ + {Type: "function", Function: FunctionDef{Name: "tool_a", Description: "first tool", Parameters: map[string]any{"type": "object"}}}, + {Type: "function", Function: FunctionDef{Name: "tool_b", Description: "second tool", Parameters: map[string]any{"type": "object"}}}, + }, + } + + params := client.buildAnthropicParams("claude-sonnet-4-20250514", req) + + t.Run("last system block has cache control", func(t *testing.T) { + if len(params.System) < 2 { + t.Fatalf("expected at least 2 system blocks, got %d", len(params.System)) + } + last := params.System[len(params.System)-1] + if last.CacheControl.Type != "ephemeral" { + t.Errorf("last system block CacheControl.Type = %q, want %q", last.CacheControl.Type, "ephemeral") + } + }) + + t.Run("non-last system block has no cache control", func(t *testing.T) { + first := params.System[0] + if first.CacheControl.Type != "" { + t.Errorf("first system block CacheControl.Type = %q, want empty", first.CacheControl.Type) + } + }) + + t.Run("last tool has cache control", func(t *testing.T) { + if len(params.Tools) < 2 { + t.Fatalf("expected at least 2 tools, got %d", len(params.Tools)) + } + last := params.Tools[len(params.Tools)-1] + if last.OfTool == nil { + t.Fatal("last tool OfTool is nil") + } + if last.OfTool.CacheControl.Type != "ephemeral" { + t.Errorf("last tool CacheControl.Type = %q, want %q", last.OfTool.CacheControl.Type, "ephemeral") + } + }) + + t.Run("non-last tool has no cache control", func(t *testing.T) { + first := params.Tools[0] + if first.OfTool == nil { + t.Fatal("first tool OfTool is nil") + } + if first.OfTool.CacheControl.Type != "" { + t.Errorf("first tool CacheControl.Type = %q, want empty", first.OfTool.CacheControl.Type) + } + }) + + t.Run("top-level CacheControl is not set", func(t *testing.T) { + if params.CacheControl.Type != "" { + t.Errorf("params.CacheControl.Type = %q, want empty", params.CacheControl.Type) + } + }) +} + +func TestBuildAnthropicParams_CacheControl_NoTools(t *testing.T) { + client := NewAnthropicClient(ClientConfig{URL: "https://api.anthropic.com"}) + + req := ChatRequest{ + Messages: []Message{ + {Role: "system", Content: "You are a planner."}, + {Role: "user", Content: "Plan the review."}, + }, + } + + params := client.buildAnthropicParams("claude-sonnet-4-20250514", req) + + if len(params.System) == 0 { + t.Fatal("expected system blocks") + } + last := params.System[len(params.System)-1] + if last.CacheControl.Type != "ephemeral" { + t.Errorf("system CacheControl.Type = %q, want %q", last.CacheControl.Type, "ephemeral") + } + if len(params.Tools) != 0 { + t.Errorf("expected no tools, got %d", len(params.Tools)) + } +} + +func TestBuildAnthropicParams_CacheControl_NoSystem(t *testing.T) { + client := NewAnthropicClient(ClientConfig{URL: "https://api.anthropic.com"}) + + req := ChatRequest{ + Messages: []Message{ + {Role: "user", Content: "Hello"}, + }, + Tools: []ToolDef{ + {Type: "function", Function: FunctionDef{Name: "tool_a", Description: "a tool", Parameters: map[string]any{"type": "object"}}}, + }, + } + + params := client.buildAnthropicParams("claude-sonnet-4-20250514", req) + + if len(params.System) != 0 { + t.Errorf("expected no system blocks, got %d", len(params.System)) + } + if len(params.Tools) == 0 { + t.Fatal("expected tools") + } + if params.Tools[0].OfTool.CacheControl.Type != "ephemeral" { + t.Errorf("tool CacheControl.Type = %q, want %q", params.Tools[0].OfTool.CacheControl.Type, "ephemeral") + } +} + +// Verify the SDK constant is accessible (compile-time check). +var _ anthropic.CacheControlEphemeralParam = anthropic.NewCacheControlEphemeralParam() diff --git a/internal/telemetry/events.go b/internal/telemetry/events.go index 854d958..9730c29 100644 --- a/internal/telemetry/events.go +++ b/internal/telemetry/events.go @@ -66,11 +66,15 @@ func FormatDuration(dur time.Duration) string { } // PrintTraceSummary prints a one-line summary of the review to stdout. -func PrintTraceSummary(filesReviewed, commentsGenerated int64, inputTokens, outputTokens, totalTokens int64, duration time.Duration) { +func PrintTraceSummary(filesReviewed, commentsGenerated int64, inputTokens, outputTokens, totalTokens int64, cacheReadTokens, cacheWriteTokens int64, duration time.Duration) { elapsed := duration.Round(time.Second).String() if inputTokens > 0 || outputTokens > 0 { - fmt.Fprintf(stdout.Writer(), "[ocr] Summary: %d file(s) reviewed, %d comment(s), ~%d token(s) used (input: ~%d, output: ~%d), %s elapsed\n", - filesReviewed, commentsGenerated, totalTokens, inputTokens, outputTokens, elapsed) + base := fmt.Sprintf("[ocr] Summary: %d file(s) reviewed, %d comment(s), ~%d token(s) used (input: ~%d, output: ~%d)", + filesReviewed, commentsGenerated, totalTokens, inputTokens, outputTokens) + if cacheReadTokens > 0 || cacheWriteTokens > 0 { + base += fmt.Sprintf(", cache(read: ~%d, write: ~%d)", cacheReadTokens, cacheWriteTokens) + } + fmt.Fprintf(stdout.Writer(), "%s, %s elapsed\n", base, elapsed) } else { fmt.Fprintf(stdout.Writer(), "[ocr] Summary: %d file(s) reviewed, %d comment(s), ~%d token(s) used, %s elapsed\n", filesReviewed, commentsGenerated, totalTokens, elapsed)