Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 77 additions & 39 deletions internal/virtualmodels/batch_preparer.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,35 +25,86 @@ func NewBatchPreparer(provider core.RoutableProvider, service *Service) *BatchPr
// PrepareBatchRequest rewrites redirect sources for inline and file-backed batch
// items and validates model access for each resolved selector.
func (p *BatchPreparer) PrepareBatchRequest(ctx context.Context, providerType string, req *core.BatchRequest) (*core.BatchRewriteResult, error) {
return rewriteBatchSource(ctx, providerType, req, p.service, p.provider, p.batchFileTransport(), p.validateAccess)
}

// validateAccess enforces the access policy for one resolved batch selector.
func (p *BatchPreparer) validateAccess(ctx context.Context, resolved core.ModelSelector) error {
if p.service == nil {
return nil
}
return p.service.ValidateModelAccess(ctx, resolved)
}

// batchFileTransport returns the provider's native file transport when it can
// rewrite file-backed batch requests directly.
func (p *BatchPreparer) batchFileTransport() core.BatchFileTransport {
if p == nil || p.provider == nil {
return nil
}
if files, ok := p.provider.(core.NativeFileRoutableProvider); ok {
return files
}
return nil
}

// rewriteBatchSource resolves redirects for inline and file-backed batch items
// and rewrites each for upstream submission. validate, when non-nil, is called
// with the resolved selector before rewriting — the server-side preparer enforces
// access there; the provider wrapper passes nil.
func rewriteBatchSource(
ctx context.Context,
providerType string,
req *core.BatchRequest,
service *Service,
checker modelSupportChecker,
fileTransport core.BatchFileTransport,
validate func(context.Context, core.ModelSelector) error,
) (*core.BatchRewriteResult, error) {
return core.RewriteBatchSource(
ctx,
providerType,
req,
p.batchFileTransport(),
fileTransport,
[]core.Operation{core.OperationChatCompletions, core.OperationResponses, core.OperationEmbeddings},
func(ctx context.Context, _ core.BatchRequestItem, decoded *core.DecodedBatchItemRequest) (json.RawMessage, error) {
requested, err := requestedSelectorForDecodedRequest(decoded.Request)
if err != nil {
return nil, err
}
// Resolve the redirect target and verify catalog support + single
// provider per batch, mirroring the alias rewrite pass.
resolved, err := resolveRedirectRoutableSelector(ctx, p.service, p.provider, requested, providerType)
if err != nil {
return nil, err
}
// Validate access against the resolved selector, mirroring the
// access-override pass.
if p.service != nil {
if err := p.service.ValidateModelAccess(ctx, resolved); err != nil {
return nil, err
}
}
return rewriteDecodedBatchItem(decoded.Request, resolved)
return rewriteBatchItem(ctx, service, checker, providerType, decoded, validate)
},
)
}

// rewriteBatchItem resolves one decoded batch item's redirect (verifying catalog
// support and single-provider-per-batch), optionally validates access, then
// re-encodes the item for upstream. It is the single per-item rewrite shared by
// the provider wrapper and the server-side preparer.
func rewriteBatchItem(
ctx context.Context,
service *Service,
checker modelSupportChecker,
providerType string,
decoded *core.DecodedBatchItemRequest,
validate func(context.Context, core.ModelSelector) error,
) (json.RawMessage, error) {
requested, err := decoded.RequestedModelSelector()
if err != nil {
return nil, core.NewInvalidRequestError(err.Error(), err)
}
// resolveRedirectRoutableSelector is user-path aware (scoped redirects), so a
// caller outside a scoped alias's user_paths gets the literal name here too.
resolved, err := resolveRedirectRoutableSelector(ctx, service, checker, requested, providerType)
if err != nil {
return nil, err
}
if validate != nil {
if err := validate(ctx, resolved); err != nil {
return nil, err
}
}
return rewriteDecodedBatchItem(decoded.Request, resolved)
Comment thread
SantiagoDePolonia marked this conversation as resolved.
}

