feat: image output token billing, channel-mapped billing source, credits balance precheck
- Parse candidatesTokensDetails from Gemini API to separate image/text output tokens
- Add image_output_tokens and image_output_cost to usage_log (migration 089)
- Support per-image-token pricing via output_cost_per_image_token from model pricing data
- Channel pricing ImageOutputPrice override works in token billing mode
- Auto-fill image_output_price in channel pricing form from model defaults
- Add "channel_mapped" billing model source as new default (migration 088)
- Bills by model name after channel mapping, before account mapping
- Fix channel cache error TTL sign error (115s → 5s)
- Fix Update channel only invalidating new groups, not removed groups
- Fix frontend model_mapping clearing sending undefined instead of {}
- Credits balance precheck via shared AccountUsageService cache before injection
- Skip credits injection for accounts with insufficient balance
- Don't mark credits exhausted for "exhausted your capacity on this model" 429s
This commit is contained in:
@ -139,11 +139,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI, tempUnschedCache)
|
||||
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
|
||||
tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client)
|
||||
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
|
||||
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache, accountUsageService)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
|
||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
||||
|
||||
@ -31,7 +31,7 @@ type createChannelRequest struct {
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ModelPricing []channelModelPricingRequest `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"`
|
||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
||||
RestrictModels bool `json:"restrict_models"`
|
||||
}
|
||||
|
||||
@ -42,7 +42,7 @@ type updateChannelRequest struct {
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"`
|
||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
||||
RestrictModels *bool `json:"restrict_models"`
|
||||
}
|
||||
|
||||
@ -129,7 +129,7 @@ func channelToResponse(ch *service.Channel) *channelResponse {
|
||||
}
|
||||
resp.BillingModelSource = ch.BillingModelSource
|
||||
if resp.BillingModelSource == "" {
|
||||
resp.BillingModelSource = "requested"
|
||||
resp.BillingModelSource = "channel_mapped"
|
||||
}
|
||||
if resp.GroupIDs == nil {
|
||||
resp.GroupIDs = []int64{}
|
||||
@ -388,10 +388,11 @@ func (h *ChannelHandler) GetModelDefaultPricing(c *gin.Context) {
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"found": true,
|
||||
"input_price": pricing.InputPricePerToken,
|
||||
"output_price": pricing.OutputPricePerToken,
|
||||
"cache_write_price": pricing.CacheCreationPricePerToken,
|
||||
"cache_read_price": pricing.CacheReadPricePerToken,
|
||||
"found": true,
|
||||
"input_price": pricing.InputPricePerToken,
|
||||
"output_price": pricing.OutputPricePerToken,
|
||||
"cache_write_price": pricing.CacheCreationPricePerToken,
|
||||
"cache_read_price": pricing.CacheReadPricePerToken,
|
||||
"image_output_price": pricing.ImageOutputPricePerToken,
|
||||
})
|
||||
}
|
||||
|
||||
@ -36,7 +36,7 @@ func TestChannelToResponse_FullChannel(t *testing.T) {
|
||||
RestrictModels: true,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now.Add(time.Hour),
|
||||
GroupIDs: []int64{1, 2, 3},
|
||||
GroupIDs: []int64{1, 2, 3},
|
||||
ModelPricing: []service.ChannelModelPricing{
|
||||
{
|
||||
ID: 10,
|
||||
@ -94,8 +94,8 @@ func TestChannelToResponse_EmptyDefaults(t *testing.T) {
|
||||
BillingModelSource: "",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
GroupIDs: nil,
|
||||
ModelMapping: nil,
|
||||
GroupIDs: nil,
|
||||
ModelMapping: nil,
|
||||
ModelPricing: []service.ChannelModelPricing{
|
||||
{
|
||||
Platform: "",
|
||||
@ -106,7 +106,7 @@ func TestChannelToResponse_EmptyDefaults(t *testing.T) {
|
||||
}
|
||||
|
||||
resp := channelToResponse(ch)
|
||||
require.Equal(t, "requested", resp.BillingModelSource)
|
||||
require.Equal(t, "channel_mapped", resp.BillingModelSource)
|
||||
require.NotNil(t, resp.GroupIDs)
|
||||
require.Empty(t, resp.GroupIDs)
|
||||
require.NotNil(t, resp.ModelMapping)
|
||||
|
||||
@ -125,6 +125,7 @@ type ClaudeUsage struct {
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
|
||||
ImageOutputTokens int `json:"image_output_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeError Claude 错误响应
|
||||
|
||||
@ -149,13 +149,31 @@ type GeminiCandidate struct {
|
||||
GroundingMetadata *GeminiGroundingMetadata `json:"groundingMetadata,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiTokenDetail Gemini token 详情(按模态分类)
|
||||
type GeminiTokenDetail struct {
|
||||
Modality string `json:"modality"`
|
||||
TokenCount int `json:"tokenCount"`
|
||||
}
|
||||
|
||||
// GeminiUsageMetadata Gemini 用量元数据
|
||||
type GeminiUsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount,omitempty"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
|
||||
CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"`
|
||||
TotalTokenCount int `json:"totalTokenCount,omitempty"`
|
||||
ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens(按输出价格计费)
|
||||
PromptTokenCount int `json:"promptTokenCount,omitempty"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
|
||||
CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"`
|
||||
TotalTokenCount int `json:"totalTokenCount,omitempty"`
|
||||
ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens(按输出价格计费)
|
||||
CandidatesTokensDetails []GeminiTokenDetail `json:"candidatesTokensDetails,omitempty"`
|
||||
PromptTokensDetails []GeminiTokenDetail `json:"promptTokensDetails,omitempty"`
|
||||
}
|
||||
|
||||
// ImageOutputTokens 从 CandidatesTokensDetails 中提取 IMAGE 模态的 token 数
|
||||
func (m *GeminiUsageMetadata) ImageOutputTokens() int {
|
||||
for _, d := range m.CandidatesTokensDetails {
|
||||
if d.Modality == "IMAGE" {
|
||||
return d.TokenCount
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GeminiGroundingMetadata Gemini grounding 元数据(Web Search)
|
||||
|
||||
@ -284,6 +284,7 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
|
||||
usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
|
||||
usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount
|
||||
usage.CacheReadInputTokens = cached
|
||||
usage.ImageOutputTokens = geminiResp.UsageMetadata.ImageOutputTokens()
|
||||
}
|
||||
|
||||
// 生成响应 ID
|
||||
|
||||
@ -32,9 +32,10 @@ type StreamingProcessor struct {
|
||||
groundingChunks []GeminiGroundingChunk
|
||||
|
||||
// 累计 usage
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
cacheReadTokens int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
cacheReadTokens int
|
||||
imageOutputTokens int
|
||||
}
|
||||
|
||||
// NewStreamingProcessor 创建流式响应处理器
|
||||
@ -45,6 +46,28 @@ func NewStreamingProcessor(originalModel string) *StreamingProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
// SetUsageMapHook sets an optional hook that modifies usage maps before they are emitted.
|
||||
func (p *StreamingProcessor) SetUsageMapHook(fn UsageMapHook) {
|
||||
p.usageMapHook = fn
|
||||
}
|
||||
|
||||
func usageToMap(u ClaudeUsage) map[string]any {
|
||||
m := map[string]any{
|
||||
"input_tokens": u.InputTokens,
|
||||
"output_tokens": u.OutputTokens,
|
||||
}
|
||||
if u.CacheCreationInputTokens > 0 {
|
||||
m["cache_creation_input_tokens"] = u.CacheCreationInputTokens
|
||||
}
|
||||
if u.CacheReadInputTokens > 0 {
|
||||
m["cache_read_input_tokens"] = u.CacheReadInputTokens
|
||||
}
|
||||
if u.ImageOutputTokens > 0 {
|
||||
m["image_output_tokens"] = u.ImageOutputTokens
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件
|
||||
func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
||||
line = strings.TrimSpace(line)
|
||||
@ -87,6 +110,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
||||
p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
|
||||
p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount
|
||||
p.cacheReadTokens = cached
|
||||
p.imageOutputTokens = geminiResp.UsageMetadata.ImageOutputTokens()
|
||||
}
|
||||
|
||||
// 处理 parts
|
||||
@ -127,6 +151,7 @@ func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) {
|
||||
InputTokens: p.inputTokens,
|
||||
OutputTokens: p.outputTokens,
|
||||
CacheReadInputTokens: p.cacheReadTokens,
|
||||
ImageOutputTokens: p.imageOutputTokens,
|
||||
}
|
||||
|
||||
if !p.messageStartSent {
|
||||
@ -158,6 +183,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
|
||||
usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached
|
||||
usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount + v1Resp.Response.UsageMetadata.ThoughtsTokenCount
|
||||
usage.CacheReadInputTokens = cached
|
||||
usage.ImageOutputTokens = v1Resp.Response.UsageMetadata.ImageOutputTokens()
|
||||
}
|
||||
|
||||
responseID := v1Resp.ResponseID
|
||||
@ -485,6 +511,7 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
|
||||
InputTokens: p.inputTokens,
|
||||
OutputTokens: p.outputTokens,
|
||||
CacheReadInputTokens: p.cacheReadTokens,
|
||||
ImageOutputTokens: p.imageOutputTokens,
|
||||
}
|
||||
|
||||
deltaEvent := map[string]any{
|
||||
|
||||
@ -97,7 +97,7 @@ func TestUnmarshalModelMapping(t *testing.T) {
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "valid JSON",
|
||||
name: "valid JSON",
|
||||
input: []byte(`{"openai":{"gpt-4":"gpt-4-turbo"},"anthropic":{"old":"new"}}`),
|
||||
want: map[string]map[string]string{
|
||||
"openai": {"gpt-4": "gpt-4-turbo"},
|
||||
|
||||
@ -28,7 +28,7 @@ import (
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at"
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at"
|
||||
|
||||
// usageLogInsertArgTypes must stay in the same order as:
|
||||
// 1. prepareUsageLogInsert().args
|
||||
@ -53,6 +53,8 @@ var usageLogInsertArgTypes = [...]string{
|
||||
"integer", // cache_read_tokens
|
||||
"integer", // cache_creation_5m_tokens
|
||||
"integer", // cache_creation_1h_tokens
|
||||
"integer", // image_output_tokens
|
||||
"numeric", // image_output_cost
|
||||
"numeric", // input_cost
|
||||
"numeric", // output_cost
|
||||
"numeric", // cache_creation_cost
|
||||
@ -330,6 +332,8 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
image_output_tokens,
|
||||
image_output_cost,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
@ -363,9 +367,9 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
|
||||
$1, $2, $3, $4, $5, $6, $7,
|
||||
$8, $9,
|
||||
$10, $11, $12, $13,
|
||||
$14, $15,
|
||||
$16, $17, $18, $19, $20, $21,
|
||||
$22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44
|
||||
$14, $15, $16, $17,
|
||||
$18, $19, $20, $21, $22, $23,
|
||||
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
@ -766,6 +770,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
image_output_tokens,
|
||||
image_output_cost,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
@ -797,7 +803,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
created_at
|
||||
) AS (VALUES `)
|
||||
|
||||
args := make([]any, 0, len(keys)*45)
|
||||
args := make([]any, 0, len(keys)*47)
|
||||
argPos := 1
|
||||
for idx, key := range keys {
|
||||
if idx > 0 {
|
||||
@ -841,6 +847,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
image_output_tokens,
|
||||
image_output_cost,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
@ -887,6 +895,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
image_output_tokens,
|
||||
image_output_cost,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
@ -973,6 +983,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
image_output_tokens,
|
||||
image_output_cost,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
@ -1004,7 +1016,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
created_at
|
||||
) AS (VALUES `)
|
||||
|
||||
args := make([]any, 0, len(preparedList)*44)
|
||||
args := make([]any, 0, len(preparedList)*46)
|
||||
argPos := 1
|
||||
for idx, prepared := range preparedList {
|
||||
if idx > 0 {
|
||||
@ -1045,6 +1057,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
image_output_tokens,
|
||||
image_output_cost,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
@ -1091,6 +1105,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
image_output_tokens,
|
||||
image_output_cost,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
@ -1145,6 +1161,8 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
image_output_tokens,
|
||||
image_output_cost,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
@ -1178,9 +1196,9 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
|
||||
$1, $2, $3, $4, $5, $6, $7,
|
||||
$8, $9,
|
||||
$10, $11, $12, $13,
|
||||
$14, $15,
|
||||
$16, $17, $18, $19, $20, $21,
|
||||
$22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44
|
||||
$14, $15, $16, $17,
|
||||
$18, $19, $20, $21, $22, $23,
|
||||
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
`, prepared.args...)
|
||||
@ -1248,6 +1266,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
||||
log.CacheReadTokens,
|
||||
log.CacheCreation5mTokens,
|
||||
log.CacheCreation1hTokens,
|
||||
log.ImageOutputTokens,
|
||||
log.ImageOutputCost,
|
||||
log.InputCost,
|
||||
log.OutputCost,
|
||||
log.CacheCreationCost,
|
||||
@ -4011,6 +4031,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
cacheReadTokens int
|
||||
cacheCreation5m int
|
||||
cacheCreation1h int
|
||||
imageOutputTokens int
|
||||
imageOutputCost float64
|
||||
inputCost float64
|
||||
outputCost float64
|
||||
cacheCreationCost float64
|
||||
@ -4059,6 +4081,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
&cacheReadTokens,
|
||||
&cacheCreation5m,
|
||||
&cacheCreation1h,
|
||||
&imageOutputTokens,
|
||||
&imageOutputCost,
|
||||
&inputCost,
|
||||
&outputCost,
|
||||
&cacheCreationCost,
|
||||
@ -4105,6 +4129,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
CacheReadTokens: cacheReadTokens,
|
||||
CacheCreation5mTokens: cacheCreation5m,
|
||||
CacheCreation1hTokens: cacheCreation1h,
|
||||
ImageOutputTokens: imageOutputTokens,
|
||||
ImageOutputCost: imageOutputCost,
|
||||
InputCost: inputCost,
|
||||
OutputCost: outputCost,
|
||||
CacheCreationCost: cacheCreationCost,
|
||||
|
||||
@ -56,6 +56,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
||||
log.CacheReadTokens,
|
||||
log.CacheCreation5mTokens,
|
||||
log.CacheCreation1hTokens,
|
||||
log.ImageOutputTokens,
|
||||
log.ImageOutputCost,
|
||||
log.InputCost,
|
||||
log.OutputCost,
|
||||
log.CacheCreationCost,
|
||||
@ -133,6 +135,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
|
||||
log.CacheReadTokens,
|
||||
log.CacheCreation5mTokens,
|
||||
log.CacheCreation1hTokens,
|
||||
log.ImageOutputTokens,
|
||||
log.ImageOutputCost,
|
||||
log.InputCost,
|
||||
log.OutputCost,
|
||||
log.CacheCreationCost,
|
||||
@ -447,6 +451,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
4, // cache_read_tokens
|
||||
5, // cache_creation_5m_tokens
|
||||
6, // cache_creation_1h_tokens
|
||||
0, // image_output_tokens
|
||||
0.0, // image_output_cost
|
||||
0.1, // input_cost
|
||||
0.2, // output_cost
|
||||
0.3, // cache_creation_cost
|
||||
@ -499,6 +505,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullInt64{},
|
||||
sql.NullInt64{},
|
||||
1, 2, 3, 4, 5, 6,
|
||||
0, 0.0, // image_output_tokens, image_output_cost
|
||||
0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
|
||||
1.0,
|
||||
sql.NullFloat64{},
|
||||
@ -546,6 +553,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullInt64{},
|
||||
sql.NullInt64{},
|
||||
1, 2, 3, 4, 5, 6,
|
||||
0, 0.0, // image_output_tokens, image_output_cost
|
||||
0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
|
||||
1.0,
|
||||
sql.NullFloat64{},
|
||||
|
||||
@ -846,6 +846,15 @@ func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account *
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
// GetAntigravityCredits 返回账号的 AI Credits 信息,复用 getAntigravityUsage 的缓存。
|
||||
// 如果缓存存在且 TTL 充足则直接返回;TTL 不足时自动刷新。
|
||||
func (s *AccountUsageService) GetAntigravityCredits(ctx context.Context, account *Account) (*UsageInfo, error) {
|
||||
if account == nil || account.Platform != PlatformAntigravity {
|
||||
return nil, nil
|
||||
}
|
||||
return s.getAntigravityUsage(ctx, account)
|
||||
}
|
||||
|
||||
// recalcAntigravityRemainingSeconds 重新计算 Antigravity UsageInfo 中各窗口的 RemainingSeconds
|
||||
// 用于从缓存取出时更新倒计时,避免返回过时的剩余秒数
|
||||
func recalcAntigravityRemainingSeconds(info *UsageInfo) {
|
||||
|
||||
@ -19,6 +19,54 @@ const (
|
||||
creditsExhaustedDuration = 5 * time.Hour
|
||||
)
|
||||
|
||||
// checkAccountCredits 通过共享的 AccountUsageService 缓存检查账号是否有足够的 AI Credits。
|
||||
// 缓存 TTL 不足时会自动从 Google loadCodeAssist API 刷新。
|
||||
// 返回 true 表示积分可用。
|
||||
func (s *AntigravityGatewayService) checkAccountCredits(
|
||||
ctx context.Context, account *Account, accessToken, proxyURL string,
|
||||
) bool {
|
||||
if account == nil || account.ID == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if s.accountUsageService == nil {
|
||||
return true // 无 usage service 时不阻断
|
||||
}
|
||||
|
||||
usageInfo, err := s.accountUsageService.GetAntigravityCredits(ctx, account)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.antigravity_gateway",
|
||||
"check_credits: get_credits_failed account=%d err=%v", account.ID, err)
|
||||
return true // 出错时假设有积分,不阻断
|
||||
}
|
||||
|
||||
if usageInfo == nil || len(usageInfo.AICredits) == 0 {
|
||||
logger.LegacyPrintf("service.antigravity_gateway",
|
||||
"check_credits: account=%d has_credits=false amount=0 (no credits field)",
|
||||
account.ID)
|
||||
return false
|
||||
}
|
||||
|
||||
for _, credit := range usageInfo.AICredits {
|
||||
if credit.CreditType == "GOOGLE_ONE_AI" {
|
||||
minimum := credit.MinimumBalance
|
||||
if minimum <= 0 {
|
||||
minimum = 5
|
||||
}
|
||||
hasCredits := credit.Amount >= minimum
|
||||
logger.LegacyPrintf("service.antigravity_gateway",
|
||||
"check_credits: account=%d has_credits=%t amount=%.0f minimum=%.0f",
|
||||
account.ID, hasCredits, credit.Amount, minimum)
|
||||
return hasCredits
|
||||
}
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("service.antigravity_gateway",
|
||||
"check_credits: account=%d has_credits=false (no GOOGLE_ONE_AI credit)",
|
||||
account.ID)
|
||||
return false
|
||||
}
|
||||
|
||||
type antigravity429Category string
|
||||
|
||||
const (
|
||||
@ -141,6 +189,13 @@ func resolveCreditsOveragesModelKey(ctx context.Context, account *Account, upstr
|
||||
}
|
||||
|
||||
// shouldMarkCreditsExhausted 判断一次 credits 请求失败是否应标记为 credits 耗尽。
|
||||
// 此函数在积分注入后失败时调用(预检查注入 + attemptCreditsOveragesRetry 两条路径)。
|
||||
// - 429 + 非单模型限流:积分注入后仍 429 → 标记耗尽。
|
||||
// - 429 + 单模型限流("exhausted your capacity on this model"):该模型免费配额用完,
|
||||
// 积分注入对此无效,但账号积分对其他模型可能仍可用 → 不标记积分耗尽。
|
||||
// - 403 等其他 4xx:检查 body 是否包含积分不足的关键词。
|
||||
//
|
||||
// clearCreditsExhausted 会在后续成功时自动清除。
|
||||
func shouldMarkCreditsExhausted(resp *http.Response, respBody []byte, reqErr error) bool {
|
||||
if reqErr != nil || resp == nil {
|
||||
return false
|
||||
@ -148,13 +203,16 @@ func shouldMarkCreditsExhausted(resp *http.Response, respBody []byte, reqErr err
|
||||
if resp.StatusCode >= 500 || resp.StatusCode == http.StatusRequestTimeout {
|
||||
return false
|
||||
}
|
||||
// 注意:不再检查 isURLLevelRateLimit。此函数仅在积分重试失败后调用,
|
||||
// 如果注入 enabledCreditTypes 后仍返回 "Resource has been exhausted",
|
||||
// 说明积分也已耗尽,应该标记。clearCreditsExhausted 会在后续成功时自动清除。
|
||||
if info := parseAntigravitySmartRetryInfo(respBody); info != nil {
|
||||
return false
|
||||
}
|
||||
bodyLower := strings.ToLower(string(respBody))
|
||||
// 积分注入后仍 429
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
// 单模型配额耗尽:积分注入对此无效,不标记整个账号积分耗尽
|
||||
if strings.Contains(bodyLower, "exhausted your capacity on this model") {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
// 其他 4xx:关键词匹配(如 403 + "Insufficient credits")
|
||||
for _, keyword := range creditsExhaustedKeywords {
|
||||
if strings.Contains(bodyLower, keyword) {
|
||||
return true
|
||||
@ -181,6 +239,16 @@ func (s *AntigravityGatewayService) attemptCreditsOveragesRetry(
|
||||
if creditsBody == nil {
|
||||
return &creditsOveragesRetryResult{handled: false}
|
||||
}
|
||||
|
||||
// Check actual credits balance before attempting retry
|
||||
if !s.checkAccountCredits(p.ctx, p.account, p.accessToken, p.proxyURL) {
|
||||
s.setCreditsExhausted(p.ctx, p.account)
|
||||
modelKey := resolveCreditsOveragesModelKey(p.ctx, p.account, modelName, p.requestedModel)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s credit_overages_no_credits model=%s account=%d (skipping credits retry)",
|
||||
p.prefix, modelKey, p.account.ID)
|
||||
return &creditsOveragesRetryResult{handled: true}
|
||||
}
|
||||
|
||||
modelKey := resolveCreditsOveragesModelKey(p.ctx, p.account, modelName, p.requestedModel)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 credit_overages_retry model=%s account=%d (injecting enabledCreditTypes)",
|
||||
p.prefix, modelKey, p.account.ID)
|
||||
|
||||
@ -418,7 +418,13 @@ func TestShouldMarkCreditsExhausted(t *testing.T) {
|
||||
require.True(t, shouldMarkCreditsExhausted(resp, body, nil))
|
||||
})
|
||||
|
||||
t.Run("结构化限流不标记", func(t *testing.T) {
|
||||
t.Run("单模型配额耗尽不标记(积分对此无效)", func(t *testing.T) {
|
||||
resp := &http.Response{StatusCode: http.StatusTooManyRequests}
|
||||
body := []byte(`{"error":{"code":429,"message":"You have exhausted your capacity on this model. Your quota will reset after 146h11m17s.","status":"RESOURCE_EXHAUSTED"}}`)
|
||||
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
|
||||
})
|
||||
|
||||
t.Run("429 结构化限流也标记(积分注入后仍 429 即为耗尽)", func(t *testing.T) {
|
||||
resp := &http.Response{StatusCode: http.StatusTooManyRequests}
|
||||
body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"RATE_LIMIT_EXCEEDED"},{"@type":"type.googleapis.com/google.rpc.RetryInfo","retryDelay":"0.5s"}]}}`)
|
||||
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
|
||||
|
||||
@ -557,7 +557,13 @@ func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopP
|
||||
if p.requestedModel != "" && p.account.Platform == PlatformAntigravity &&
|
||||
p.account.IsOveragesEnabled() && !p.account.isCreditsExhausted() &&
|
||||
p.account.isModelRateLimitedWithContext(p.ctx, p.requestedModel) {
|
||||
if creditsBody := injectEnabledCreditTypes(p.body); creditsBody != nil {
|
||||
// Check actual credits balance before injection
|
||||
if !s.checkAccountCredits(p.ctx, p.account, p.accessToken, p.proxyURL) {
|
||||
// No credits available - mark as exhausted and skip injection
|
||||
s.setCreditsExhausted(p.ctx, p.account)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: no_credits_available account=%d (skipping credits injection)",
|
||||
p.prefix, p.account.ID)
|
||||
} else if creditsBody := injectEnabledCreditTypes(p.body); creditsBody != nil {
|
||||
p.body = creditsBody
|
||||
overagesInjected = true
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: model_rate_limited_credits_inject model=%s account=%d (injecting enabledCreditTypes)",
|
||||
@ -870,14 +876,15 @@ func logPrefix(sessionID, accountName string) string {
|
||||
|
||||
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
|
||||
type AntigravityGatewayService struct {
|
||||
accountRepo AccountRepository
|
||||
tokenProvider *AntigravityTokenProvider
|
||||
rateLimitService *RateLimitService
|
||||
httpUpstream HTTPUpstream
|
||||
settingService *SettingService
|
||||
cache GatewayCache // 用于模型级限流时清除粘性会话绑定
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
internal500Cache Internal500CounterCache // INTERNAL 500 渐进惩罚计数器
|
||||
accountRepo AccountRepository
|
||||
tokenProvider *AntigravityTokenProvider
|
||||
rateLimitService *RateLimitService
|
||||
httpUpstream HTTPUpstream
|
||||
settingService *SettingService
|
||||
cache GatewayCache // 用于模型级限流时清除粘性会话绑定
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
internal500Cache Internal500CounterCache // INTERNAL 500 渐进惩罚计数器
|
||||
accountUsageService *AccountUsageService // 共享 usage 缓存,用于积分余额检查
|
||||
}
|
||||
|
||||
func NewAntigravityGatewayService(
|
||||
@ -889,16 +896,18 @@ func NewAntigravityGatewayService(
|
||||
httpUpstream HTTPUpstream,
|
||||
settingService *SettingService,
|
||||
internal500Cache Internal500CounterCache,
|
||||
accountUsageService *AccountUsageService,
|
||||
) *AntigravityGatewayService {
|
||||
return &AntigravityGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
tokenProvider: tokenProvider,
|
||||
rateLimitService: rateLimitService,
|
||||
httpUpstream: httpUpstream,
|
||||
settingService: settingService,
|
||||
cache: cache,
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
internal500Cache: internal500Cache,
|
||||
accountRepo: accountRepo,
|
||||
tokenProvider: tokenProvider,
|
||||
rateLimitService: rateLimitService,
|
||||
httpUpstream: httpUpstream,
|
||||
settingService: settingService,
|
||||
cache: cache,
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
internal500Cache: internal500Cache,
|
||||
accountUsageService: accountUsageService,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -56,6 +56,7 @@ type ModelPricing struct {
|
||||
LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格
|
||||
LongContextInputMultiplier float64 // 长上下文整次会话输入倍率
|
||||
LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率
|
||||
ImageOutputPricePerToken float64 // 图片输出 token 价格 (USD)
|
||||
}
|
||||
|
||||
const (
|
||||
@ -94,12 +95,14 @@ type UsageTokens struct {
|
||||
CacheReadTokens int
|
||||
CacheCreation5mTokens int
|
||||
CacheCreation1hTokens int
|
||||
ImageOutputTokens int
|
||||
}
|
||||
|
||||
// CostBreakdown 费用明细
|
||||
type CostBreakdown struct {
|
||||
InputCost float64
|
||||
OutputCost float64
|
||||
ImageOutputCost float64
|
||||
CacheCreationCost float64
|
||||
CacheReadCost float64
|
||||
TotalCost float64
|
||||
@ -358,6 +361,7 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
|
||||
LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold,
|
||||
LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier,
|
||||
LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier,
|
||||
ImageOutputPricePerToken: litellmPricing.OutputCostPerImageToken,
|
||||
}), nil
|
||||
}
|
||||
}
|
||||
@ -399,6 +403,9 @@ func (s *BillingService) GetModelPricingWithChannel(model string, channelPricing
|
||||
pricing.CacheReadPricePerToken = *channelPricing.CacheReadPrice
|
||||
pricing.CacheReadPricePerTokenPriority = *channelPricing.CacheReadPrice
|
||||
}
|
||||
if channelPricing.ImageOutputPrice != nil {
|
||||
pricing.ImageOutputPricePerToken = *channelPricing.ImageOutputPrice
|
||||
}
|
||||
return pricing, nil
|
||||
}
|
||||
|
||||
@ -489,7 +496,22 @@ func (s *BillingService) calculateTokenCost(resolved *ResolvedPricing, input Cos
|
||||
}
|
||||
|
||||
breakdown.InputCost = float64(input.Tokens.InputTokens) * inputPricePerToken
|
||||
breakdown.OutputCost = float64(input.Tokens.OutputTokens) * outputPricePerToken
|
||||
|
||||
// Separate image output tokens from text output tokens
|
||||
textOutputTokens := input.Tokens.OutputTokens - input.Tokens.ImageOutputTokens
|
||||
if textOutputTokens < 0 {
|
||||
textOutputTokens = 0
|
||||
}
|
||||
breakdown.OutputCost = float64(textOutputTokens) * outputPricePerToken
|
||||
|
||||
// Image output tokens cost (separate rate from text output)
|
||||
if input.Tokens.ImageOutputTokens > 0 {
|
||||
imageOutputPrice := pricing.ImageOutputPricePerToken
|
||||
if imageOutputPrice == 0 {
|
||||
imageOutputPrice = outputPricePerToken // fallback to regular output price
|
||||
}
|
||||
breakdown.ImageOutputCost = float64(input.Tokens.ImageOutputTokens) * imageOutputPrice
|
||||
}
|
||||
|
||||
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
|
||||
if input.Tokens.CacheCreation5mTokens == 0 && input.Tokens.CacheCreation1hTokens == 0 && input.Tokens.CacheCreationTokens > 0 {
|
||||
@ -507,11 +529,12 @@ func (s *BillingService) calculateTokenCost(resolved *ResolvedPricing, input Cos
|
||||
if tierMultiplier != 1.0 {
|
||||
breakdown.InputCost *= tierMultiplier
|
||||
breakdown.OutputCost *= tierMultiplier
|
||||
breakdown.ImageOutputCost *= tierMultiplier
|
||||
breakdown.CacheCreationCost *= tierMultiplier
|
||||
breakdown.CacheReadCost *= tierMultiplier
|
||||
}
|
||||
|
||||
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost +
|
||||
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost + breakdown.ImageOutputCost +
|
||||
breakdown.CacheCreationCost + breakdown.CacheReadCost
|
||||
breakdown.ActualCost = breakdown.TotalCost * input.RateMultiplier
|
||||
|
||||
@ -597,8 +620,21 @@ func (s *BillingService) calculateCostInternal(model string, tokens UsageTokens,
|
||||
// 计算输入token费用(使用per-token价格)
|
||||
breakdown.InputCost = float64(tokens.InputTokens) * inputPricePerToken
|
||||
|
||||
// 计算输出token费用
|
||||
breakdown.OutputCost = float64(tokens.OutputTokens) * outputPricePerToken
|
||||
// 计算输出token费用(分离图片输出token)
|
||||
textOutputTokens := tokens.OutputTokens - tokens.ImageOutputTokens
|
||||
if textOutputTokens < 0 {
|
||||
textOutputTokens = 0
|
||||
}
|
||||
breakdown.OutputCost = float64(textOutputTokens) * outputPricePerToken
|
||||
|
||||
// 图片输出 token 费用
|
||||
if tokens.ImageOutputTokens > 0 {
|
||||
imageOutputPrice := pricing.ImageOutputPricePerToken
|
||||
if imageOutputPrice == 0 {
|
||||
imageOutputPrice = outputPricePerToken
|
||||
}
|
||||
breakdown.ImageOutputCost = float64(tokens.ImageOutputTokens) * imageOutputPrice
|
||||
}
|
||||
|
||||
// 计算缓存费用
|
||||
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
|
||||
@ -620,12 +656,13 @@ func (s *BillingService) calculateCostInternal(model string, tokens UsageTokens,
|
||||
if tierMultiplier != 1.0 {
|
||||
breakdown.InputCost *= tierMultiplier
|
||||
breakdown.OutputCost *= tierMultiplier
|
||||
breakdown.ImageOutputCost *= tierMultiplier
|
||||
breakdown.CacheCreationCost *= tierMultiplier
|
||||
breakdown.CacheReadCost *= tierMultiplier
|
||||
}
|
||||
|
||||
// 计算总费用
|
||||
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost +
|
||||
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost + breakdown.ImageOutputCost +
|
||||
breakdown.CacheCreationCost + breakdown.CacheReadCost
|
||||
|
||||
// 应用倍率计算实际费用
|
||||
@ -730,6 +767,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage
|
||||
CacheReadTokens: inRangeCacheTokens,
|
||||
CacheCreation5mTokens: tokens.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: tokens.CacheCreation1hTokens,
|
||||
ImageOutputTokens: tokens.ImageOutputTokens,
|
||||
}
|
||||
inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier)
|
||||
if err != nil {
|
||||
@ -750,6 +788,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage
|
||||
return &CostBreakdown{
|
||||
InputCost: inRangeCost.InputCost + outRangeCost.InputCost,
|
||||
OutputCost: inRangeCost.OutputCost,
|
||||
ImageOutputCost: inRangeCost.ImageOutputCost,
|
||||
CacheCreationCost: inRangeCost.CacheCreationCost,
|
||||
CacheReadCost: inRangeCost.CacheReadCost + outRangeCost.CacheReadCost,
|
||||
TotalCost: inRangeCost.TotalCost + outRangeCost.TotalCost,
|
||||
|
||||
@ -24,8 +24,9 @@ func (m BillingMode) IsValid() bool {
|
||||
}
|
||||
|
||||
const (
|
||||
BillingModelSourceRequested = "requested"
|
||||
BillingModelSourceUpstream = "upstream"
|
||||
BillingModelSourceRequested = "requested"
|
||||
BillingModelSourceUpstream = "upstream"
|
||||
BillingModelSourceChannelMapped = "channel_mapped"
|
||||
)
|
||||
|
||||
// Channel 渠道实体
|
||||
@ -34,7 +35,7 @@ type Channel struct {
|
||||
Name string
|
||||
Description string
|
||||
Status string
|
||||
BillingModelSource string // "requested" or "upstream"
|
||||
BillingModelSource string // "requested", "upstream", or "channel_mapped"
|
||||
RestrictModels bool // 是否限制模型(仅允许定价列表中的模型)
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
@ -180,6 +181,7 @@ func (c *Channel) Clone() *Channel {
|
||||
type ChannelUsageFields struct {
|
||||
ChannelID int64 // 渠道 ID(0 = 无渠道)
|
||||
OriginalModel string // 用户原始请求模型(渠道映射前)
|
||||
BillingModelSource string // 计费模型来源:"requested" / "upstream"
|
||||
ChannelMappedModel string // 渠道映射后的模型名(无映射时等于 OriginalModel)
|
||||
BillingModelSource string // 计费模型来源:"requested" / "upstream" / "channel_mapped"
|
||||
ModelMappingChain string // 映射链描述,如 "a→b→c"
|
||||
}
|
||||
|
||||
@ -97,7 +97,7 @@ type ChannelMappingResult struct {
|
||||
MappedModel string // 映射后的模型名(无映射时等于原始模型名)
|
||||
ChannelID int64 // 渠道 ID(0 = 无渠道关联)
|
||||
Mapped bool // 是否发生了映射
|
||||
BillingModelSource string // 计费模型来源("requested" / "upstream")
|
||||
BillingModelSource string // 计费模型来源("requested" / "upstream" / "channel_mapped")
|
||||
}
|
||||
|
||||
// BuildModelMappingChain 根据映射结果和上游实际模型构建映射链描述。
|
||||
@ -119,9 +119,14 @@ func (r ChannelMappingResult) BuildModelMappingChain(reqModel, upstreamModel str
|
||||
|
||||
// ToUsageFields 将渠道映射结果转为使用记录字段
|
||||
func (r ChannelMappingResult) ToUsageFields(reqModel, upstreamModel string) ChannelUsageFields {
|
||||
channelMappedModel := reqModel
|
||||
if r.Mapped {
|
||||
channelMappedModel = r.MappedModel
|
||||
}
|
||||
return ChannelUsageFields{
|
||||
ChannelID: r.ChannelID,
|
||||
OriginalModel: reqModel,
|
||||
ChannelMappedModel: channelMappedModel,
|
||||
BillingModelSource: r.BillingModelSource,
|
||||
ModelMappingChain: r.BuildModelMappingChain(reqModel, upstreamModel),
|
||||
}
|
||||
@ -193,7 +198,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
||||
channelByGroupID: make(map[int64]*Channel),
|
||||
groupPlatform: make(map[int64]string),
|
||||
byID: make(map[int64]*Channel),
|
||||
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
|
||||
loadedAt: time.Now().Add(-(channelCacheTTL - channelErrorTTL)), // 使剩余 TTL = errorTTL
|
||||
}
|
||||
s.cache.Store(errorCache)
|
||||
return nil, fmt.Errorf("list all channels: %w", err)
|
||||
@ -374,7 +379,7 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6
|
||||
BillingModelSource: ch.BillingModelSource,
|
||||
}
|
||||
if result.BillingModelSource == "" {
|
||||
result.BillingModelSource = BillingModelSourceRequested
|
||||
result.BillingModelSource = BillingModelSourceChannelMapped
|
||||
}
|
||||
|
||||
platform := cache.groupPlatform[groupID]
|
||||
@ -481,7 +486,7 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
|
||||
ModelMapping: input.ModelMapping,
|
||||
}
|
||||
if channel.BillingModelSource == "" {
|
||||
channel.BillingModelSource = BillingModelSourceRequested
|
||||
channel.BillingModelSource = BillingModelSourceChannelMapped
|
||||
}
|
||||
|
||||
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
|
||||
@ -565,20 +570,36 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 先获取旧分组,Update 后旧分组关联已删除,无法再查到
|
||||
var oldGroupIDs []int64
|
||||
if s.authCacheInvalidator != nil {
|
||||
var err2 error
|
||||
oldGroupIDs, err2 = s.repo.GetGroupIDs(ctx, id)
|
||||
if err2 != nil {
|
||||
slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", id, "error", err2)
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.repo.Update(ctx, channel); err != nil {
|
||||
return nil, fmt.Errorf("update channel: %w", err)
|
||||
}
|
||||
|
||||
s.invalidateCache()
|
||||
|
||||
// 失效关联分组的 auth 缓存
|
||||
// 失效新旧分组的 auth 缓存
|
||||
if s.authCacheInvalidator != nil {
|
||||
groupIDs, err := s.repo.GetGroupIDs(ctx, id)
|
||||
if err != nil {
|
||||
slog.Warn("failed to get group IDs for cache invalidation", "channel_id", id, "error", err)
|
||||
seen := make(map[int64]struct{}, len(oldGroupIDs)+len(channel.GroupIDs))
|
||||
for _, gid := range oldGroupIDs {
|
||||
if _, ok := seen[gid]; !ok {
|
||||
seen[gid] = struct{}{}
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
||||
}
|
||||
}
|
||||
for _, gid := range groupIDs {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
||||
for _, gid := range channel.GroupIDs {
|
||||
if _, ok := seen[gid]; !ok {
|
||||
seen[gid] = struct{}{}
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -16,24 +16,24 @@ import (
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type mockChannelRepository struct {
|
||||
listAllFn func(ctx context.Context) ([]Channel, error)
|
||||
getGroupPlatformsFn func(ctx context.Context, groupIDs []int64) (map[int64]string, error)
|
||||
createFn func(ctx context.Context, channel *Channel) error
|
||||
getByIDFn func(ctx context.Context, id int64) (*Channel, error)
|
||||
updateFn func(ctx context.Context, channel *Channel) error
|
||||
deleteFn func(ctx context.Context, id int64) error
|
||||
listFn func(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error)
|
||||
existsByNameFn func(ctx context.Context, name string) (bool, error)
|
||||
existsByNameExcludingFn func(ctx context.Context, name string, excludeID int64) (bool, error)
|
||||
getGroupIDsFn func(ctx context.Context, channelID int64) ([]int64, error)
|
||||
setGroupIDsFn func(ctx context.Context, channelID int64, groupIDs []int64) error
|
||||
getChannelIDByGroupIDFn func(ctx context.Context, groupID int64) (int64, error)
|
||||
listAllFn func(ctx context.Context) ([]Channel, error)
|
||||
getGroupPlatformsFn func(ctx context.Context, groupIDs []int64) (map[int64]string, error)
|
||||
createFn func(ctx context.Context, channel *Channel) error
|
||||
getByIDFn func(ctx context.Context, id int64) (*Channel, error)
|
||||
updateFn func(ctx context.Context, channel *Channel) error
|
||||
deleteFn func(ctx context.Context, id int64) error
|
||||
listFn func(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error)
|
||||
existsByNameFn func(ctx context.Context, name string) (bool, error)
|
||||
existsByNameExcludingFn func(ctx context.Context, name string, excludeID int64) (bool, error)
|
||||
getGroupIDsFn func(ctx context.Context, channelID int64) ([]int64, error)
|
||||
setGroupIDsFn func(ctx context.Context, channelID int64, groupIDs []int64) error
|
||||
getChannelIDByGroupIDFn func(ctx context.Context, groupID int64) (int64, error)
|
||||
getGroupsInOtherChannelsFn func(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error)
|
||||
listModelPricingFn func(ctx context.Context, channelID int64) ([]ChannelModelPricing, error)
|
||||
createModelPricingFn func(ctx context.Context, pricing *ChannelModelPricing) error
|
||||
updateModelPricingFn func(ctx context.Context, pricing *ChannelModelPricing) error
|
||||
deleteModelPricingFn func(ctx context.Context, id int64) error
|
||||
replaceModelPricingFn func(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error
|
||||
listModelPricingFn func(ctx context.Context, channelID int64) ([]ChannelModelPricing, error)
|
||||
createModelPricingFn func(ctx context.Context, pricing *ChannelModelPricing) error
|
||||
updateModelPricingFn func(ctx context.Context, pricing *ChannelModelPricing) error
|
||||
deleteModelPricingFn func(ctx context.Context, id int64) error
|
||||
replaceModelPricingFn func(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error
|
||||
}
|
||||
|
||||
func (m *mockChannelRepository) Create(ctx context.Context, channel *Channel) error {
|
||||
@ -196,7 +196,6 @@ func newTestChannelServiceWithAuth(repo *mockChannelRepository, auth *mockChanne
|
||||
return NewChannelService(repo, auth)
|
||||
}
|
||||
|
||||
|
||||
// makeStandardRepo returns a repo that serves one active channel with anthropic pricing
|
||||
// for group 1, with the given model pricing and model mapping.
|
||||
func makeStandardRepo(ch Channel, groupPlatforms map[int64]string) *mockChannelRepository {
|
||||
@ -907,21 +906,21 @@ func TestResolveChannelMapping_DefaultBillingModelSource(t *testing.T) {
|
||||
ch := Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{10},
|
||||
GroupIDs: []int64{10},
|
||||
BillingModelSource: "", // empty
|
||||
}
|
||||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||||
svc := newTestChannelService(repo)
|
||||
|
||||
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4")
|
||||
require.Equal(t, BillingModelSourceRequested, result.BillingModelSource)
|
||||
require.Equal(t, BillingModelSourceChannelMapped, result.BillingModelSource)
|
||||
}
|
||||
|
||||
func TestResolveChannelMapping_UpstreamBillingModelSource(t *testing.T) {
|
||||
ch := Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{10},
|
||||
GroupIDs: []int64{10},
|
||||
BillingModelSource: BillingModelSourceUpstream,
|
||||
}
|
||||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||||
@ -957,7 +956,7 @@ func TestIsModelRestricted_NoChannel(t *testing.T) {
|
||||
ch := Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{10},
|
||||
GroupIDs: []int64{10},
|
||||
RestrictModels: true,
|
||||
}
|
||||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||||
@ -972,7 +971,7 @@ func TestIsModelRestricted_RestrictDisabled(t *testing.T) {
|
||||
ch := Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{10},
|
||||
GroupIDs: []int64{10},
|
||||
RestrictModels: false,
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{Platform: "anthropic", Models: []string{"claude-opus-4"}},
|
||||
@ -990,7 +989,7 @@ func TestIsModelRestricted_InactiveChannel(t *testing.T) {
|
||||
ch := Channel{
|
||||
ID: 1,
|
||||
Status: StatusDisabled,
|
||||
GroupIDs: []int64{10},
|
||||
GroupIDs: []int64{10},
|
||||
RestrictModels: true,
|
||||
}
|
||||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||||
@ -1004,7 +1003,7 @@ func TestIsModelRestricted_ModelInPricing(t *testing.T) {
|
||||
ch := Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{10},
|
||||
GroupIDs: []int64{10},
|
||||
RestrictModels: true,
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{Platform: "anthropic", Models: []string{"claude-opus-4", "claude-sonnet-4"}},
|
||||
@ -1021,7 +1020,7 @@ func TestIsModelRestricted_ModelInWildcard(t *testing.T) {
|
||||
ch := Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{10},
|
||||
GroupIDs: []int64{10},
|
||||
RestrictModels: true,
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{Platform: "anthropic", Models: []string{"claude-*"}},
|
||||
@ -1038,7 +1037,7 @@ func TestIsModelRestricted_ModelNotFound(t *testing.T) {
|
||||
ch := Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{10},
|
||||
GroupIDs: []int64{10},
|
||||
RestrictModels: true,
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{Platform: "anthropic", Models: []string{"claude-opus-4"}},
|
||||
@ -1055,7 +1054,7 @@ func TestIsModelRestricted_CaseInsensitive(t *testing.T) {
|
||||
ch := Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{10},
|
||||
GroupIDs: []int64{10},
|
||||
RestrictModels: true,
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{Platform: "anthropic", Models: []string{"claude-opus-4"}},
|
||||
@ -1088,7 +1087,7 @@ func TestResolveChannelMappingAndRestrict_ModelInPricing_WithMapping(t *testing.
|
||||
ch := Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{10},
|
||||
GroupIDs: []int64{10},
|
||||
RestrictModels: true,
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{Platform: "anthropic", Models: []string{"claude-sonnet-4"}},
|
||||
@ -1117,7 +1116,7 @@ func TestResolveChannelMappingAndRestrict_ModelNotInPricing_WithMapping(t *testi
|
||||
ch := Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{10},
|
||||
GroupIDs: []int64{10},
|
||||
RestrictModels: true,
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{Platform: "anthropic", Models: []string{"claude-sonnet-4"}},
|
||||
@ -1142,7 +1141,7 @@ func TestResolveChannelMappingAndRestrict_ModelNotInPricing_NoMapping(t *testing
|
||||
ch := Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{10},
|
||||
GroupIDs: []int64{10},
|
||||
RestrictModels: true,
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{Platform: "anthropic", Models: []string{"claude-sonnet-4"}},
|
||||
@ -1451,11 +1450,11 @@ func TestCreate_DefaultBillingModelSource(t *testing.T) {
|
||||
|
||||
result, err := svc.Create(context.Background(), &CreateChannelInput{
|
||||
Name: "new-channel",
|
||||
BillingModelSource: "", // empty, should default to "requested"
|
||||
BillingModelSource: "", // empty, should default to "channel_mapped"
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, BillingModelSourceRequested, result.BillingModelSource)
|
||||
require.Equal(t, BillingModelSourceChannelMapped, result.BillingModelSource)
|
||||
}
|
||||
|
||||
func TestCreate_InvalidatesCache(t *testing.T) {
|
||||
|
||||
@ -483,6 +483,7 @@ type ClaudeUsage struct {
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||
CacheCreation5mTokens int // 5分钟缓存创建token(来自嵌套 cache_creation 对象)
|
||||
CacheCreation1hTokens int // 1小时缓存创建token(来自嵌套 cache_creation 对象)
|
||||
ImageOutputTokens int `json:"image_output_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// ForwardResult 转发结果
|
||||
@ -7729,6 +7730,9 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
var cost *CostBreakdown
|
||||
// 确定计费模型
|
||||
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
|
||||
if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" {
|
||||
billingModel = input.ChannelMappedModel
|
||||
}
|
||||
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
|
||||
billingModel = input.OriginalModel
|
||||
}
|
||||
@ -7777,6 +7781,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
ImageOutputTokens: result.Usage.ImageOutputTokens,
|
||||
}
|
||||
var err error
|
||||
if s.resolver != nil && apiKey.Group != nil {
|
||||
@ -7836,8 +7841,10 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
ImageOutputTokens: result.Usage.ImageOutputTokens,
|
||||
InputCost: cost.InputCost,
|
||||
OutputCost: cost.OutputCost,
|
||||
ImageOutputCost: cost.ImageOutputCost,
|
||||
CacheCreationCost: cost.CacheCreationCost,
|
||||
CacheReadCost: cost.CacheReadCost,
|
||||
TotalCost: cost.TotalCost,
|
||||
@ -7976,6 +7983,9 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
var cost *CostBreakdown
|
||||
// 确定计费模型
|
||||
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
|
||||
if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" {
|
||||
billingModel = input.ChannelMappedModel
|
||||
}
|
||||
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
|
||||
billingModel = input.OriginalModel
|
||||
}
|
||||
@ -8007,6 +8017,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
ImageOutputTokens: result.Usage.ImageOutputTokens,
|
||||
}
|
||||
var err error
|
||||
// 优先尝试 Resolver + CalculateCostUnified(仅在有渠道定价时使用)
|
||||
@ -8073,8 +8084,10 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
ImageOutputTokens: result.Usage.ImageOutputTokens,
|
||||
InputCost: cost.InputCost,
|
||||
OutputCost: cost.OutputCost,
|
||||
ImageOutputCost: cost.ImageOutputCost,
|
||||
CacheCreationCost: cost.CacheCreationCost,
|
||||
CacheReadCost: cost.CacheReadCost,
|
||||
TotalCost: cost.TotalCost,
|
||||
|
||||
@ -2692,12 +2692,27 @@ func extractGeminiUsage(data []byte) *ClaudeUsage {
|
||||
cand := int(usage.Get("candidatesTokenCount").Int())
|
||||
cached := int(usage.Get("cachedContentTokenCount").Int())
|
||||
thoughts := int(usage.Get("thoughtsTokenCount").Int())
|
||||
|
||||
// 从 candidatesTokensDetails 提取 IMAGE 模态 token 数
|
||||
imageTokens := 0
|
||||
candidateDetails := usage.Get("candidatesTokensDetails")
|
||||
if candidateDetails.Exists() {
|
||||
candidateDetails.ForEach(func(_, detail gjson.Result) bool {
|
||||
if detail.Get("modality").String() == "IMAGE" {
|
||||
imageTokens = int(detail.Get("tokenCount").Int())
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount,
|
||||
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去
|
||||
return &ClaudeUsage{
|
||||
InputTokens: prompt - cached,
|
||||
OutputTokens: cand + thoughts,
|
||||
CacheReadInputTokens: cached,
|
||||
ImageOutputTokens: imageTokens,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -134,6 +134,9 @@ func (r *ModelPricingResolver) applyTokenOverrides(chPricing *ChannelModelPricin
|
||||
resolved.BasePricing.CacheReadPricePerToken = *chPricing.CacheReadPrice
|
||||
resolved.BasePricing.CacheReadPricePerTokenPriority = *chPricing.CacheReadPrice
|
||||
}
|
||||
if chPricing.ImageOutputPrice != nil {
|
||||
resolved.BasePricing.ImageOutputPricePerToken = *chPricing.ImageOutputPrice
|
||||
}
|
||||
}
|
||||
|
||||
// applyRequestTierOverrides 应用按次/图片模式的渠道覆盖
|
||||
|
||||
@ -204,6 +204,7 @@ type OpenAIUsage struct {
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
|
||||
ImageOutputTokens int `json:"image_output_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIForwardResult represents the result of forwarding
|
||||
@ -4177,6 +4178,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
ImageOutputTokens: result.Usage.ImageOutputTokens,
|
||||
}
|
||||
|
||||
// Get rate multiplier
|
||||
@ -4195,6 +4197,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
if result.BillingModel != "" {
|
||||
billingModel = strings.TrimSpace(result.BillingModel)
|
||||
}
|
||||
if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" {
|
||||
billingModel = input.ChannelMappedModel
|
||||
}
|
||||
if input.BillingModelSource == "requested" && input.OriginalModel != "" {
|
||||
billingModel = input.OriginalModel
|
||||
}
|
||||
@ -4255,8 +4260,10 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
ImageOutputTokens: result.Usage.ImageOutputTokens,
|
||||
InputCost: cost.InputCost,
|
||||
OutputCost: cost.OutputCost,
|
||||
ImageOutputCost: cost.ImageOutputCost,
|
||||
CacheCreationCost: cost.CacheCreationCost,
|
||||
CacheReadCost: cost.CacheReadCost,
|
||||
TotalCost: cost.TotalCost,
|
||||
|
||||
@ -70,7 +70,8 @@ type LiteLLMModelPricing struct {
|
||||
LiteLLMProvider string `json:"litellm_provider"`
|
||||
Mode string `json:"mode"`
|
||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||
OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格
|
||||
OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格
|
||||
OutputCostPerImageToken float64 `json:"output_cost_per_image_token"` // 图片输出 token 价格
|
||||
}
|
||||
|
||||
// PricingRemoteClient 远程价格数据获取接口
|
||||
@ -94,6 +95,7 @@ type LiteLLMRawEntry struct {
|
||||
Mode string `json:"mode"`
|
||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||
OutputCostPerImage *float64 `json:"output_cost_per_image"`
|
||||
OutputCostPerImageToken *float64 `json:"output_cost_per_image_token"`
|
||||
}
|
||||
|
||||
// PricingService 动态价格服务
|
||||
@ -408,6 +410,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel
|
||||
if entry.OutputCostPerImage != nil {
|
||||
pricing.OutputCostPerImage = *entry.OutputCostPerImage
|
||||
}
|
||||
if entry.OutputCostPerImageToken != nil {
|
||||
pricing.OutputCostPerImageToken = *entry.OutputCostPerImageToken
|
||||
}
|
||||
|
||||
result[modelName] = pricing
|
||||
}
|
||||
|
||||
@ -134,6 +134,9 @@ type UsageLog struct {
|
||||
CacheCreation5mTokens int `gorm:"column:cache_creation_5m_tokens"`
|
||||
CacheCreation1hTokens int `gorm:"column:cache_creation_1h_tokens"`
|
||||
|
||||
ImageOutputTokens int
|
||||
ImageOutputCost float64
|
||||
|
||||
InputCost float64
|
||||
OutputCost float64
|
||||
CacheCreationCost float64
|
||||
|
||||
@ -0,0 +1,3 @@
|
||||
-- Change default billing_model_source for new channels to 'channel_mapped'
|
||||
-- Existing channels keep their current setting (no UPDATE on existing rows)
|
||||
ALTER TABLE channels ALTER COLUMN billing_model_source SET DEFAULT 'channel_mapped';
|
||||
2
backend/migrations/089_usage_log_image_output_tokens.sql
Normal file
2
backend/migrations/089_usage_log_image_output_tokens.sql
Normal file
@ -0,0 +1,2 @@
|
||||
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS image_output_tokens INTEGER NOT NULL DEFAULT 0;
|
||||
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS image_output_cost DECIMAL(20, 10) NOT NULL DEFAULT 0;
|
||||
@ -134,6 +134,7 @@ export interface ModelDefaultPricing {
|
||||
output_price?: number
|
||||
cache_write_price?: number
|
||||
cache_read_price?: number
|
||||
image_output_price?: number
|
||||
}
|
||||
|
||||
export async function getModelDefaultPricing(model: string): Promise<ModelDefaultPricing> {
|
||||
|
||||
@ -328,6 +328,7 @@ async function onModelsUpdate(newModels: string[]) {
|
||||
output_price: perTokenToMTok(result.output_price ?? null),
|
||||
cache_write_price: perTokenToMTok(result.cache_write_price ?? null),
|
||||
cache_read_price: perTokenToMTok(result.cache_read_price ?? null),
|
||||
image_output_price: perTokenToMTok(result.image_output_price ?? null),
|
||||
})
|
||||
}
|
||||
} catch {
|
||||
|
||||
@ -1800,6 +1800,7 @@ export default {
|
||||
mappingSource: 'Source model',
|
||||
mappingTarget: 'Target model',
|
||||
billingModelSource: 'Billing Model',
|
||||
billingModelSourceChannelMapped: 'Bill by channel-mapped model',
|
||||
billingModelSourceRequested: 'Bill by requested model',
|
||||
billingModelSourceUpstream: 'Bill by final upstream model',
|
||||
billingModelSourceHint: 'Controls which model name is used for pricing lookup',
|
||||
|
||||
@ -1880,6 +1880,7 @@ export default {
|
||||
mappingSource: '源模型',
|
||||
mappingTarget: '目标模型',
|
||||
billingModelSource: '计费基准',
|
||||
billingModelSourceChannelMapped: '以渠道映射后的模型计费',
|
||||
billingModelSourceRequested: '以请求模型计费',
|
||||
billingModelSourceUpstream: '以最终模型计费',
|
||||
billingModelSourceHint: '控制使用哪个模型名称进行定价查找',
|
||||
|
||||
@ -471,6 +471,7 @@ const statusEditOptions = computed(() => [
|
||||
])
|
||||
|
||||
const billingModelSourceOptions = computed(() => [
|
||||
{ value: 'channel_mapped', label: t('admin.channels.form.billingModelSourceChannelMapped', 'Bill by channel-mapped model') },
|
||||
{ value: 'requested', label: t('admin.channels.form.billingModelSourceRequested', 'Bill by requested model') },
|
||||
{ value: 'upstream', label: t('admin.channels.form.billingModelSourceUpstream', 'Bill by final upstream model') }
|
||||
])
|
||||
@ -504,7 +505,7 @@ const form = reactive({
|
||||
description: '',
|
||||
status: 'active',
|
||||
restrict_models: false,
|
||||
billing_model_source: 'requested' as string,
|
||||
billing_model_source: 'channel_mapped' as string,
|
||||
platforms: [] as PlatformSection[]
|
||||
})
|
||||
|
||||
@ -819,7 +820,7 @@ function resetForm() {
|
||||
form.description = ''
|
||||
form.status = 'active'
|
||||
form.restrict_models = false
|
||||
form.billing_model_source = 'requested'
|
||||
form.billing_model_source = 'channel_mapped'
|
||||
form.platforms = []
|
||||
activeTab.value = 'basic'
|
||||
}
|
||||
@ -837,7 +838,7 @@ async function openEditDialog(channel: Channel) {
|
||||
form.description = channel.description || ''
|
||||
form.status = channel.status
|
||||
form.restrict_models = channel.restrict_models || false
|
||||
form.billing_model_source = channel.billing_model_source || 'requested'
|
||||
form.billing_model_source = channel.billing_model_source || 'channel_mapped'
|
||||
// Must load groups first so apiToForm can map groupID → platform
|
||||
await loadGroups()
|
||||
form.platforms = apiToForm(channel)
|
||||
@ -932,7 +933,7 @@ async function handleSubmit() {
|
||||
status: form.status,
|
||||
group_ids,
|
||||
model_pricing,
|
||||
model_mapping: Object.keys(model_mapping).length > 0 ? model_mapping : undefined,
|
||||
model_mapping: Object.keys(model_mapping).length > 0 ? model_mapping : {},
|
||||
billing_model_source: form.billing_model_source,
|
||||
restrict_models: form.restrict_models
|
||||
}
|
||||
@ -944,7 +945,7 @@ async function handleSubmit() {
|
||||
description: form.description.trim() || undefined,
|
||||
group_ids,
|
||||
model_pricing,
|
||||
model_mapping: Object.keys(model_mapping).length > 0 ? model_mapping : undefined,
|
||||
model_mapping: Object.keys(model_mapping).length > 0 ? model_mapping : {},
|
||||
billing_model_source: form.billing_model_source,
|
||||
restrict_models: form.restrict_models
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user