// rewriteDecodedBatchItem writes the resolved model into a supported decoded
// batch request and clears the provider before upstream submission.
func rewriteDecodedBatchItem(request any, resolved core.ModelSelector) (json.RawMessage, error) {
switch typed := request.(type) {
case *core.ChatRequest:
Expand All @@ -76,25 +127,12 @@ func rewriteDecodedBatchItem(request any, resolved core.ModelSelector) (json.Raw
}
}

func requestedSelectorForDecodedRequest(request any) (core.RequestedModelSelector, error) {
switch typed := request.(type) {
case *core.ChatRequest:
return core.NewRequestedModelSelector(typed.Model, typed.Provider), nil
case *core.ResponsesRequest:
return core.NewRequestedModelSelector(typed.Model, typed.Provider), nil
case *core.EmbeddingRequest:
return core.NewRequestedModelSelector(typed.Model, typed.Provider), nil
default:
return core.RequestedModelSelector{}, core.NewInvalidRequestError("unsupported batch item request", nil)
}
}

func (p *BatchPreparer) batchFileTransport() core.BatchFileTransport {
if p == nil || p.provider == nil {
return nil
// marshalBatchItem encodes a rewritten batch item as JSON for the upstream
// provider payload.
func marshalBatchItem(v any) (json.RawMessage, error) {
body, err := json.Marshal(v)
if err != nil {
return nil, core.NewInvalidRequestError("failed to encode batch item", err)
}
if files, ok := p.provider.(core.NativeFileRoutableProvider); ok {
return files
}
return nil
return body, nil
}
154 changes: 154 additions & 0 deletions internal/virtualmodels/batch_preparer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package virtualmodels

import (
"context"
"encoding/json"
"errors"
"strings"
"testing"

"gomodel/internal/core"
)

// decodedChatItem builds a decoded chat batch request for per-item rewrite tests.
func decodedChatItem(model, provider string) *core.DecodedBatchItemRequest {
return &core.DecodedBatchItemRequest{
Endpoint: "/v1/chat/completions",
Request: &core.ChatRequest{Model: model, Provider: provider},
}
}

// newRedirectService creates a service with the "fast" redirect used by batch
// rewrite tests.
func newRedirectService(t *testing.T) *Service {
t.Helper()
svc := newTestService(t)
if err := svc.Upsert(context.Background(), VirtualModel{
Source: "fast",
Targets: []Target{{Provider: "openai", Model: "gpt-4o"}},
Enabled: true,
}); err != nil {
t.Fatalf("Upsert(redirect) error = %v", err)
}
return svc
}

// requireGatewayError asserts the gateway error contract while returning the
// typed error for any additional test-specific checks.
func requireGatewayError(t *testing.T, err error, wantType core.ErrorType, wantCode string) *core.GatewayError {
t.Helper()
var gatewayErr *core.GatewayError
if !errors.As(err, &gatewayErr) {
t.Fatalf("error type = %T, want *core.GatewayError", err)
}
if gatewayErr.Type != wantType {
t.Fatalf("error type = %q, want %q", gatewayErr.Type, wantType)
}
if wantCode != "" {
if gatewayErr.Code == nil {
t.Fatalf("error code = nil, want %q", wantCode)
}
if *gatewayErr.Code != wantCode {
t.Fatalf("error code = %q, want %q", *gatewayErr.Code, wantCode)
}
}
return gatewayErr
}

// Provider-wrapper-style call: nil validation, redirect rewritten and the
// per-item provider cleared before upstream submission.
func TestRewriteBatchItem_RewritesAndClearsProvider(t *testing.T) {
t.Parallel()
// No explicit provider on the item, so the "fast" redirect applies; the
// resolved target (openai/gpt-4o) is written as the model with the provider
// cleared for upstream.
body, err := rewriteBatchItem(context.Background(), newRedirectService(t), testCatalog(), "", decodedChatItem("fast", ""), nil)
if err != nil {
t.Fatalf("rewriteBatchItem() error = %v", err)
}
var out core.ChatRequest
if err := json.Unmarshal(body, &out); err != nil {
t.Fatalf("unmarshal error = %v", err)
}
if out.Model != "gpt-4o" {
t.Fatalf("rewritten model = %q, want gpt-4o (redirect resolved)", out.Model)
}
if out.Provider != "" {
t.Fatalf("rewritten provider = %q, want empty (cleared for upstream)", out.Provider)
}
}

// Server-side preparer call: the validate hook denies an unauthorized resolved
// selector and the error is surfaced.
func TestRewriteBatchItem_ValidateRejectsUnauthorized(t *testing.T) {
t.Parallel()
denied := errors.New("denied")
var validated core.ModelSelector
_, err := rewriteBatchItem(context.Background(), newRedirectService(t), testCatalog(), "", decodedChatItem("fast", ""),
func(_ context.Context, resolved core.ModelSelector) error {
validated = resolved
return denied
})
if !errors.Is(err, denied) {
t.Fatalf("rewriteBatchItem() error = %v, want denied", err)
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
if validated.Provider != "openai" || validated.Model != "gpt-4o" {
t.Fatalf("validated selector = %q/%q, want openai/gpt-4o", validated.Provider, validated.Model)
}
}

// A malformed / unsupported batch item is rejected rather than silently passed.
func TestRewriteBatchItem_UnsupportedItem(t *testing.T) {
t.Parallel()
decoded := &core.DecodedBatchItemRequest{Endpoint: "/v1/unknown", Request: "not a request"}
_, err := rewriteBatchItem(context.Background(), newRedirectService(t), testCatalog(), "", decoded, nil)
if err == nil {
t.Fatal("rewriteBatchItem(unsupported item) error = nil, want error")
}
_ = requireGatewayError(t, err, core.ErrorTypeInvalidRequest, "")
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

// Native batch is single-provider: a resolved target whose provider differs from
// the batch provider is rejected.
func TestRewriteBatchItem_RejectsCrossProviderBatch(t *testing.T) {
t.Parallel()
_, err := rewriteBatchItem(context.Background(), newRedirectService(t), testCatalog(), "anthropic", decodedChatItem("fast", ""), nil)
if err == nil {
t.Fatal("rewriteBatchItem(cross-provider batch) error = nil, want single-provider-per-batch error")
}
gatewayErr := requireGatewayError(t, err, core.ErrorTypeInvalidRequest, "")
if !strings.Contains(gatewayErr.Message, "single provider per batch") {
t.Fatalf("rewriteBatchItem(cross-provider batch) error = %q, want single-provider reason", gatewayErr.Message)
}
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

// BatchPreparer.validateAccess enforces the access policy; a nil-service preparer
// (provider-wrapper parity) never blocks.
func TestBatchPreparerValidateAccess(t *testing.T) {
t.Parallel()
ctx := context.Background()
selector := core.ModelSelector{Provider: "openai", Model: "gpt-4o"}

enabledSvc := newTestService(t)
if err := enabledSvc.Upsert(ctx, VirtualModel{Source: "openai/gpt-4o", Enabled: true}); err != nil {
t.Fatalf("Upsert(enabled policy) error = %v", err)
}
if err := NewBatchPreparer(nil, enabledSvc).validateAccess(ctx, selector); err != nil {
t.Fatalf("validateAccess(enabled model) error = %v, want nil", err)
}

svc := newTestService(t)
if err := svc.Upsert(ctx, VirtualModel{Source: "openai/gpt-4o", Enabled: false}); err != nil {
t.Fatalf("Upsert(disabled policy) error = %v", err)
}

err := NewBatchPreparer(nil, svc).validateAccess(ctx, selector)
if err == nil {
t.Fatal("validateAccess(disabled model) error = nil, want denied")
}
_ = requireGatewayError(t, err, core.ErrorTypeInvalidRequest, "model_access_denied")

if err := (&BatchPreparer{}).validateAccess(ctx, selector); err != nil {
t.Fatalf("validateAccess(nil service) error = %v, want nil", err)
}
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Loading