diff --git a/crates/openshell-server/migrations/postgres/001_create_objects.sql b/crates/openshell-server/migrations/postgres/001_create_objects.sql index 0d2cb17da..3978f8456 100644 --- a/crates/openshell-server/migrations/postgres/001_create_objects.sql +++ b/crates/openshell-server/migrations/postgres/001_create_objects.sql @@ -1,10 +1,29 @@ CREATE TABLE IF NOT EXISTS objects ( - object_type TEXT NOT NULL, - id TEXT NOT NULL, - name TEXT NOT NULL, - payload BYTEA NOT NULL, + id TEXT PRIMARY KEY, + object_type TEXT NOT NULL, + name TEXT, + scope TEXT, + version BIGINT, + status TEXT, + dedup_key TEXT, + hit_count BIGINT NOT NULL DEFAULT 0, + payload BYTEA NOT NULL, created_at_ms BIGINT NOT NULL, - updated_at_ms BIGINT NOT NULL, - PRIMARY KEY (id), - UNIQUE (object_type, name) + updated_at_ms BIGINT NOT NULL ); + +CREATE UNIQUE INDEX IF NOT EXISTS objects_name_uq + ON objects (object_type, name) + WHERE name IS NOT NULL; + +CREATE UNIQUE INDEX IF NOT EXISTS objects_version_uq + ON objects (object_type, scope, version) + WHERE scope IS NOT NULL AND version IS NOT NULL; + +CREATE INDEX IF NOT EXISTS objects_scope_status_idx + ON objects (object_type, scope, status, version) + WHERE scope IS NOT NULL; + +CREATE UNIQUE INDEX IF NOT EXISTS objects_dedup_uq + ON objects (object_type, scope, dedup_key) + WHERE dedup_key IS NOT NULL; diff --git a/crates/openshell-server/migrations/postgres/002_create_sandbox_policies.sql b/crates/openshell-server/migrations/postgres/002_create_sandbox_policies.sql deleted file mode 100644 index 1fdd49b05..000000000 --- a/crates/openshell-server/migrations/postgres/002_create_sandbox_policies.sql +++ /dev/null @@ -1,15 +0,0 @@ -CREATE TABLE IF NOT EXISTS sandbox_policies ( - id TEXT PRIMARY KEY, - sandbox_id TEXT NOT NULL, - version INTEGER NOT NULL, - policy_payload BYTEA NOT NULL, - policy_hash TEXT NOT NULL, - status TEXT NOT NULL DEFAULT 'pending', - load_error TEXT, - created_at_ms BIGINT NOT NULL, - loaded_at_ms BIGINT, - UNIQUE (sandbox_id, version) -); - -CREATE INDEX IF NOT EXISTS idx_sandbox_policies_lookup - ON sandbox_policies (sandbox_id, version DESC); diff --git a/crates/openshell-server/migrations/postgres/003_create_policy_recommendations.sql b/crates/openshell-server/migrations/postgres/003_create_policy_recommendations.sql deleted file mode 100644 index 0c6fa09ee..000000000 --- a/crates/openshell-server/migrations/postgres/003_create_policy_recommendations.sql +++ /dev/null @@ -1,33 +0,0 @@ --- Draft policy chunks: proposed network policy rules awaiting user approval. --- --- One row per (sandbox_id, host, port, binary). The toggle model allows: --- pending -> approved | rejected (initial decision) --- approved <-> rejected (toggle via approve/revoke) --- --- Upserts bump hit_count / last_seen_ms when the same denial recurs. -CREATE TABLE IF NOT EXISTS draft_policy_chunks ( - id TEXT PRIMARY KEY, - sandbox_id TEXT NOT NULL, - draft_version BIGINT NOT NULL, - status TEXT NOT NULL DEFAULT 'pending', - rule_name TEXT NOT NULL, - proposed_rule BYTEA NOT NULL, - rationale TEXT NOT NULL DEFAULT '', - security_notes TEXT NOT NULL DEFAULT '', - confidence DOUBLE PRECISION NOT NULL DEFAULT 0.0, - host TEXT NOT NULL DEFAULT '', - port INTEGER NOT NULL DEFAULT 0, - binary TEXT NOT NULL DEFAULT '', - hit_count INTEGER NOT NULL DEFAULT 1, - first_seen_ms BIGINT NOT NULL, - last_seen_ms BIGINT NOT NULL, - created_at_ms BIGINT NOT NULL, - decided_at_ms BIGINT -); - -CREATE INDEX IF NOT EXISTS idx_draft_chunks_sandbox - ON draft_policy_chunks (sandbox_id, status); - -CREATE UNIQUE INDEX IF NOT EXISTS idx_draft_chunks_endpoint - ON draft_policy_chunks (sandbox_id, host, port, binary) - WHERE status IN ('pending', 'approved', 'rejected'); diff --git a/crates/openshell-server/migrations/sqlite/001_create_objects.sql b/crates/openshell-server/migrations/sqlite/001_create_objects.sql index 26a77e9da..749a6460e 100644 --- a/crates/openshell-server/migrations/sqlite/001_create_objects.sql +++ b/crates/openshell-server/migrations/sqlite/001_create_objects.sql @@ -1,10 +1,29 @@ CREATE TABLE IF NOT EXISTS objects ( - object_type TEXT NOT NULL, - id TEXT NOT NULL, - name TEXT NOT NULL, - payload BLOB NOT NULL, + id TEXT PRIMARY KEY, + object_type TEXT NOT NULL, + name TEXT, + scope TEXT, + version INTEGER, + status TEXT, + dedup_key TEXT, + hit_count INTEGER NOT NULL DEFAULT 0, + payload BLOB NOT NULL, created_at_ms INTEGER NOT NULL, - updated_at_ms INTEGER NOT NULL, - PRIMARY KEY (id), - UNIQUE (object_type, name) + updated_at_ms INTEGER NOT NULL ); + +CREATE UNIQUE INDEX IF NOT EXISTS objects_name_uq + ON objects (object_type, name) + WHERE name IS NOT NULL; + +CREATE UNIQUE INDEX IF NOT EXISTS objects_version_uq + ON objects (object_type, scope, version) + WHERE scope IS NOT NULL AND version IS NOT NULL; + +CREATE INDEX IF NOT EXISTS objects_scope_status_idx + ON objects (object_type, scope, status, version) + WHERE scope IS NOT NULL; + +CREATE UNIQUE INDEX IF NOT EXISTS objects_dedup_uq + ON objects (object_type, scope, dedup_key) + WHERE dedup_key IS NOT NULL; diff --git a/crates/openshell-server/migrations/sqlite/002_create_sandbox_policies.sql b/crates/openshell-server/migrations/sqlite/002_create_sandbox_policies.sql deleted file mode 100644 index 395b1897a..000000000 --- a/crates/openshell-server/migrations/sqlite/002_create_sandbox_policies.sql +++ /dev/null @@ -1,15 +0,0 @@ -CREATE TABLE IF NOT EXISTS sandbox_policies ( - id TEXT PRIMARY KEY, - sandbox_id TEXT NOT NULL, - version INTEGER NOT NULL, - policy_payload BLOB NOT NULL, - policy_hash TEXT NOT NULL, - status TEXT NOT NULL DEFAULT 'pending', - load_error TEXT, - created_at_ms INTEGER NOT NULL, - loaded_at_ms INTEGER, - UNIQUE (sandbox_id, version) -); - -CREATE INDEX IF NOT EXISTS idx_sandbox_policies_lookup - ON sandbox_policies (sandbox_id, version DESC); diff --git a/crates/openshell-server/migrations/sqlite/003_create_policy_recommendations.sql b/crates/openshell-server/migrations/sqlite/003_create_policy_recommendations.sql deleted file mode 100644 index 069ce2048..000000000 --- a/crates/openshell-server/migrations/sqlite/003_create_policy_recommendations.sql +++ /dev/null @@ -1,35 +0,0 @@ --- Draft policy chunks: proposed network policy rules awaiting user approval. --- --- One row per (sandbox_id, host, port, binary). The toggle model allows: --- pending -> approved | rejected (initial decision) --- approved <-> rejected (toggle via approve/revoke) --- --- Upserts bump hit_count / last_seen_ms when the same denial recurs. -CREATE TABLE IF NOT EXISTS draft_policy_chunks ( - id TEXT PRIMARY KEY, - sandbox_id TEXT NOT NULL, - draft_version INTEGER NOT NULL, - status TEXT NOT NULL DEFAULT 'pending', - rule_name TEXT NOT NULL, - proposed_rule BLOB NOT NULL, - rationale TEXT NOT NULL DEFAULT '', - security_notes TEXT NOT NULL DEFAULT '', - confidence REAL NOT NULL DEFAULT 0.0, - host TEXT NOT NULL DEFAULT '', - port INTEGER NOT NULL DEFAULT 0, - binary TEXT NOT NULL DEFAULT '', - hit_count INTEGER NOT NULL DEFAULT 1, - first_seen_ms INTEGER NOT NULL, - last_seen_ms INTEGER NOT NULL, - created_at_ms INTEGER NOT NULL, - decided_at_ms INTEGER -); - -CREATE INDEX IF NOT EXISTS idx_draft_chunks_sandbox - ON draft_policy_chunks (sandbox_id, status); - --- Only one active chunk per (sandbox, endpoint, binary). Covers all three --- statuses so rejected chunks block duplicate proposals until re-approved. -CREATE UNIQUE INDEX IF NOT EXISTS idx_draft_chunks_endpoint - ON draft_policy_chunks (sandbox_id, host, port, binary) - WHERE status IN ('pending', 'approved', 'rejected'); diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs index 19cfd5faf..9770ab0d5 100644 --- a/crates/openshell-server/src/compute/mod.rs +++ b/crates/openshell-server/src/compute/mod.rs @@ -7,7 +7,7 @@ pub mod vm; pub use vm::VmComputeConfig; -use crate::grpc::policy::{SANDBOX_SETTINGS_OBJECT_TYPE, sandbox_settings_id}; +use crate::grpc::policy::SANDBOX_SETTINGS_OBJECT_TYPE; use crate::persistence::{ObjectId, ObjectName, ObjectRecord, ObjectType, Store}; use crate::sandbox_index::SandboxIndex; use crate::sandbox_watch::SandboxWatchBus; @@ -406,7 +406,7 @@ impl ComputeRuntime { if let Err(e) = self .store - .delete(SANDBOX_SETTINGS_OBJECT_TYPE, &sandbox_settings_id(&id)) + .delete_by_name(SANDBOX_SETTINGS_OBJECT_TYPE, &sandbox.name) .await { warn!( @@ -1183,6 +1183,115 @@ fn is_terminal_failure_reason(reason: &str) -> bool { !transient_reasons.contains(&reason.as_str()) } +#[cfg(test)] +#[derive(Debug, Default)] +pub(crate) struct NoopTestDriver; + +#[cfg(test)] +#[tonic::async_trait] +impl ComputeDriver for NoopTestDriver { + type WatchSandboxesStream = DriverWatchStream; + + async fn get_capabilities( + &self, + _request: Request, + ) -> Result, Status> + { + Ok(tonic::Response::new( + openshell_core::proto::compute::v1::GetCapabilitiesResponse { + driver_name: "noop-test-driver".to_string(), + driver_version: "test".to_string(), + default_image: "openshell/sandbox:test".to_string(), + supports_gpu: false, + }, + )) + } + + async fn validate_sandbox_create( + &self, + _request: Request, + ) -> Result< + tonic::Response, + Status, + > { + Ok(tonic::Response::new( + openshell_core::proto::compute::v1::ValidateSandboxCreateResponse {}, + )) + } + + async fn get_sandbox( + &self, + _request: Request, + ) -> Result, Status> + { + Err(Status::not_found("sandbox not found")) + } + + async fn list_sandboxes( + &self, + _request: Request, + ) -> Result, Status> + { + Ok(tonic::Response::new( + openshell_core::proto::compute::v1::ListSandboxesResponse { + sandboxes: Vec::new(), + }, + )) + } + + async fn create_sandbox( + &self, + _request: Request, + ) -> Result, Status> + { + Ok(tonic::Response::new( + openshell_core::proto::compute::v1::CreateSandboxResponse {}, + )) + } + + async fn stop_sandbox( + &self, + _request: Request, + ) -> Result, Status> + { + Ok(tonic::Response::new( + openshell_core::proto::compute::v1::StopSandboxResponse {}, + )) + } + + async fn delete_sandbox( + &self, + _request: Request, + ) -> Result, Status> + { + Ok(tonic::Response::new( + openshell_core::proto::compute::v1::DeleteSandboxResponse { deleted: true }, + )) + } + + async fn watch_sandboxes( + &self, + _request: Request, + ) -> Result, Status> { + Ok(tonic::Response::new(Box::pin(futures::stream::empty()))) + } +} + +#[cfg(test)] +pub(crate) async fn new_test_runtime(store: Arc) -> ComputeRuntime { + ComputeRuntime { + driver: Arc::new(NoopTestDriver), + _driver_process: None, + default_image: "openshell/sandbox:test".to_string(), + store, + sandbox_index: SandboxIndex::new(), + sandbox_watch_bus: SandboxWatchBus::new(), + tracing_log_bus: TracingLogBus::new(), + supervisor_sessions: Arc::new(SupervisorSessionRegistry::new()), + sync_lock: Arc::new(Mutex::new(())), + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index 8ef8cb5c7..ea4daa700 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -62,8 +62,6 @@ use super::{MAX_PAGE_SIZE, StoredSettingValue, StoredSettings, clamp_limit, curr /// Internal object type for durable gateway-global settings. const GLOBAL_SETTINGS_OBJECT_TYPE: &str = "gateway_settings"; -/// Internal object id for the singleton global settings record. -const GLOBAL_SETTINGS_ID: &str = "gateway_settings:global"; const GLOBAL_SETTINGS_NAME: &str = "global"; /// Internal object type for durable sandbox-scoped settings. pub(crate) const SANDBOX_SETTINGS_OBJECT_TYPE: &str = "sandbox_settings"; @@ -386,7 +384,7 @@ pub(super) async fn handle_get_sandbox_config( }; let global_settings = load_global_settings(state.store.as_ref()).await?; - let sandbox_settings = load_sandbox_settings(state.store.as_ref(), &sandbox_id).await?; + let sandbox_settings = load_sandbox_settings(state.store.as_ref(), &sandbox.name).await?; let mut global_policy_version: u32 = 0; @@ -685,17 +683,12 @@ pub(super) async fn handle_update_config( } let mut sandbox_settings = - load_sandbox_settings(state.store.as_ref(), &sandbox_id).await?; + load_sandbox_settings(state.store.as_ref(), &sandbox.name).await?; let removed = sandbox_settings.settings.remove(key).is_some(); if removed { sandbox_settings.revision = sandbox_settings.revision.wrapping_add(1); - save_sandbox_settings( - state.store.as_ref(), - &sandbox_id, - &sandbox.name, - &sandbox_settings, - ) - .await?; + save_sandbox_settings(state.store.as_ref(), &sandbox.name, &sandbox_settings) + .await?; } return Ok(Response::new(UpdateConfigResponse { @@ -718,17 +711,12 @@ pub(super) async fn handle_update_config( .ok_or_else(|| Status::invalid_argument("setting_value is required"))?; let stored = proto_setting_to_stored(key, setting)?; - let mut sandbox_settings = load_sandbox_settings(state.store.as_ref(), &sandbox_id).await?; + let mut sandbox_settings = + load_sandbox_settings(state.store.as_ref(), &sandbox.name).await?; let changed = upsert_setting_value(&mut sandbox_settings.settings, key, stored); if changed { sandbox_settings.revision = sandbox_settings.revision.wrapping_add(1); - save_sandbox_settings( - state.store.as_ref(), - &sandbox_id, - &sandbox.name, - &sandbox_settings, - ) - .await?; + save_sandbox_settings(state.store.as_ref(), &sandbox.name, &sandbox_settings).await?; } return Ok(Response::new(UpdateConfigResponse { @@ -2258,8 +2246,7 @@ async fn apply_merge_operations_with_retry( return Ok((next_version, hash)); } Err(e) => { - let msg = e.to_string(); - if msg.contains("UNIQUE") || msg.contains("unique") || msg.contains("duplicate") { + if e.is_unique_violation_on("objects_version_uq") { warn!( sandbox_id = %sandbox_id, attempt, @@ -2400,7 +2387,7 @@ fn upsert_setting_value( } pub(super) async fn load_global_settings(store: &Store) -> Result { - load_settings_record(store, GLOBAL_SETTINGS_OBJECT_TYPE, GLOBAL_SETTINGS_ID).await + load_settings_record(store, GLOBAL_SETTINGS_OBJECT_TYPE, GLOBAL_SETTINGS_NAME).await } pub(super) async fn save_global_settings( @@ -2410,53 +2397,34 @@ pub(super) async fn save_global_settings( save_settings_record( store, GLOBAL_SETTINGS_OBJECT_TYPE, - GLOBAL_SETTINGS_ID, GLOBAL_SETTINGS_NAME, settings, ) .await } -/// Derive a distinct settings record ID from a sandbox UUID. -pub(crate) fn sandbox_settings_id(sandbox_id: &str) -> String { - format!("settings:{sandbox_id}") -} - pub(super) async fn load_sandbox_settings( store: &Store, - sandbox_id: &str, + sandbox_name: &str, ) -> Result { - load_settings_record( - store, - SANDBOX_SETTINGS_OBJECT_TYPE, - &sandbox_settings_id(sandbox_id), - ) - .await + load_settings_record(store, SANDBOX_SETTINGS_OBJECT_TYPE, sandbox_name).await } pub(super) async fn save_sandbox_settings( store: &Store, - sandbox_id: &str, sandbox_name: &str, settings: &StoredSettings, ) -> Result<(), Status> { - save_settings_record( - store, - SANDBOX_SETTINGS_OBJECT_TYPE, - &sandbox_settings_id(sandbox_id), - sandbox_name, - settings, - ) - .await + save_settings_record(store, SANDBOX_SETTINGS_OBJECT_TYPE, sandbox_name, settings).await } async fn load_settings_record( store: &Store, object_type: &str, - id: &str, + name: &str, ) -> Result { let record = store - .get(object_type, id) + .get_by_name(object_type, name) .await .map_err(|e| Status::internal(format!("fetch settings failed: {e}")))?; if let Some(record) = record { @@ -2470,14 +2438,18 @@ async fn load_settings_record( async fn save_settings_record( store: &Store, object_type: &str, - id: &str, name: &str, settings: &StoredSettings, ) -> Result<(), Status> { let payload = serde_json::to_vec(settings) .map_err(|e| Status::internal(format!("encode settings payload failed: {e}")))?; store - .put(object_type, id, name, &payload) + .put( + object_type, + &uuid::Uuid::new_v4().to_string(), + name, + &payload, + ) .await .map_err(|e| Status::internal(format!("persist settings failed: {e}")))?; Ok(()) @@ -2576,7 +2548,14 @@ fn materialize_global_settings( #[cfg(test)] mod tests { use super::*; + use crate::ServerState; + use crate::compute::new_test_runtime; use crate::persistence::Store; + use crate::sandbox_index::SandboxIndex; + use crate::sandbox_watch::SandboxWatchBus; + use crate::supervisor_session::SupervisorSessionRegistry; + use crate::tracing_bus::TracingLogBus; + use openshell_core::Config; use std::collections::HashMap; use std::sync::Arc; use tonic::Code; @@ -2667,6 +2646,229 @@ mod tests { assert_eq!(policy.process.unwrap().run_as_user, "sandbox"); } + async fn test_server_state() -> Arc { + let store = Arc::new( + Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(), + ); + let compute = new_test_runtime(store.clone()).await; + Arc::new(ServerState::new( + Config::new(None) + .with_database_url("sqlite::memory:?cache=shared") + .with_ssh_handshake_secret("test-secret"), + store, + compute, + SandboxIndex::new(), + SandboxWatchBus::new(), + TracingLogBus::new(), + Arc::new(SupervisorSessionRegistry::new()), + )) + } + + #[tokio::test] + async fn draft_chunk_handler_lifecycle_round_trip() { + use openshell_core::proto::{ + GetDraftPolicyRequest, NetworkBinary, NetworkEndpoint, SandboxPhase, SandboxSpec, + }; + + let state = test_server_state().await; + let sandbox = Sandbox { + id: "sb-draft-flow".to_string(), + name: "draft-flow".to_string(), + namespace: "default".to_string(), + spec: Some(SandboxSpec { + policy: None, + ..Default::default() + }), + phase: SandboxPhase::Ready as i32, + ..Default::default() + }; + state.store.put_message(&sandbox).await.unwrap(); + + let proposed_rule = NetworkPolicyRule { + name: "allow_example".to_string(), + endpoints: vec![NetworkEndpoint { + host: "api.example.com".to_string(), + port: 443, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/curl".to_string(), + ..Default::default() + }], + }; + + let submit = handle_submit_policy_analysis( + &state, + Request::new(SubmitPolicyAnalysisRequest { + name: sandbox.name.clone(), + proposed_chunks: vec![PolicyChunk { + rule_name: "allow_example".to_string(), + proposed_rule: Some(proposed_rule.clone()), + rationale: "observed denied request".to_string(), + confidence: 0.85, + hit_count: 3, + first_seen_ms: 100, + last_seen_ms: 200, + binary: "/usr/bin/curl".to_string(), + ..Default::default() + }], + ..Default::default() + }), + ) + .await + .unwrap() + .into_inner(); + assert_eq!(submit.accepted_chunks, 1); + assert_eq!(submit.rejected_chunks, 0); + + let draft_policy = handle_get_draft_policy( + &state, + Request::new(GetDraftPolicyRequest { + name: sandbox.name.clone(), + status_filter: String::new(), + }), + ) + .await + .unwrap() + .into_inner(); + assert_eq!(draft_policy.draft_version, 1); + assert_eq!(draft_policy.chunks.len(), 1); + assert_eq!(draft_policy.chunks[0].status, "pending"); + let chunk_id = draft_policy.chunks[0].id.clone(); + + let approve = handle_approve_draft_chunk( + &state, + Request::new(ApproveDraftChunkRequest { + name: sandbox.name.clone(), + chunk_id: chunk_id.clone(), + }), + ) + .await + .unwrap() + .into_inner(); + assert_eq!(approve.policy_version, 1); + assert!(!approve.policy_hash.is_empty()); + + let history_after_approve = handle_get_draft_history( + &state, + Request::new(GetDraftHistoryRequest { + name: sandbox.name.clone(), + }), + ) + .await + .unwrap() + .into_inner(); + assert_eq!(history_after_approve.entries.len(), 2); + assert_eq!(history_after_approve.entries[0].event_type, "proposed"); + assert_eq!(history_after_approve.entries[1].event_type, "approved"); + assert_eq!(history_after_approve.entries[1].chunk_id, chunk_id); + + let policies_after_approve = handle_list_sandbox_policies( + &state, + Request::new(ListSandboxPoliciesRequest { + name: sandbox.name.clone(), + limit: 10, + offset: 0, + global: false, + }), + ) + .await + .unwrap() + .into_inner(); + assert_eq!(policies_after_approve.revisions.len(), 1); + assert_eq!(policies_after_approve.revisions[0].version, 1); + + let undo = handle_undo_draft_chunk( + &state, + Request::new(UndoDraftChunkRequest { + name: sandbox.name.clone(), + chunk_id: chunk_id.clone(), + }), + ) + .await + .unwrap() + .into_inner(); + assert_eq!(undo.policy_version, 2); + assert!(!undo.policy_hash.is_empty()); + + let draft_policy_after_undo = handle_get_draft_policy( + &state, + Request::new(GetDraftPolicyRequest { + name: sandbox.name.clone(), + status_filter: String::new(), + }), + ) + .await + .unwrap() + .into_inner(); + assert_eq!(draft_policy_after_undo.chunks.len(), 1); + assert_eq!(draft_policy_after_undo.chunks[0].status, "pending"); + + let history_after_undo = handle_get_draft_history( + &state, + Request::new(GetDraftHistoryRequest { + name: sandbox.name.clone(), + }), + ) + .await + .unwrap() + .into_inner(); + assert_eq!(history_after_undo.entries.len(), 1); + assert_eq!(history_after_undo.entries[0].event_type, "proposed"); + + let policies_after_undo = handle_list_sandbox_policies( + &state, + Request::new(ListSandboxPoliciesRequest { + name: sandbox.name.clone(), + limit: 10, + offset: 0, + global: false, + }), + ) + .await + .unwrap() + .into_inner(); + assert_eq!(policies_after_undo.revisions.len(), 2); + assert_eq!(policies_after_undo.revisions[0].version, 2); + assert_eq!(policies_after_undo.revisions[1].version, 1); + + let cleared = handle_clear_draft_chunks( + &state, + Request::new(ClearDraftChunksRequest { + name: sandbox.name.clone(), + }), + ) + .await + .unwrap() + .into_inner(); + assert_eq!(cleared.chunks_cleared, 1); + + let draft_policy_after_clear = handle_get_draft_policy( + &state, + Request::new(GetDraftPolicyRequest { + name: sandbox.name.clone(), + status_filter: String::new(), + }), + ) + .await + .unwrap() + .into_inner(); + assert!(draft_policy_after_clear.chunks.is_empty()); + + let history_after_clear = handle_get_draft_history( + &state, + Request::new(GetDraftHistoryRequest { + name: sandbox.name.clone(), + }), + ) + .await + .unwrap() + .into_inner(); + assert!(history_after_clear.entries.is_empty()); + } + #[test] fn build_gateway_policy_audit_message_formats_ocsf_config_line() { let message = build_gateway_policy_audit_message( @@ -3392,25 +3594,9 @@ mod tests { } #[test] - fn sandbox_settings_id_has_prefix_preventing_collision() { - let sandbox_id = "abc-123"; - let settings_id = sandbox_settings_id(sandbox_id); - assert!(settings_id.starts_with("settings:")); - assert_ne!(settings_id, sandbox_id); - } - - #[test] - fn sandbox_settings_id_different_sandboxes_produce_different_ids() { - let id_a = sandbox_settings_id("sandbox-1"); - let id_b = sandbox_settings_id("sandbox-2"); - assert_ne!(id_a, id_b); - } - - #[test] - fn sandbox_settings_id_embeds_sandbox_id() { - let sandbox_id = "some-uuid-value"; - let settings_id = sandbox_settings_id(sandbox_id); - assert!(settings_id.contains(sandbox_id)); + fn sandbox_settings_names_match_sandbox_names() { + let sandbox_name = "my-sandbox"; + assert_eq!(sandbox_name, "my-sandbox"); } // ---- compute_config_revision ---- @@ -3620,17 +3806,17 @@ mod tests { .await .unwrap(); - let sandbox_id = "sb-uuid-123"; + let sandbox_name = "my-sandbox"; let mut settings = StoredSettings::default(); settings .settings .insert("dummy_int".to_string(), StoredSettingValue::Int(99)); settings.revision = 3; - save_sandbox_settings(&store, sandbox_id, "my-sandbox", &settings) + save_sandbox_settings(&store, sandbox_name, &settings) .await .unwrap(); - let loaded = load_sandbox_settings(&store, sandbox_id).await.unwrap(); + let loaded = load_sandbox_settings(&store, sandbox_name).await.unwrap(); assert_eq!(loaded.revision, 3); assert_eq!( loaded.settings.get("dummy_int"), @@ -3775,8 +3961,8 @@ mod tests { let loaded = load_global_settings(&store).await.unwrap(); assert!(!loaded.settings.contains_key("log_level")); - let sandbox_id = "test-sandbox-uuid"; - let mut sandbox_settings = load_sandbox_settings(&store, sandbox_id).await.unwrap(); + let sandbox_name = "test-sandbox"; + let mut sandbox_settings = load_sandbox_settings(&store, sandbox_name).await.unwrap(); let changed = upsert_setting_value( &mut sandbox_settings.settings, "log_level", @@ -3784,11 +3970,11 @@ mod tests { ); assert!(changed); sandbox_settings.revision = sandbox_settings.revision.wrapping_add(1); - save_sandbox_settings(&store, sandbox_id, "test-sandbox", &sandbox_settings) + save_sandbox_settings(&store, sandbox_name, &sandbox_settings) .await .unwrap(); - let reloaded = load_sandbox_settings(&store, sandbox_id).await.unwrap(); + let reloaded = load_sandbox_settings(&store, sandbox_name).await.unwrap(); assert_eq!( reloaded.settings.get("log_level"), Some(&StoredSettingValue::String("debug".to_string())), diff --git a/crates/openshell-server/src/persistence/mod.rs b/crates/openshell-server/src/persistence/mod.rs index 5cd36693b..f6bcda8b4 100644 --- a/crates/openshell-server/src/persistence/mod.rs +++ b/crates/openshell-server/src/persistence/mod.rs @@ -6,14 +6,68 @@ mod postgres; mod sqlite; -use openshell_core::{Error, Result}; +use openshell_core::{ + Error as CoreError, Result as CoreResult, + proto::{ + DraftChunkPayload, NetworkPolicyRule, PolicyRevisionPayload, + SandboxPolicy as ProtoSandboxPolicy, + }, +}; use prost::Message; use rand::Rng; use std::time::{SystemTime, UNIX_EPOCH}; +use thiserror::Error; pub use postgres::PostgresStore; pub use sqlite::SqliteStore; +pub type PersistenceResult = Result; + +/// Persistence-layer error type. +#[derive(Debug, Error, Clone)] +pub enum PersistenceError { + #[error("configuration error: {0}")] + Config(String), + #[error("database error: {0}")] + Database(String), + #[error("migration error: {0}")] + Migration(String), + #[error("decode error: {0}")] + Decode(String), + #[error("encode error: {0}")] + Encode(String), + #[error("unique violation{constraint_msg}")] + UniqueViolation { + constraint: Option, + detail: Option, + constraint_msg: String, + }, +} + +impl PersistenceError { + pub fn unique_violation(constraint: Option, detail: Option) -> Self { + let constraint_msg = constraint + .as_ref() + .map(|value| format!(" on {value}")) + .unwrap_or_default(); + Self::UniqueViolation { + constraint, + detail, + constraint_msg, + } + } + + pub fn is_unique_violation_on(&self, constraint: &str) -> bool { + matches!( + self, + Self::UniqueViolation { + constraint: Some(value), + .. + } if value == constraint + ) + } +} + /// Stored object record. #[derive(Debug, Clone)] pub struct ObjectRecord { @@ -39,6 +93,34 @@ pub struct PolicyRecord { pub loaded_at_ms: Option, } +/// Stored draft policy chunk record. +#[derive(Debug, Clone)] +pub struct DraftChunkRecord { + pub id: String, + pub sandbox_id: String, + pub draft_version: i64, + pub status: String, + pub rule_name: String, + pub proposed_rule: Vec, + pub rationale: String, + pub security_notes: String, + pub confidence: f64, + pub created_at_ms: i64, + pub decided_at_ms: Option, + /// Denormalized endpoint host (lowercase) for DB-level dedup. + pub host: String, + /// Denormalized endpoint port for DB-level dedup. + pub port: i32, + /// Binary path that triggered the denial (for per-binary dedup). + pub binary: String, + /// How many times this endpoint has been seen across denial flush cycles. + pub hit_count: i32, + /// First time this endpoint was proposed (ms since epoch). + pub first_seen_ms: i64, + /// Most recent time this endpoint was re-proposed (ms since epoch). + pub last_seen_ms: i64, +} + /// Persistence store implementations. #[derive(Debug, Clone)] pub enum Store { @@ -71,24 +153,40 @@ pub fn generate_name() -> String { impl Store { /// Connect to a persistence store based on the database URL. - pub async fn connect(url: &str) -> Result { + pub async fn connect(url: &str) -> CoreResult { if url.starts_with("postgres://") || url.starts_with("postgresql://") { - let store = PostgresStore::connect(url).await?; - store.migrate().await?; + let store = PostgresStore::connect(url) + .await + .map_err(|e| CoreError::execution(e.to_string()))?; + store + .migrate() + .await + .map_err(|e| CoreError::execution(e.to_string()))?; Ok(Self::Postgres(store)) } else if url.starts_with("sqlite:") { - let store = SqliteStore::connect(url).await?; - store.migrate().await?; + let store = SqliteStore::connect(url) + .await + .map_err(|e| CoreError::execution(e.to_string()))?; + store + .migrate() + .await + .map_err(|e| CoreError::execution(e.to_string()))?; Ok(Self::Sqlite(store)) } else { - Err(Error::config(format!( + Err(CoreError::config(format!( "unsupported database URL scheme: {url}" ))) } } - /// Insert or update an object. - pub async fn put(&self, object_type: &str, id: &str, name: &str, payload: &[u8]) -> Result<()> { + /// Insert or update a generic named object. + pub async fn put( + &self, + object_type: &str, + id: &str, + name: &str, + payload: &[u8], + ) -> PersistenceResult<()> { match self { Self::Postgres(store) => store.put(object_type, id, name, payload).await, Self::Sqlite(store) => store.put(object_type, id, name, payload).await, @@ -96,7 +194,11 @@ impl Store { } /// Fetch an object by id. - pub async fn get(&self, object_type: &str, id: &str) -> Result> { + pub async fn get( + &self, + object_type: &str, + id: &str, + ) -> PersistenceResult> { match self { Self::Postgres(store) => store.get(object_type, id).await, Self::Sqlite(store) => store.get(object_type, id).await, @@ -104,7 +206,11 @@ impl Store { } /// Fetch an object by name within an object type. - pub async fn get_by_name(&self, object_type: &str, name: &str) -> Result> { + pub async fn get_by_name( + &self, + object_type: &str, + name: &str, + ) -> PersistenceResult> { match self { Self::Postgres(store) => store.get_by_name(object_type, name).await, Self::Sqlite(store) => store.get_by_name(object_type, name).await, @@ -112,7 +218,7 @@ impl Store { } /// Delete an object by id. - pub async fn delete(&self, object_type: &str, id: &str) -> Result { + pub async fn delete(&self, object_type: &str, id: &str) -> PersistenceResult { match self { Self::Postgres(store) => store.delete(object_type, id).await, Self::Sqlite(store) => store.delete(object_type, id).await, @@ -120,7 +226,7 @@ impl Store { } /// Delete an object by name within an object type. - pub async fn delete_by_name(&self, object_type: &str, name: &str) -> Result { + pub async fn delete_by_name(&self, object_type: &str, name: &str) -> PersistenceResult { match self { Self::Postgres(store) => store.delete_by_name(object_type, name).await, Self::Sqlite(store) => store.delete_by_name(object_type, name).await, @@ -133,17 +239,13 @@ impl Store { object_type: &str, limit: u32, offset: u32, - ) -> Result> { + ) -> PersistenceResult> { match self { Self::Postgres(store) => store.list(object_type, limit, offset).await, Self::Sqlite(store) => store.list(object_type, limit, offset).await, } } - // ----------------------------------------------------------------------- - // Policy revision operations - // ----------------------------------------------------------------------- - /// Insert a new policy revision. pub async fn put_policy_revision( &self, @@ -152,7 +254,7 @@ impl Store { version: i64, payload: &[u8], hash: &str, - ) -> Result<()> { + ) -> PersistenceResult<()> { match self { Self::Postgres(store) => { store @@ -168,7 +270,10 @@ impl Store { } /// Get the latest policy revision for a sandbox (by highest version, any status). - pub async fn get_latest_policy(&self, sandbox_id: &str) -> Result> { + pub async fn get_latest_policy( + &self, + sandbox_id: &str, + ) -> PersistenceResult> { match self { Self::Postgres(store) => store.get_latest_policy(sandbox_id).await, Self::Sqlite(store) => store.get_latest_policy(sandbox_id).await, @@ -176,7 +281,10 @@ impl Store { } /// Get the latest loaded policy revision for a sandbox. - pub async fn get_latest_loaded_policy(&self, sandbox_id: &str) -> Result> { + pub async fn get_latest_loaded_policy( + &self, + sandbox_id: &str, + ) -> PersistenceResult> { match self { Self::Postgres(store) => store.get_latest_loaded_policy(sandbox_id).await, Self::Sqlite(store) => store.get_latest_loaded_policy(sandbox_id).await, @@ -188,7 +296,7 @@ impl Store { &self, sandbox_id: &str, version: i64, - ) -> Result> { + ) -> PersistenceResult> { match self { Self::Postgres(store) => store.get_policy_by_version(sandbox_id, version).await, Self::Sqlite(store) => store.get_policy_by_version(sandbox_id, version).await, @@ -201,7 +309,7 @@ impl Store { sandbox_id: &str, limit: u32, offset: u32, - ) -> Result> { + ) -> PersistenceResult> { match self { Self::Postgres(store) => store.list_policies(sandbox_id, limit, offset).await, Self::Sqlite(store) => store.list_policies(sandbox_id, limit, offset).await, @@ -216,7 +324,7 @@ impl Store { status: &str, load_error: Option<&str>, loaded_at_ms: Option, - ) -> Result { + ) -> PersistenceResult { match self { Self::Postgres(store) => { store @@ -236,7 +344,7 @@ impl Store { &self, sandbox_id: &str, before_version: i64, - ) -> Result { + ) -> PersistenceResult { match self { Self::Postgres(store) => { store @@ -251,60 +359,8 @@ impl Store { } } - // ----------------------------------------------------------------------- - // Generic protobuf message helpers - // ----------------------------------------------------------------------- - - /// Insert or update a protobuf message using its inferred object type, id, and name. - pub async fn put_message( - &self, - message: &T, - ) -> Result<()> { - self.put( - T::object_type(), - message.object_id(), - message.object_name(), - &message.encode_to_vec(), - ) - .await - } - - /// Fetch and decode a protobuf message by id. - pub async fn get_message( - &self, - id: &str, - ) -> Result> { - let record = self.get(T::object_type(), id).await?; - let Some(record) = record else { - return Ok(None); - }; - - T::decode(record.payload.as_slice()) - .map(Some) - .map_err(|e| Error::execution(format!("protobuf decode error: {e}"))) - } - - /// Fetch and decode a protobuf message by name. - pub async fn get_message_by_name( - &self, - name: &str, - ) -> Result> { - let record = self.get_by_name(T::object_type(), name).await?; - let Some(record) = record else { - return Ok(None); - }; - - T::decode(record.payload.as_slice()) - .map(Some) - .map_err(|e| Error::execution(format!("protobuf decode error: {e}"))) - } - - // ----------------------------------------------------------------------- - // Draft policy chunk operations - // ----------------------------------------------------------------------- - - /// Insert a new draft policy chunk. - pub async fn put_draft_chunk(&self, chunk: &DraftChunkRecord) -> Result<()> { + /// Insert or merge a new draft policy chunk. + pub async fn put_draft_chunk(&self, chunk: &DraftChunkRecord) -> PersistenceResult<()> { match self { Self::Postgres(store) => store.put_draft_chunk(chunk).await, Self::Sqlite(store) => store.put_draft_chunk(chunk).await, @@ -312,7 +368,7 @@ impl Store { } /// Fetch a single draft chunk by id. - pub async fn get_draft_chunk(&self, id: &str) -> Result> { + pub async fn get_draft_chunk(&self, id: &str) -> PersistenceResult> { match self { Self::Postgres(store) => store.get_draft_chunk(id).await, Self::Sqlite(store) => store.get_draft_chunk(id).await, @@ -324,7 +380,7 @@ impl Store { &self, sandbox_id: &str, status_filter: Option<&str>, - ) -> Result> { + ) -> PersistenceResult> { match self { Self::Postgres(store) => store.list_draft_chunks(sandbox_id, status_filter).await, Self::Sqlite(store) => store.list_draft_chunks(sandbox_id, status_filter).await, @@ -337,7 +393,7 @@ impl Store { id: &str, status: &str, decided_at_ms: Option, - ) -> Result { + ) -> PersistenceResult { match self { Self::Postgres(store) => { store @@ -353,7 +409,11 @@ impl Store { } /// Update the proposed rule on a pending draft chunk. - pub async fn update_draft_chunk_rule(&self, id: &str, proposed_rule: &[u8]) -> Result { + pub async fn update_draft_chunk_rule( + &self, + id: &str, + proposed_rule: &[u8], + ) -> PersistenceResult { match self { Self::Postgres(store) => store.update_draft_chunk_rule(id, proposed_rule).await, Self::Sqlite(store) => store.update_draft_chunk_rule(id, proposed_rule).await, @@ -361,7 +421,11 @@ impl Store { } /// Delete all draft chunks for a sandbox with a given status. - pub async fn delete_draft_chunks(&self, sandbox_id: &str, status: &str) -> Result { + pub async fn delete_draft_chunks( + &self, + sandbox_id: &str, + status: &str, + ) -> PersistenceResult { match self { Self::Postgres(store) => store.delete_draft_chunks(sandbox_id, status).await, Self::Sqlite(store) => store.delete_draft_chunks(sandbox_id, status).await, @@ -369,56 +433,199 @@ impl Store { } /// Get the current maximum draft version for a sandbox. - pub async fn get_draft_version(&self, sandbox_id: &str) -> Result { + pub async fn get_draft_version(&self, sandbox_id: &str) -> PersistenceResult { match self { Self::Postgres(store) => store.get_draft_version(sandbox_id).await, Self::Sqlite(store) => store.get_draft_version(sandbox_id).await, } } + + /// Insert or update a protobuf message using its inferred object type, id, and name. + pub async fn put_message( + &self, + message: &T, + ) -> PersistenceResult<()> { + self.put( + T::object_type(), + message.object_id(), + message.object_name(), + &message.encode_to_vec(), + ) + .await + } + + /// Fetch and decode a protobuf message by id. + pub async fn get_message( + &self, + id: &str, + ) -> PersistenceResult> { + let record = self.get(T::object_type(), id).await?; + let Some(record) = record else { + return Ok(None); + }; + + T::decode(record.payload.as_slice()) + .map(Some) + .map_err(|e| PersistenceError::Decode(format!("protobuf decode error: {e}"))) + } + + /// Fetch and decode a protobuf message by name. + pub async fn get_message_by_name( + &self, + name: &str, + ) -> PersistenceResult> { + let record = self.get_by_name(T::object_type(), name).await?; + let Some(record) = record else { + return Ok(None); + }; + + T::decode(record.payload.as_slice()) + .map(Some) + .map_err(|e| PersistenceError::Decode(format!("protobuf decode error: {e}"))) + } } -/// Stored draft policy chunk record. -#[derive(Debug, Clone)] -pub struct DraftChunkRecord { - pub id: String, - pub sandbox_id: String, - pub draft_version: i64, - pub status: String, - pub rule_name: String, - pub proposed_rule: Vec, - pub rationale: String, - pub security_notes: String, - pub confidence: f64, - pub created_at_ms: i64, - pub decided_at_ms: Option, - /// Denormalized endpoint host (lowercase) for DB-level dedup. - pub host: String, - /// Denormalized endpoint port for DB-level dedup. - pub port: i32, - /// Binary path that triggered the denial (for per-binary dedup). - pub binary: String, - /// How many times this endpoint has been seen across denial flush cycles. - pub hit_count: i32, - /// First time this endpoint was proposed (ms since epoch). - pub first_seen_ms: i64, - /// Most recent time this endpoint was re-proposed (ms since epoch). - pub last_seen_ms: i64, +pub(crate) fn policy_payload_from_record(record: &PolicyRecord) -> PersistenceResult> { + let policy = ProtoSandboxPolicy::decode(record.policy_payload.as_slice()) + .map_err(|e| PersistenceError::Decode(format!("decode policy payload failed: {e}")))?; + Ok(PolicyRevisionPayload { + policy: Some(policy), + hash: record.policy_hash.clone(), + load_error: record.load_error.clone().unwrap_or_default(), + loaded_at_ms: record.loaded_at_ms.unwrap_or(0), + } + .encode_to_vec()) +} + +pub(crate) fn policy_record_from_parts( + id: String, + sandbox_id: String, + version: i64, + status: String, + payload: &[u8], + created_at_ms: i64, +) -> PersistenceResult { + let wrapper = PolicyRevisionPayload::decode(payload) + .map_err(|e| PersistenceError::Decode(format!("decode policy wrapper failed: {e}")))?; + let policy = wrapper + .policy + .ok_or_else(|| PersistenceError::Decode("policy wrapper missing policy".to_string()))?; + Ok(PolicyRecord { + id, + sandbox_id, + version, + policy_payload: policy.encode_to_vec(), + policy_hash: wrapper.hash, + status, + load_error: if wrapper.load_error.is_empty() { + None + } else { + Some(wrapper.load_error) + }, + created_at_ms, + loaded_at_ms: (wrapper.loaded_at_ms > 0).then_some(wrapper.loaded_at_ms), + }) } -fn current_time_ms() -> Result { +pub(crate) fn draft_chunk_payload_from_record( + chunk: &DraftChunkRecord, +) -> PersistenceResult> { + let proposed_rule = if chunk.proposed_rule.is_empty() { + None + } else { + Some( + NetworkPolicyRule::decode(chunk.proposed_rule.as_slice()) + .map_err(|e| PersistenceError::Decode(format!("decode draft rule failed: {e}")))?, + ) + }; + Ok(DraftChunkPayload { + rule_name: chunk.rule_name.clone(), + proposed_rule, + rationale: chunk.rationale.clone(), + security_notes: chunk.security_notes.clone(), + confidence: chunk.confidence as f32, + decided_at_ms: chunk.decided_at_ms.unwrap_or(0), + host: chunk.host.clone(), + port: chunk.port, + binary: chunk.binary.clone(), + draft_version: chunk.draft_version, + } + .encode_to_vec()) +} + +pub(crate) fn draft_chunk_record_from_parts( + id: String, + sandbox_id: String, + status: String, + hit_count: i64, + payload: &[u8], + created_at_ms: i64, + updated_at_ms: i64, +) -> PersistenceResult { + let wrapper = DraftChunkPayload::decode(payload) + .map_err(|e| PersistenceError::Decode(format!("decode draft chunk wrapper failed: {e}")))?; + let proposed_rule = wrapper + .proposed_rule + .map(|rule| rule.encode_to_vec()) + .unwrap_or_default(); + Ok(DraftChunkRecord { + id, + sandbox_id, + draft_version: wrapper.draft_version, + status, + rule_name: wrapper.rule_name, + proposed_rule, + rationale: wrapper.rationale, + security_notes: wrapper.security_notes, + confidence: f64::from(wrapper.confidence), + created_at_ms, + decided_at_ms: (wrapper.decided_at_ms > 0).then_some(wrapper.decided_at_ms), + host: wrapper.host, + port: wrapper.port, + binary: wrapper.binary, + hit_count: i32::try_from(hit_count).unwrap_or(i32::MAX), + first_seen_ms: created_at_ms, + last_seen_ms: updated_at_ms, + }) +} + +fn current_time_ms() -> PersistenceResult { let now = SystemTime::now() .duration_since(UNIX_EPOCH) - .map_err(|e| Error::execution(format!("time error: {e}")))?; + .map_err(|e| PersistenceError::Database(format!("time error: {e}")))?; i64::try_from(now.as_millis()) - .map_err(|e| Error::execution(format!("time conversion error: {e}"))) + .map_err(|e| PersistenceError::Database(format!("time conversion error: {e}"))) } -fn map_db_error(error: &sqlx::Error) -> Error { - Error::execution(format!("database error: {error}")) +fn map_db_error(error: &sqlx::Error) -> PersistenceError { + if let sqlx::Error::Database(db) = error { + if db.is_unique_violation() { + let constraint = db + .constraint() + .map(ToString::to_string) + .or_else(|| infer_sqlite_unique_constraint(db.message())); + return PersistenceError::unique_violation(constraint, Some(db.message().to_string())); + } + } + PersistenceError::Database(error.to_string()) +} + +fn infer_sqlite_unique_constraint(message: &str) -> Option { + if message.contains("objects.object_type, objects.scope, objects.version") { + Some("objects_version_uq".to_string()) + } else if message.contains("objects.object_type, objects.scope, objects.dedup_key") { + Some("objects_dedup_uq".to_string()) + } else if message.contains("objects.object_type, objects.name") { + Some("objects_name_uq".to_string()) + } else if message.contains("objects.id") { + Some("objects_pkey".to_string()) + } else { + None + } } -fn map_migrate_error(error: &sqlx::migrate::MigrateError) -> Error { - Error::execution(format!("migration error: {error}")) +fn map_migrate_error(error: &sqlx::migrate::MigrateError) -> PersistenceError { + PersistenceError::Migration(error.to_string()) } #[cfg(test)] diff --git a/crates/openshell-server/src/persistence/postgres.rs b/crates/openshell-server/src/persistence/postgres.rs index 509b028d7..4408710c7 100644 --- a/crates/openshell-server/src/persistence/postgres.rs +++ b/crates/openshell-server/src/persistence/postgres.rs @@ -2,21 +2,25 @@ // SPDX-License-Identifier: Apache-2.0 use super::{ - DraftChunkRecord, ObjectRecord, PolicyRecord, current_time_ms, map_db_error, map_migrate_error, + DraftChunkRecord, ObjectRecord, PersistenceResult, PolicyRecord, current_time_ms, + draft_chunk_payload_from_record, draft_chunk_record_from_parts, map_db_error, + map_migrate_error, policy_payload_from_record, policy_record_from_parts, }; -use openshell_core::Result; use sqlx::postgres::PgPoolOptions; use sqlx::{PgPool, Row}; static POSTGRES_MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("./migrations/postgres"); +const POLICY_OBJECT_TYPE: &str = "sandbox_policy"; +const DRAFT_CHUNK_OBJECT_TYPE: &str = "draft_policy_chunk"; + #[derive(Debug, Clone)] pub struct PostgresStore { pool: PgPool, } impl PostgresStore { - pub async fn connect(url: &str) -> Result { + pub async fn connect(url: &str) -> PersistenceResult { let pool = PgPoolOptions::new() .max_connections(10) .connect(url) @@ -26,23 +30,28 @@ impl PostgresStore { Ok(Self { pool }) } - pub async fn migrate(&self) -> Result<()> { + pub async fn migrate(&self) -> PersistenceResult<()> { POSTGRES_MIGRATOR .run(&self.pool) .await .map_err(|e| map_migrate_error(&e)) } - pub async fn put(&self, object_type: &str, id: &str, name: &str, payload: &[u8]) -> Result<()> { + pub async fn put( + &self, + object_type: &str, + id: &str, + name: &str, + payload: &[u8], + ) -> PersistenceResult<()> { let now_ms = current_time_ms()?; sqlx::query( r" INSERT INTO objects (object_type, id, name, payload, created_at_ms, updated_at_ms) VALUES ($1, $2, $3, $4, $5, $5) -ON CONFLICT (id) DO UPDATE SET +ON CONFLICT (object_type, name) WHERE name IS NOT NULL DO UPDATE SET payload = EXCLUDED.payload, updated_at_ms = EXCLUDED.updated_at_ms -WHERE objects.object_type = EXCLUDED.object_type ", ) .bind(object_type) @@ -56,7 +65,11 @@ WHERE objects.object_type = EXCLUDED.object_type Ok(()) } - pub async fn get(&self, object_type: &str, id: &str) -> Result> { + pub async fn get( + &self, + object_type: &str, + id: &str, + ) -> PersistenceResult> { let row = sqlx::query( r" SELECT object_type, id, name, payload, created_at_ms, updated_at_ms @@ -70,17 +83,14 @@ WHERE object_type = $1 AND id = $2 .await .map_err(|e| map_db_error(&e))?; - Ok(row.map(|row| ObjectRecord { - object_type: row.get("object_type"), - id: row.get("id"), - name: row.get("name"), - payload: row.get("payload"), - created_at_ms: row.get("created_at_ms"), - updated_at_ms: row.get("updated_at_ms"), - })) + Ok(row.map(row_to_object_record)) } - pub async fn get_by_name(&self, object_type: &str, name: &str) -> Result> { + pub async fn get_by_name( + &self, + object_type: &str, + name: &str, + ) -> PersistenceResult> { let row = sqlx::query( r" SELECT object_type, id, name, payload, created_at_ms, updated_at_ms @@ -94,17 +104,10 @@ WHERE object_type = $1 AND name = $2 .await .map_err(|e| map_db_error(&e))?; - Ok(row.map(|row| ObjectRecord { - object_type: row.get("object_type"), - id: row.get("id"), - name: row.get("name"), - payload: row.get("payload"), - created_at_ms: row.get("created_at_ms"), - updated_at_ms: row.get("updated_at_ms"), - })) + Ok(row.map(row_to_object_record)) } - pub async fn delete(&self, object_type: &str, id: &str) -> Result { + pub async fn delete(&self, object_type: &str, id: &str) -> PersistenceResult { let result = sqlx::query("DELETE FROM objects WHERE object_type = $1 AND id = $2") .bind(object_type) .bind(id) @@ -114,7 +117,7 @@ WHERE object_type = $1 AND name = $2 Ok(result.rows_affected() > 0) } - pub async fn delete_by_name(&self, object_type: &str, name: &str) -> Result { + pub async fn delete_by_name(&self, object_type: &str, name: &str) -> PersistenceResult { let result = sqlx::query("DELETE FROM objects WHERE object_type = $1 AND name = $2") .bind(object_type) .bind(name) @@ -129,7 +132,7 @@ WHERE object_type = $1 AND name = $2 object_type: &str, limit: u32, offset: u32, - ) -> Result> { + ) -> PersistenceResult> { let rows = sqlx::query( r" SELECT object_type, id, name, payload, created_at_ms, updated_at_ms @@ -146,25 +149,9 @@ LIMIT $2 OFFSET $3 .await .map_err(|e| map_db_error(&e))?; - let records = rows - .into_iter() - .map(|row| ObjectRecord { - object_type: row.get("object_type"), - id: row.get("id"), - name: row.get("name"), - payload: row.get("payload"), - created_at_ms: row.get("created_at_ms"), - updated_at_ms: row.get("updated_at_ms"), - }) - .collect(); - - Ok(records) + Ok(rows.into_iter().map(row_to_object_record).collect()) } - // ------------------------------------------------------------------- - // Policy revision operations - // ------------------------------------------------------------------- - pub async fn put_policy_revision( &self, id: &str, @@ -172,19 +159,35 @@ LIMIT $2 OFFSET $3 version: i64, payload: &[u8], hash: &str, - ) -> Result<()> { + ) -> PersistenceResult<()> { let now_ms = current_time_ms()?; + let record = PolicyRecord { + id: id.to_string(), + sandbox_id: sandbox_id.to_string(), + version, + policy_payload: payload.to_vec(), + policy_hash: hash.to_string(), + status: "pending".to_string(), + load_error: None, + created_at_ms: now_ms, + loaded_at_ms: None, + }; + let wrapped_payload = policy_payload_from_record(&record)?; + sqlx::query( r" -INSERT INTO sandbox_policies (id, sandbox_id, version, policy_payload, policy_hash, status, created_at_ms) -VALUES ($1, $2, $3, $4, $5, 'pending', $6) +INSERT INTO objects ( + object_type, id, scope, version, status, payload, created_at_ms, updated_at_ms +) +VALUES ($1, $2, $3, $4, $5, $6, $7, $7) ", ) + .bind(POLICY_OBJECT_TYPE) .bind(id) .bind(sandbox_id) .bind(version) - .bind(payload) - .bind(hash) + .bind("pending") + .bind(wrapped_payload) .bind(now_ms) .execute(&self.pool) .await @@ -192,61 +195,70 @@ VALUES ($1, $2, $3, $4, $5, 'pending', $6) Ok(()) } - pub async fn get_latest_policy(&self, sandbox_id: &str) -> Result> { + pub async fn get_latest_policy( + &self, + sandbox_id: &str, + ) -> PersistenceResult> { let row = sqlx::query( r" -SELECT id, sandbox_id, version, policy_payload, policy_hash, status, load_error, created_at_ms, loaded_at_ms -FROM sandbox_policies -WHERE sandbox_id = $1 -ORDER BY version DESC +SELECT id, scope, version, status, payload, created_at_ms +FROM objects +WHERE object_type = $1 AND scope = $2 +ORDER BY version DESC, created_at_ms DESC LIMIT 1 ", ) + .bind(POLICY_OBJECT_TYPE) .bind(sandbox_id) .fetch_optional(&self.pool) .await .map_err(|e| map_db_error(&e))?; - Ok(row.map(row_to_policy_record)) + row.map(row_to_policy_record).transpose() } - pub async fn get_latest_loaded_policy(&self, sandbox_id: &str) -> Result> { + pub async fn get_latest_loaded_policy( + &self, + sandbox_id: &str, + ) -> PersistenceResult> { let row = sqlx::query( r" -SELECT id, sandbox_id, version, policy_payload, policy_hash, status, load_error, created_at_ms, loaded_at_ms -FROM sandbox_policies -WHERE sandbox_id = $1 AND status = 'loaded' -ORDER BY version DESC +SELECT id, scope, version, status, payload, created_at_ms +FROM objects +WHERE object_type = $1 AND scope = $2 AND status = 'loaded' +ORDER BY version DESC, created_at_ms DESC LIMIT 1 ", ) + .bind(POLICY_OBJECT_TYPE) .bind(sandbox_id) .fetch_optional(&self.pool) .await .map_err(|e| map_db_error(&e))?; - Ok(row.map(row_to_policy_record)) + row.map(row_to_policy_record).transpose() } pub async fn get_policy_by_version( &self, sandbox_id: &str, version: i64, - ) -> Result> { + ) -> PersistenceResult> { let row = sqlx::query( r" -SELECT id, sandbox_id, version, policy_payload, policy_hash, status, load_error, created_at_ms, loaded_at_ms -FROM sandbox_policies -WHERE sandbox_id = $1 AND version = $2 +SELECT id, scope, version, status, payload, created_at_ms +FROM objects +WHERE object_type = $1 AND scope = $2 AND version = $3 ", ) + .bind(POLICY_OBJECT_TYPE) .bind(sandbox_id) .bind(version) .fetch_optional(&self.pool) .await .map_err(|e| map_db_error(&e))?; - Ok(row.map(row_to_policy_record)) + row.map(row_to_policy_record).transpose() } pub async fn list_policies( @@ -254,16 +266,17 @@ WHERE sandbox_id = $1 AND version = $2 sandbox_id: &str, limit: u32, offset: u32, - ) -> Result> { + ) -> PersistenceResult> { let rows = sqlx::query( r" -SELECT id, sandbox_id, version, policy_payload, policy_hash, status, load_error, created_at_ms, loaded_at_ms -FROM sandbox_policies -WHERE sandbox_id = $1 -ORDER BY version DESC -LIMIT $2 OFFSET $3 +SELECT id, scope, version, status, payload, created_at_ms +FROM objects +WHERE object_type = $1 AND scope = $2 +ORDER BY version DESC, created_at_ms DESC +LIMIT $3 OFFSET $4 ", ) + .bind(POLICY_OBJECT_TYPE) .bind(sandbox_id) .bind(i64::from(limit)) .bind(i64::from(offset)) @@ -271,7 +284,7 @@ LIMIT $2 OFFSET $3 .await .map_err(|e| map_db_error(&e))?; - Ok(rows.into_iter().map(row_to_policy_record).collect()) + rows.into_iter().map(row_to_policy_record).collect() } pub async fn update_policy_status( @@ -281,19 +294,30 @@ LIMIT $2 OFFSET $3 status: &str, load_error: Option<&str>, loaded_at_ms: Option, - ) -> Result { + ) -> PersistenceResult { + let Some(mut record) = self.get_policy_by_version(sandbox_id, version).await? else { + return Ok(false); + }; + + record.status = status.to_string(); + record.load_error = load_error.map(ToOwned::to_owned); + record.loaded_at_ms = loaded_at_ms; + let payload = policy_payload_from_record(&record)?; + let now_ms = current_time_ms()?; + let result = sqlx::query( r" -UPDATE sandbox_policies -SET status = $3, load_error = $4, loaded_at_ms = $5 -WHERE sandbox_id = $1 AND version = $2 +UPDATE objects +SET status = $4, payload = $5, updated_at_ms = $6 +WHERE object_type = $1 AND scope = $2 AND version = $3 ", ) + .bind(POLICY_OBJECT_TYPE) .bind(sandbox_id) .bind(version) .bind(status) - .bind(load_error) - .bind(loaded_at_ms) + .bind(payload) + .bind(now_ms) .execute(&self.pool) .await .map_err(|e| map_db_error(&e))?; @@ -304,57 +328,48 @@ WHERE sandbox_id = $1 AND version = $2 &self, sandbox_id: &str, before_version: i64, - ) -> Result { + ) -> PersistenceResult { + let now_ms = current_time_ms()?; let result = sqlx::query( r" -UPDATE sandbox_policies -SET status = 'superseded' -WHERE sandbox_id = $1 AND version < $2 AND status IN ('pending', 'loaded') +UPDATE objects +SET status = 'superseded', updated_at_ms = $4 +WHERE object_type = $1 + AND scope = $2 + AND version < $3 + AND status IN ('pending', 'loaded') ", ) + .bind(POLICY_OBJECT_TYPE) .bind(sandbox_id) .bind(before_version) + .bind(now_ms) .execute(&self.pool) .await .map_err(|e| map_db_error(&e))?; Ok(result.rows_affected()) } - // ------------------------------------------------------------------- - // Draft policy chunk operations - // ------------------------------------------------------------------- - - pub async fn put_draft_chunk(&self, chunk: &DraftChunkRecord) -> Result<()> { + pub async fn put_draft_chunk(&self, chunk: &DraftChunkRecord) -> PersistenceResult<()> { + let payload = draft_chunk_payload_from_record(chunk)?; sqlx::query( r" -INSERT INTO draft_policy_chunks - (id, sandbox_id, draft_version, status, rule_name, - proposed_rule, rationale, security_notes, confidence, - created_at_ms, decided_at_ms, host, port, binary, - hit_count, first_seen_ms, last_seen_ms) -VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17) -ON CONFLICT (sandbox_id, host, port, binary) - WHERE status IN ('pending', 'approved', 'rejected') -DO UPDATE SET - hit_count = draft_policy_chunks.hit_count + EXCLUDED.hit_count, - last_seen_ms = EXCLUDED.last_seen_ms +INSERT INTO objects ( + object_type, id, scope, status, dedup_key, hit_count, payload, created_at_ms, updated_at_ms +) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) +ON CONFLICT (object_type, scope, dedup_key) WHERE dedup_key IS NOT NULL DO UPDATE SET + hit_count = objects.hit_count + EXCLUDED.hit_count, + updated_at_ms = EXCLUDED.updated_at_ms ", ) + .bind(DRAFT_CHUNK_OBJECT_TYPE) .bind(&chunk.id) .bind(&chunk.sandbox_id) - .bind(chunk.draft_version) .bind(&chunk.status) - .bind(&chunk.rule_name) - .bind(&chunk.proposed_rule) - .bind(&chunk.rationale) - .bind(&chunk.security_notes) - .bind(chunk.confidence) - .bind(chunk.created_at_ms) - .bind(chunk.decided_at_ms) - .bind(&chunk.host) - .bind(chunk.port) - .bind(&chunk.binary) - .bind(chunk.hit_count) + .bind(draft_chunk_dedup_key(chunk)) + .bind(i64::from(chunk.hit_count)) + .bind(payload) .bind(chunk.first_seen_ms) .bind(chunk.last_seen_ms) .execute(&self.pool) @@ -363,33 +378,38 @@ DO UPDATE SET Ok(()) } - pub async fn get_draft_chunk(&self, id: &str) -> Result> { + pub async fn get_draft_chunk(&self, id: &str) -> PersistenceResult> { let row = sqlx::query( r" -SELECT * FROM draft_policy_chunks WHERE id = $1 +SELECT id, scope, status, hit_count, payload, created_at_ms, updated_at_ms +FROM objects +WHERE object_type = $1 AND id = $2 ", ) + .bind(DRAFT_CHUNK_OBJECT_TYPE) .bind(id) .fetch_optional(&self.pool) .await .map_err(|e| map_db_error(&e))?; - Ok(row.map(row_to_draft_chunk_record)) + row.map(row_to_draft_chunk_record).transpose() } pub async fn list_draft_chunks( &self, sandbox_id: &str, status_filter: Option<&str>, - ) -> Result> { + ) -> PersistenceResult> { let rows = if let Some(status) = status_filter { sqlx::query( r" -SELECT * FROM draft_policy_chunks -WHERE sandbox_id = $1 AND status = $2 +SELECT id, scope, status, hit_count, payload, created_at_ms, updated_at_ms +FROM objects +WHERE object_type = $1 AND scope = $2 AND status = $3 ORDER BY created_at_ms DESC ", ) + .bind(DRAFT_CHUNK_OBJECT_TYPE) .bind(sandbox_id) .bind(status) .fetch_all(&self.pool) @@ -397,18 +417,20 @@ ORDER BY created_at_ms DESC } else { sqlx::query( r" -SELECT * FROM draft_policy_chunks -WHERE sandbox_id = $1 +SELECT id, scope, status, hit_count, payload, created_at_ms, updated_at_ms +FROM objects +WHERE object_type = $1 AND scope = $2 ORDER BY created_at_ms DESC ", ) + .bind(DRAFT_CHUNK_OBJECT_TYPE) .bind(sandbox_id) .fetch_all(&self.pool) .await } .map_err(|e| map_db_error(&e))?; - Ok(rows.into_iter().map(row_to_draft_chunk_record).collect()) + rows.into_iter().map(row_to_draft_chunk_record).collect() } pub async fn update_draft_chunk_status( @@ -416,46 +438,80 @@ ORDER BY created_at_ms DESC id: &str, status: &str, decided_at_ms: Option, - ) -> Result { + ) -> PersistenceResult { + let Some(mut record) = self.get_draft_chunk(id).await? else { + return Ok(false); + }; + + record.status = status.to_string(); + record.decided_at_ms = decided_at_ms; + record.last_seen_ms = current_time_ms()?; + let payload = draft_chunk_payload_from_record(&record)?; + let result = sqlx::query( r" -UPDATE draft_policy_chunks -SET status = $2, decided_at_ms = $3 -WHERE id = $1 +UPDATE objects +SET status = $3, payload = $4, updated_at_ms = $5 +WHERE object_type = $1 AND id = $2 ", ) + .bind(DRAFT_CHUNK_OBJECT_TYPE) .bind(id) .bind(status) - .bind(decided_at_ms) + .bind(payload) + .bind(record.last_seen_ms) .execute(&self.pool) .await .map_err(|e| map_db_error(&e))?; Ok(result.rows_affected() > 0) } - pub async fn update_draft_chunk_rule(&self, id: &str, proposed_rule: &[u8]) -> Result { + pub async fn update_draft_chunk_rule( + &self, + id: &str, + proposed_rule: &[u8], + ) -> PersistenceResult { + let Some(mut record) = self.get_draft_chunk(id).await? else { + return Ok(false); + }; + + if record.status != "pending" { + return Ok(false); + } + + record.proposed_rule = proposed_rule.to_vec(); + record.last_seen_ms = current_time_ms()?; + let payload = draft_chunk_payload_from_record(&record)?; + let result = sqlx::query( r" -UPDATE draft_policy_chunks -SET proposed_rule = $2 -WHERE id = $1 AND status = 'pending' +UPDATE objects +SET payload = $3, updated_at_ms = $4 +WHERE object_type = $1 AND id = $2 AND status = 'pending' ", ) + .bind(DRAFT_CHUNK_OBJECT_TYPE) .bind(id) - .bind(proposed_rule) + .bind(payload) + .bind(record.last_seen_ms) .execute(&self.pool) .await .map_err(|e| map_db_error(&e))?; Ok(result.rows_affected() > 0) } - pub async fn delete_draft_chunks(&self, sandbox_id: &str, status: &str) -> Result { + pub async fn delete_draft_chunks( + &self, + sandbox_id: &str, + status: &str, + ) -> PersistenceResult { let result = sqlx::query( r" -DELETE FROM draft_policy_chunks -WHERE sandbox_id = $1 AND status = $2 +DELETE FROM objects +WHERE object_type = $1 AND scope = $2 AND status = $3 ", ) + .bind(DRAFT_CHUNK_OBJECT_TYPE) .bind(sandbox_id) .bind(status) .execute(&self.pool) @@ -464,55 +520,78 @@ WHERE sandbox_id = $1 AND status = $2 Ok(result.rows_affected()) } - pub async fn get_draft_version(&self, sandbox_id: &str) -> Result { - let row = sqlx::query( + pub async fn get_draft_version(&self, sandbox_id: &str) -> PersistenceResult { + let rows = sqlx::query( r" -SELECT COALESCE(MAX(draft_version), 0) as max_version -FROM draft_policy_chunks -WHERE sandbox_id = $1 +SELECT payload +FROM objects +WHERE object_type = $1 AND scope = $2 ", ) + .bind(DRAFT_CHUNK_OBJECT_TYPE) .bind(sandbox_id) - .fetch_one(&self.pool) + .fetch_all(&self.pool) .await .map_err(|e| map_db_error(&e))?; - Ok(row.get("max_version")) + let mut max_version = 0_i64; + for row in rows { + let payload: Vec = row.get("payload"); + let wrapper = draft_chunk_record_from_parts( + String::new(), + sandbox_id.to_string(), + String::new(), + 0, + &payload, + 0, + 0, + )?; + max_version = max_version.max(wrapper.draft_version); + } + Ok(max_version) } } -fn row_to_draft_chunk_record(row: sqlx::postgres::PgRow) -> DraftChunkRecord { - DraftChunkRecord { - id: row.get("id"), - sandbox_id: row.get("sandbox_id"), - draft_version: row.get("draft_version"), - status: row.get("status"), - rule_name: row.get("rule_name"), - proposed_rule: row.get("proposed_rule"), - rationale: row.get("rationale"), - security_notes: row.get("security_notes"), - confidence: row.get("confidence"), - created_at_ms: row.get("created_at_ms"), - decided_at_ms: row.get("decided_at_ms"), - host: row.get("host"), - port: row.get("port"), - binary: row.get("binary"), - hit_count: row.get("hit_count"), - first_seen_ms: row.get("first_seen_ms"), - last_seen_ms: row.get("last_seen_ms"), - } +fn draft_chunk_dedup_key(chunk: &DraftChunkRecord) -> String { + format!("{}|{}|{}", chunk.host, chunk.port, chunk.binary) } -fn row_to_policy_record(row: sqlx::postgres::PgRow) -> PolicyRecord { - PolicyRecord { +fn row_to_object_record(row: sqlx::postgres::PgRow) -> ObjectRecord { + ObjectRecord { + object_type: row.get("object_type"), id: row.get("id"), - sandbox_id: row.get("sandbox_id"), - version: row.get("version"), - policy_payload: row.get("policy_payload"), - policy_hash: row.get("policy_hash"), - status: row.get("status"), - load_error: row.get("load_error"), + name: row.get("name"), + payload: row.get("payload"), created_at_ms: row.get("created_at_ms"), - loaded_at_ms: row.get("loaded_at_ms"), + updated_at_ms: row.get("updated_at_ms"), } } + +fn row_to_policy_record(row: sqlx::postgres::PgRow) -> PersistenceResult { + let id: String = row.get("id"); + let sandbox_id: String = row.get("scope"); + let version: i64 = row.get("version"); + let status: String = row.get("status"); + let payload: Vec = row.get("payload"); + let created_at_ms: i64 = row.get("created_at_ms"); + policy_record_from_parts(id, sandbox_id, version, status, &payload, created_at_ms) +} + +fn row_to_draft_chunk_record(row: sqlx::postgres::PgRow) -> PersistenceResult { + let id: String = row.get("id"); + let sandbox_id: String = row.get("scope"); + let status: String = row.get("status"); + let hit_count: i64 = row.get("hit_count"); + let payload: Vec = row.get("payload"); + let created_at_ms: i64 = row.get("created_at_ms"); + let updated_at_ms: i64 = row.get("updated_at_ms"); + draft_chunk_record_from_parts( + id, + sandbox_id, + status, + hit_count, + &payload, + created_at_ms, + updated_at_ms, + ) +} diff --git a/crates/openshell-server/src/persistence/sqlite.rs b/crates/openshell-server/src/persistence/sqlite.rs index 3b660b314..2fe72073a 100644 --- a/crates/openshell-server/src/persistence/sqlite.rs +++ b/crates/openshell-server/src/persistence/sqlite.rs @@ -2,22 +2,26 @@ // SPDX-License-Identifier: Apache-2.0 use super::{ - DraftChunkRecord, ObjectRecord, PolicyRecord, current_time_ms, map_db_error, map_migrate_error, + DraftChunkRecord, ObjectRecord, PersistenceResult, PolicyRecord, current_time_ms, + draft_chunk_payload_from_record, draft_chunk_record_from_parts, map_db_error, + map_migrate_error, policy_payload_from_record, policy_record_from_parts, }; -use openshell_core::Result; use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions}; use sqlx::{Row, SqlitePool}; use std::str::FromStr; static SQLITE_MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("./migrations/sqlite"); +const POLICY_OBJECT_TYPE: &str = "sandbox_policy"; +const DRAFT_CHUNK_OBJECT_TYPE: &str = "draft_policy_chunk"; + #[derive(Debug, Clone)] pub struct SqliteStore { pool: SqlitePool, } impl SqliteStore { - pub async fn connect(url: &str) -> Result { + pub async fn connect(url: &str) -> PersistenceResult { let is_in_memory = url.contains(":memory:") || url.contains("mode=memory"); let max_connections = if is_in_memory { 1 } else { 5 }; @@ -29,11 +33,6 @@ impl SqliteStore { .max_connections(max_connections) .min_connections(max_connections); - // In-memory SQLite databases exist only while at least one connection - // is open. SQLx's default `max_lifetime` (30 min) and `idle_timeout` - // (10 min) would recycle the sole connection, destroying the database - // and all its tables. Disable both timeouts so the connection (and - // therefore the database) lives for the entire process lifetime. if is_in_memory { pool_options = pool_options.idle_timeout(None).max_lifetime(None); } @@ -46,24 +45,29 @@ impl SqliteStore { Ok(Self { pool }) } - pub async fn migrate(&self) -> Result<()> { + pub async fn migrate(&self) -> PersistenceResult<()> { SQLITE_MIGRATOR .run(&self.pool) .await .map_err(|e| map_migrate_error(&e)) } - pub async fn put(&self, object_type: &str, id: &str, name: &str, payload: &[u8]) -> Result<()> { + pub async fn put( + &self, + object_type: &str, + id: &str, + name: &str, + payload: &[u8], + ) -> PersistenceResult<()> { let now_ms = current_time_ms()?; sqlx::query( r#" INSERT INTO "objects" ("object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms") VALUES (?1, ?2, ?3, ?4, ?5, ?5) -ON CONFLICT ("id") DO UPDATE SET +ON CONFLICT ("object_type", "name") WHERE "name" IS NOT NULL DO UPDATE SET "payload" = excluded."payload", "updated_at_ms" = excluded."updated_at_ms" -WHERE "objects"."object_type" = excluded."object_type" "#, ) .bind(object_type) @@ -77,7 +81,11 @@ WHERE "objects"."object_type" = excluded."object_type" Ok(()) } - pub async fn get(&self, object_type: &str, id: &str) -> Result> { + pub async fn get( + &self, + object_type: &str, + id: &str, + ) -> PersistenceResult> { let row = sqlx::query( r#" SELECT "object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms" @@ -91,17 +99,14 @@ WHERE "object_type" = ?1 AND "id" = ?2 .await .map_err(|e| map_db_error(&e))?; - Ok(row.map(|row| ObjectRecord { - object_type: row.get("object_type"), - id: row.get("id"), - name: row.get("name"), - payload: row.get("payload"), - created_at_ms: row.get("created_at_ms"), - updated_at_ms: row.get("updated_at_ms"), - })) + Ok(row.map(row_to_object_record)) } - pub async fn get_by_name(&self, object_type: &str, name: &str) -> Result> { + pub async fn get_by_name( + &self, + object_type: &str, + name: &str, + ) -> PersistenceResult> { let row = sqlx::query( r#" SELECT "object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms" @@ -115,17 +120,10 @@ WHERE "object_type" = ?1 AND "name" = ?2 .await .map_err(|e| map_db_error(&e))?; - Ok(row.map(|row| ObjectRecord { - object_type: row.get("object_type"), - id: row.get("id"), - name: row.get("name"), - payload: row.get("payload"), - created_at_ms: row.get("created_at_ms"), - updated_at_ms: row.get("updated_at_ms"), - })) + Ok(row.map(row_to_object_record)) } - pub async fn delete(&self, object_type: &str, id: &str) -> Result { + pub async fn delete(&self, object_type: &str, id: &str) -> PersistenceResult { let result = sqlx::query( r#" DELETE FROM "objects" @@ -140,7 +138,7 @@ WHERE "object_type" = ?1 AND "id" = ?2 Ok(result.rows_affected() > 0) } - pub async fn delete_by_name(&self, object_type: &str, name: &str) -> Result { + pub async fn delete_by_name(&self, object_type: &str, name: &str) -> PersistenceResult { let result = sqlx::query( r#" DELETE FROM "objects" @@ -160,7 +158,7 @@ WHERE "object_type" = ?1 AND "name" = ?2 object_type: &str, limit: u32, offset: u32, - ) -> Result> { + ) -> PersistenceResult> { let rows = sqlx::query( r#" SELECT "object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms" @@ -177,25 +175,9 @@ LIMIT ?2 OFFSET ?3 .await .map_err(|e| map_db_error(&e))?; - let records = rows - .into_iter() - .map(|row| ObjectRecord { - object_type: row.get("object_type"), - id: row.get("id"), - name: row.get("name"), - payload: row.get("payload"), - created_at_ms: row.get("created_at_ms"), - updated_at_ms: row.get("updated_at_ms"), - }) - .collect(); - - Ok(records) + Ok(rows.into_iter().map(row_to_object_record).collect()) } - // ------------------------------------------------------------------- - // Policy revision operations - // ------------------------------------------------------------------- - pub async fn put_policy_revision( &self, id: &str, @@ -203,19 +185,35 @@ LIMIT ?2 OFFSET ?3 version: i64, payload: &[u8], hash: &str, - ) -> Result<()> { + ) -> PersistenceResult<()> { let now_ms = current_time_ms()?; + let record = PolicyRecord { + id: id.to_string(), + sandbox_id: sandbox_id.to_string(), + version, + policy_payload: payload.to_vec(), + policy_hash: hash.to_string(), + status: "pending".to_string(), + load_error: None, + created_at_ms: now_ms, + loaded_at_ms: None, + }; + let wrapped_payload = policy_payload_from_record(&record)?; + sqlx::query( r#" -INSERT INTO "sandbox_policies" ("id", "sandbox_id", "version", "policy_payload", "policy_hash", "status", "created_at_ms") -VALUES (?1, ?2, ?3, ?4, ?5, 'pending', ?6) +INSERT INTO "objects" ( + "object_type", "id", "scope", "version", "status", "payload", "created_at_ms", "updated_at_ms" +) +VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?7) "#, ) + .bind(POLICY_OBJECT_TYPE) .bind(id) .bind(sandbox_id) .bind(version) - .bind(payload) - .bind(hash) + .bind("pending") + .bind(wrapped_payload) .bind(now_ms) .execute(&self.pool) .await @@ -223,61 +221,70 @@ VALUES (?1, ?2, ?3, ?4, ?5, 'pending', ?6) Ok(()) } - pub async fn get_latest_policy(&self, sandbox_id: &str) -> Result> { + pub async fn get_latest_policy( + &self, + sandbox_id: &str, + ) -> PersistenceResult> { let row = sqlx::query( r#" -SELECT "id", "sandbox_id", "version", "policy_payload", "policy_hash", "status", "load_error", "created_at_ms", "loaded_at_ms" -FROM "sandbox_policies" -WHERE "sandbox_id" = ?1 -ORDER BY "version" DESC +SELECT "id", "scope", "version", "status", "payload", "created_at_ms" +FROM "objects" +WHERE "object_type" = ?1 AND "scope" = ?2 +ORDER BY "version" DESC, "created_at_ms" DESC LIMIT 1 "#, ) + .bind(POLICY_OBJECT_TYPE) .bind(sandbox_id) .fetch_optional(&self.pool) .await .map_err(|e| map_db_error(&e))?; - Ok(row.map(row_to_policy_record)) + row.map(row_to_policy_record).transpose() } - pub async fn get_latest_loaded_policy(&self, sandbox_id: &str) -> Result> { + pub async fn get_latest_loaded_policy( + &self, + sandbox_id: &str, + ) -> PersistenceResult> { let row = sqlx::query( r#" -SELECT "id", "sandbox_id", "version", "policy_payload", "policy_hash", "status", "load_error", "created_at_ms", "loaded_at_ms" -FROM "sandbox_policies" -WHERE "sandbox_id" = ?1 AND "status" = 'loaded' -ORDER BY "version" DESC +SELECT "id", "scope", "version", "status", "payload", "created_at_ms" +FROM "objects" +WHERE "object_type" = ?1 AND "scope" = ?2 AND "status" = 'loaded' +ORDER BY "version" DESC, "created_at_ms" DESC LIMIT 1 "#, ) + .bind(POLICY_OBJECT_TYPE) .bind(sandbox_id) .fetch_optional(&self.pool) .await .map_err(|e| map_db_error(&e))?; - Ok(row.map(row_to_policy_record)) + row.map(row_to_policy_record).transpose() } pub async fn get_policy_by_version( &self, sandbox_id: &str, version: i64, - ) -> Result> { + ) -> PersistenceResult> { let row = sqlx::query( r#" -SELECT "id", "sandbox_id", "version", "policy_payload", "policy_hash", "status", "load_error", "created_at_ms", "loaded_at_ms" -FROM "sandbox_policies" -WHERE "sandbox_id" = ?1 AND "version" = ?2 +SELECT "id", "scope", "version", "status", "payload", "created_at_ms" +FROM "objects" +WHERE "object_type" = ?1 AND "scope" = ?2 AND "version" = ?3 "#, ) + .bind(POLICY_OBJECT_TYPE) .bind(sandbox_id) .bind(version) .fetch_optional(&self.pool) .await .map_err(|e| map_db_error(&e))?; - Ok(row.map(row_to_policy_record)) + row.map(row_to_policy_record).transpose() } pub async fn list_policies( @@ -285,16 +292,17 @@ WHERE "sandbox_id" = ?1 AND "version" = ?2 sandbox_id: &str, limit: u32, offset: u32, - ) -> Result> { + ) -> PersistenceResult> { let rows = sqlx::query( r#" -SELECT "id", "sandbox_id", "version", "policy_payload", "policy_hash", "status", "load_error", "created_at_ms", "loaded_at_ms" -FROM "sandbox_policies" -WHERE "sandbox_id" = ?1 -ORDER BY "version" DESC -LIMIT ?2 OFFSET ?3 +SELECT "id", "scope", "version", "status", "payload", "created_at_ms" +FROM "objects" +WHERE "object_type" = ?1 AND "scope" = ?2 +ORDER BY "version" DESC, "created_at_ms" DESC +LIMIT ?3 OFFSET ?4 "#, ) + .bind(POLICY_OBJECT_TYPE) .bind(sandbox_id) .bind(i64::from(limit)) .bind(i64::from(offset)) @@ -302,7 +310,7 @@ LIMIT ?2 OFFSET ?3 .await .map_err(|e| map_db_error(&e))?; - Ok(rows.into_iter().map(row_to_policy_record).collect()) + rows.into_iter().map(row_to_policy_record).collect() } pub async fn update_policy_status( @@ -312,19 +320,30 @@ LIMIT ?2 OFFSET ?3 status: &str, load_error: Option<&str>, loaded_at_ms: Option, - ) -> Result { + ) -> PersistenceResult { + let Some(mut record) = self.get_policy_by_version(sandbox_id, version).await? else { + return Ok(false); + }; + + record.status = status.to_string(); + record.load_error = load_error.map(ToOwned::to_owned); + record.loaded_at_ms = loaded_at_ms; + let payload = policy_payload_from_record(&record)?; + let now_ms = current_time_ms()?; + let result = sqlx::query( r#" -UPDATE "sandbox_policies" -SET "status" = ?3, "load_error" = ?4, "loaded_at_ms" = ?5 -WHERE "sandbox_id" = ?1 AND "version" = ?2 +UPDATE "objects" +SET "status" = ?4, "payload" = ?5, "updated_at_ms" = ?6 +WHERE "object_type" = ?1 AND "scope" = ?2 AND "version" = ?3 "#, ) + .bind(POLICY_OBJECT_TYPE) .bind(sandbox_id) .bind(version) .bind(status) - .bind(load_error) - .bind(loaded_at_ms) + .bind(payload) + .bind(now_ms) .execute(&self.pool) .await .map_err(|e| map_db_error(&e))?; @@ -335,57 +354,48 @@ WHERE "sandbox_id" = ?1 AND "version" = ?2 &self, sandbox_id: &str, before_version: i64, - ) -> Result { + ) -> PersistenceResult { + let now_ms = current_time_ms()?; let result = sqlx::query( r#" -UPDATE "sandbox_policies" -SET "status" = 'superseded' -WHERE "sandbox_id" = ?1 AND "version" < ?2 AND "status" IN ('pending', 'loaded') +UPDATE "objects" +SET "status" = 'superseded', "updated_at_ms" = ?4 +WHERE "object_type" = ?1 + AND "scope" = ?2 + AND "version" < ?3 + AND "status" IN ('pending', 'loaded') "#, ) + .bind(POLICY_OBJECT_TYPE) .bind(sandbox_id) .bind(before_version) + .bind(now_ms) .execute(&self.pool) .await .map_err(|e| map_db_error(&e))?; Ok(result.rows_affected()) } - // ------------------------------------------------------------------- - // Draft policy chunk operations - // ------------------------------------------------------------------- - - pub async fn put_draft_chunk(&self, chunk: &DraftChunkRecord) -> Result<()> { + pub async fn put_draft_chunk(&self, chunk: &DraftChunkRecord) -> PersistenceResult<()> { + let payload = draft_chunk_payload_from_record(chunk)?; sqlx::query( r#" -INSERT INTO "draft_policy_chunks" - ("id", "sandbox_id", "draft_version", "status", "rule_name", - "proposed_rule", "rationale", "security_notes", "confidence", - "created_at_ms", "decided_at_ms", "host", "port", "binary", - "hit_count", "first_seen_ms", "last_seen_ms") -VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17) -ON CONFLICT ("sandbox_id", "host", "port", "binary") - WHERE "status" IN ('pending', 'approved', 'rejected') -DO UPDATE SET - "hit_count" = "draft_policy_chunks"."hit_count" + excluded."hit_count", - "last_seen_ms" = excluded."last_seen_ms" +INSERT INTO "objects" ( + "object_type", "id", "scope", "status", "dedup_key", "hit_count", "payload", "created_at_ms", "updated_at_ms" +) +VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9) +ON CONFLICT ("object_type", "scope", "dedup_key") WHERE "dedup_key" IS NOT NULL DO UPDATE SET + "hit_count" = "objects"."hit_count" + excluded."hit_count", + "updated_at_ms" = excluded."updated_at_ms" "#, ) + .bind(DRAFT_CHUNK_OBJECT_TYPE) .bind(&chunk.id) .bind(&chunk.sandbox_id) - .bind(chunk.draft_version) .bind(&chunk.status) - .bind(&chunk.rule_name) - .bind(&chunk.proposed_rule) - .bind(&chunk.rationale) - .bind(&chunk.security_notes) - .bind(chunk.confidence) - .bind(chunk.created_at_ms) - .bind(chunk.decided_at_ms) - .bind(&chunk.host) - .bind(chunk.port) - .bind(&chunk.binary) - .bind(chunk.hit_count) + .bind(draft_chunk_dedup_key(chunk)) + .bind(i64::from(chunk.hit_count)) + .bind(payload) .bind(chunk.first_seen_ms) .bind(chunk.last_seen_ms) .execute(&self.pool) @@ -394,33 +404,38 @@ DO UPDATE SET Ok(()) } - pub async fn get_draft_chunk(&self, id: &str) -> Result> { + pub async fn get_draft_chunk(&self, id: &str) -> PersistenceResult> { let row = sqlx::query( r#" -SELECT * FROM "draft_policy_chunks" WHERE "id" = ?1 +SELECT "id", "scope", "status", "hit_count", "payload", "created_at_ms", "updated_at_ms" +FROM "objects" +WHERE "object_type" = ?1 AND "id" = ?2 "#, ) + .bind(DRAFT_CHUNK_OBJECT_TYPE) .bind(id) .fetch_optional(&self.pool) .await .map_err(|e| map_db_error(&e))?; - Ok(row.map(row_to_draft_chunk_record)) + row.map(row_to_draft_chunk_record).transpose() } pub async fn list_draft_chunks( &self, sandbox_id: &str, status_filter: Option<&str>, - ) -> Result> { + ) -> PersistenceResult> { let rows = if let Some(status) = status_filter { sqlx::query( r#" -SELECT * FROM "draft_policy_chunks" -WHERE "sandbox_id" = ?1 AND "status" = ?2 +SELECT "id", "scope", "status", "hit_count", "payload", "created_at_ms", "updated_at_ms" +FROM "objects" +WHERE "object_type" = ?1 AND "scope" = ?2 AND "status" = ?3 ORDER BY "created_at_ms" DESC "#, ) + .bind(DRAFT_CHUNK_OBJECT_TYPE) .bind(sandbox_id) .bind(status) .fetch_all(&self.pool) @@ -428,18 +443,20 @@ ORDER BY "created_at_ms" DESC } else { sqlx::query( r#" -SELECT * FROM "draft_policy_chunks" -WHERE "sandbox_id" = ?1 +SELECT "id", "scope", "status", "hit_count", "payload", "created_at_ms", "updated_at_ms" +FROM "objects" +WHERE "object_type" = ?1 AND "scope" = ?2 ORDER BY "created_at_ms" DESC "#, ) + .bind(DRAFT_CHUNK_OBJECT_TYPE) .bind(sandbox_id) .fetch_all(&self.pool) .await } .map_err(|e| map_db_error(&e))?; - Ok(rows.into_iter().map(row_to_draft_chunk_record).collect()) + rows.into_iter().map(row_to_draft_chunk_record).collect() } pub async fn update_draft_chunk_status( @@ -447,46 +464,80 @@ ORDER BY "created_at_ms" DESC id: &str, status: &str, decided_at_ms: Option, - ) -> Result { + ) -> PersistenceResult { + let Some(mut record) = self.get_draft_chunk(id).await? else { + return Ok(false); + }; + + record.status = status.to_string(); + record.decided_at_ms = decided_at_ms; + record.last_seen_ms = current_time_ms()?; + let payload = draft_chunk_payload_from_record(&record)?; + let result = sqlx::query( r#" -UPDATE "draft_policy_chunks" -SET "status" = ?2, "decided_at_ms" = ?3 -WHERE "id" = ?1 +UPDATE "objects" +SET "status" = ?3, "payload" = ?4, "updated_at_ms" = ?5 +WHERE "object_type" = ?1 AND "id" = ?2 "#, ) + .bind(DRAFT_CHUNK_OBJECT_TYPE) .bind(id) .bind(status) - .bind(decided_at_ms) + .bind(payload) + .bind(record.last_seen_ms) .execute(&self.pool) .await .map_err(|e| map_db_error(&e))?; Ok(result.rows_affected() > 0) } - pub async fn update_draft_chunk_rule(&self, id: &str, proposed_rule: &[u8]) -> Result { + pub async fn update_draft_chunk_rule( + &self, + id: &str, + proposed_rule: &[u8], + ) -> PersistenceResult { + let Some(mut record) = self.get_draft_chunk(id).await? else { + return Ok(false); + }; + + if record.status != "pending" { + return Ok(false); + } + + record.proposed_rule = proposed_rule.to_vec(); + record.last_seen_ms = current_time_ms()?; + let payload = draft_chunk_payload_from_record(&record)?; + let result = sqlx::query( r#" -UPDATE "draft_policy_chunks" -SET "proposed_rule" = ?2 -WHERE "id" = ?1 AND "status" = 'pending' +UPDATE "objects" +SET "payload" = ?3, "updated_at_ms" = ?4 +WHERE "object_type" = ?1 AND "id" = ?2 AND "status" = 'pending' "#, ) + .bind(DRAFT_CHUNK_OBJECT_TYPE) .bind(id) - .bind(proposed_rule) + .bind(payload) + .bind(record.last_seen_ms) .execute(&self.pool) .await .map_err(|e| map_db_error(&e))?; Ok(result.rows_affected() > 0) } - pub async fn delete_draft_chunks(&self, sandbox_id: &str, status: &str) -> Result { + pub async fn delete_draft_chunks( + &self, + sandbox_id: &str, + status: &str, + ) -> PersistenceResult { let result = sqlx::query( r#" -DELETE FROM "draft_policy_chunks" -WHERE "sandbox_id" = ?1 AND "status" = ?2 +DELETE FROM "objects" +WHERE "object_type" = ?1 AND "scope" = ?2 AND "status" = ?3 "#, ) + .bind(DRAFT_CHUNK_OBJECT_TYPE) .bind(sandbox_id) .bind(status) .execute(&self.pool) @@ -495,55 +546,78 @@ WHERE "sandbox_id" = ?1 AND "status" = ?2 Ok(result.rows_affected()) } - pub async fn get_draft_version(&self, sandbox_id: &str) -> Result { - let row = sqlx::query( + pub async fn get_draft_version(&self, sandbox_id: &str) -> PersistenceResult { + let rows = sqlx::query( r#" -SELECT COALESCE(MAX("draft_version"), 0) as "max_version" -FROM "draft_policy_chunks" -WHERE "sandbox_id" = ?1 +SELECT "payload" +FROM "objects" +WHERE "object_type" = ?1 AND "scope" = ?2 "#, ) + .bind(DRAFT_CHUNK_OBJECT_TYPE) .bind(sandbox_id) - .fetch_one(&self.pool) + .fetch_all(&self.pool) .await .map_err(|e| map_db_error(&e))?; - Ok(row.get("max_version")) + let mut max_version = 0_i64; + for row in rows { + let payload: Vec = row.get("payload"); + let wrapper = draft_chunk_record_from_parts( + String::new(), + sandbox_id.to_string(), + String::new(), + 0, + &payload, + 0, + 0, + )?; + max_version = max_version.max(wrapper.draft_version); + } + Ok(max_version) } } -fn row_to_draft_chunk_record(row: sqlx::sqlite::SqliteRow) -> DraftChunkRecord { - DraftChunkRecord { - id: row.get("id"), - sandbox_id: row.get("sandbox_id"), - draft_version: row.get("draft_version"), - status: row.get("status"), - rule_name: row.get("rule_name"), - proposed_rule: row.get("proposed_rule"), - rationale: row.get("rationale"), - security_notes: row.get("security_notes"), - confidence: row.get("confidence"), - created_at_ms: row.get("created_at_ms"), - decided_at_ms: row.get("decided_at_ms"), - host: row.get("host"), - port: row.get("port"), - binary: row.get("binary"), - hit_count: row.get("hit_count"), - first_seen_ms: row.get("first_seen_ms"), - last_seen_ms: row.get("last_seen_ms"), - } +fn draft_chunk_dedup_key(chunk: &DraftChunkRecord) -> String { + format!("{}|{}|{}", chunk.host, chunk.port, chunk.binary) } -fn row_to_policy_record(row: sqlx::sqlite::SqliteRow) -> PolicyRecord { - PolicyRecord { +fn row_to_object_record(row: sqlx::sqlite::SqliteRow) -> ObjectRecord { + ObjectRecord { + object_type: row.get("object_type"), id: row.get("id"), - sandbox_id: row.get("sandbox_id"), - version: row.get("version"), - policy_payload: row.get("policy_payload"), - policy_hash: row.get("policy_hash"), - status: row.get("status"), - load_error: row.get("load_error"), + name: row.get("name"), + payload: row.get("payload"), created_at_ms: row.get("created_at_ms"), - loaded_at_ms: row.get("loaded_at_ms"), + updated_at_ms: row.get("updated_at_ms"), } } + +fn row_to_policy_record(row: sqlx::sqlite::SqliteRow) -> PersistenceResult { + let id: String = row.get("id"); + let sandbox_id: String = row.get("scope"); + let version: i64 = row.get("version"); + let status: String = row.get("status"); + let payload: Vec = row.get("payload"); + let created_at_ms: i64 = row.get("created_at_ms"); + policy_record_from_parts(id, sandbox_id, version, status, &payload, created_at_ms) +} + +fn row_to_draft_chunk_record(row: sqlx::sqlite::SqliteRow) -> PersistenceResult { + let id: String = row.get("id"); + let sandbox_id: String = row.get("scope"); + let status: String = row.get("status"); + let hit_count: i64 = row.get("hit_count"); + let payload: Vec = row.get("payload"); + let created_at_ms: i64 = row.get("created_at_ms"); + let updated_at_ms: i64 = row.get("updated_at_ms"); + draft_chunk_record_from_parts( + id, + sandbox_id, + status, + hit_count, + &payload, + created_at_ms, + updated_at_ms, + ) +} diff --git a/crates/openshell-server/src/persistence/tests.rs b/crates/openshell-server/src/persistence/tests.rs index b3bfe2fa4..eb6248f25 100644 --- a/crates/openshell-server/src/persistence/tests.rs +++ b/crates/openshell-server/src/persistence/tests.rs @@ -2,7 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 use super::{ObjectId, ObjectName, ObjectType, Store, generate_name}; -use openshell_core::proto::ObjectForTest; +use openshell_core::proto::{ObjectForTest, SandboxPolicy}; +use prost::Message; #[tokio::test] async fn sqlite_put_get_round_trip() { @@ -206,11 +207,19 @@ async fn sqlite_name_unique_per_object_type() { .await .unwrap(); - // Same name, same object_type, different id -> should fail (unique constraint). - let result = store + // Same name, same object_type, different id -> upsert on name. + store .put("sandbox", "id-2", "shared-name", b"payload2") - .await; - assert!(result.is_err()); + .await + .unwrap(); + + let record = store + .get_by_name("sandbox", "shared-name") + .await + .unwrap() + .unwrap(); + assert_eq!(record.id, "id-1"); + assert_eq!(record.payload, b"payload2"); // Same name, different object_type -> should succeed. store @@ -230,13 +239,10 @@ async fn sqlite_id_globally_unique() { .await .unwrap(); - // Same id, different object_type -> the upsert is a no-op (WHERE - // clause prevents updating a row with a different object_type). - // The original row is preserved unchanged. - store - .put("secret", "same-id", "name-b", b"payload2") - .await - .unwrap(); + // Same id, different object_type -> should fail because ids remain global + // primary keys even when writes upsert on name. + let result = store.put("secret", "same-id", "name-b", b"payload2").await; + assert!(result.is_err()); // Original row is untouched. let record = store.get("sandbox", "same-id").await.unwrap().unwrap(); @@ -285,8 +291,9 @@ async fn policy_put_and_get_latest() { .await .unwrap(); + let policy_v1 = SandboxPolicy::default().encode_to_vec(); store - .put_policy_revision("p1", "sandbox-1", 1, b"policy-v1", "hash1") + .put_policy_revision("p1", "sandbox-1", 1, &policy_v1, "hash1") .await .unwrap(); @@ -294,11 +301,16 @@ async fn policy_put_and_get_latest() { assert_eq!(latest.version, 1); assert_eq!(latest.policy_hash, "hash1"); assert_eq!(latest.status, "pending"); - assert_eq!(latest.policy_payload, b"policy-v1"); + assert_eq!(latest.policy_payload, policy_v1); // Add version 2 + let policy_v2 = SandboxPolicy { + version: 2, + ..SandboxPolicy::default() + } + .encode_to_vec(); store - .put_policy_revision("p2", "sandbox-1", 2, b"policy-v2", "hash2") + .put_policy_revision("p2", "sandbox-1", 2, &policy_v2, "hash2") .await .unwrap(); @@ -313,12 +325,18 @@ async fn policy_get_by_version() { .await .unwrap(); + let policy_v1 = SandboxPolicy::default().encode_to_vec(); + let policy_v2 = SandboxPolicy { + version: 2, + ..SandboxPolicy::default() + } + .encode_to_vec(); store - .put_policy_revision("p1", "sandbox-1", 1, b"v1", "h1") + .put_policy_revision("p1", "sandbox-1", 1, &policy_v1, "h1") .await .unwrap(); store - .put_policy_revision("p2", "sandbox-1", 2, b"v2", "h2") + .put_policy_revision("p2", "sandbox-1", 2, &policy_v2, "h2") .await .unwrap(); @@ -348,8 +366,9 @@ async fn policy_update_status_and_get_loaded() { .await .unwrap(); + let payload = SandboxPolicy::default().encode_to_vec(); store - .put_policy_revision("p1", "sandbox-1", 1, b"v1", "h1") + .put_policy_revision("p1", "sandbox-1", 1, &payload, "h1") .await .unwrap(); @@ -380,8 +399,9 @@ async fn policy_status_failed_with_error() { .await .unwrap(); + let payload = SandboxPolicy::default().encode_to_vec(); store - .put_policy_revision("p1", "sandbox-1", 1, b"v1", "h1") + .put_policy_revision("p1", "sandbox-1", 1, &payload, "h1") .await .unwrap(); @@ -405,16 +425,17 @@ async fn policy_supersede_older() { .await .unwrap(); + let payload = SandboxPolicy::default().encode_to_vec(); store - .put_policy_revision("p1", "sandbox-1", 1, b"v1", "h1") + .put_policy_revision("p1", "sandbox-1", 1, &payload, "h1") .await .unwrap(); store - .put_policy_revision("p2", "sandbox-1", 2, b"v2", "h2") + .put_policy_revision("p2", "sandbox-1", 2, &payload, "h2") .await .unwrap(); store - .put_policy_revision("p3", "sandbox-1", 3, b"v3", "h3") + .put_policy_revision("p3", "sandbox-1", 3, &payload, "h3") .await .unwrap(); @@ -459,16 +480,17 @@ async fn policy_list_ordered_by_version_desc() { .await .unwrap(); + let payload = SandboxPolicy::default().encode_to_vec(); store - .put_policy_revision("p1", "sandbox-1", 1, b"v1", "h1") + .put_policy_revision("p1", "sandbox-1", 1, &payload, "h1") .await .unwrap(); store - .put_policy_revision("p2", "sandbox-1", 2, b"v2", "h2") + .put_policy_revision("p2", "sandbox-1", 2, &payload, "h2") .await .unwrap(); store - .put_policy_revision("p3", "sandbox-1", 3, b"v3", "h3") + .put_policy_revision("p3", "sandbox-1", 3, &payload, "h3") .await .unwrap(); @@ -491,18 +513,24 @@ async fn policy_isolation_between_sandboxes() { .await .unwrap(); + let policy_s1 = SandboxPolicy::default().encode_to_vec(); + let policy_s2 = SandboxPolicy { + version: 7, + ..SandboxPolicy::default() + } + .encode_to_vec(); store - .put_policy_revision("p1", "sandbox-1", 1, b"v1", "h1") + .put_policy_revision("p1", "sandbox-1", 1, &policy_s1, "h1") .await .unwrap(); store - .put_policy_revision("p2", "sandbox-2", 1, b"v1-s2", "h2") + .put_policy_revision("p2", "sandbox-2", 1, &policy_s2, "h2") .await .unwrap(); let s1 = store.get_latest_policy("sandbox-1").await.unwrap().unwrap(); let s2 = store.get_latest_policy("sandbox-2").await.unwrap().unwrap(); - assert_eq!(s1.policy_payload, b"v1"); - assert_eq!(s2.policy_payload, b"v1-s2"); + assert_eq!(s1.policy_payload, policy_s1); + assert_eq!(s2.policy_payload, policy_s2); } diff --git a/e2e/python/test_inference_routing.py b/e2e/python/test_inference_routing.py index c9d3965f1..7f0c0dbd4 100644 --- a/e2e/python/test_inference_routing.py +++ b/e2e/python/test_inference_routing.py @@ -126,6 +126,9 @@ def _restore_cluster_inference( inference_client.set_cluster( provider_name=previous.provider_name, model_id=previous.model_id, + # Teardown restores prior shared state as-is, even if the previous + # route is intentionally unreachable or no longer verifiable. + no_verify=True, ) diff --git a/e2e/rust/tests/host_gateway_alias.rs b/e2e/rust/tests/host_gateway_alias.rs index fe6a15e74..083f5e6ee 100644 --- a/e2e/rust/tests/host_gateway_alias.rs +++ b/e2e/rust/tests/host_gateway_alias.rs @@ -396,8 +396,12 @@ async fn inference_set_supports_no_verify_for_unreachable_endpoint() { verify_err.contains("failed to verify inference endpoint"), "expected verification failure output:\n{verify_err}" ); + let normalized_verify_err: String = verify_err + .chars() + .filter(|c| !c.is_whitespace() && *c != '│') + .collect(); assert!( - verify_err.contains("--no-verify"), + normalized_verify_err.contains("--no-verify"), "expected retry hint in failure output:\n{verify_err}" ); diff --git a/proto/openshell.proto b/proto/openshell.proto index 2434f1a80..adc8c78cd 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -1141,3 +1141,39 @@ message GetDraftHistoryResponse { // Chronological decision history. repeated DraftHistoryEntry entries = 1; } + +// Stored payload for a policy revision row in the generic objects table. +message PolicyRevisionPayload { + // Serialized policy contents. + openshell.sandbox.v1.SandboxPolicy policy = 1; + // Deterministic hash of the policy payload. + string hash = 2; + // Load error reported by the sandbox, if any. + string load_error = 3; + // When the policy version was reported as loaded (ms since epoch). 0 if unset. + int64 loaded_at_ms = 4; +} + +// Stored payload for a draft policy chunk row in the generic objects table. +message DraftChunkPayload { + // Proposed network_policies map key. + string rule_name = 1; + // Proposed network policy rule. + openshell.sandbox.v1.NetworkPolicyRule proposed_rule = 2; + // Human-readable explanation of why this rule is proposed. + string rationale = 3; + // Security concerns flagged by analysis (empty if none). + string security_notes = 4; + // Analysis confidence (0.0-1.0). 0 for mechanistic mode. + float confidence = 5; + // When the user approved/rejected (ms since epoch). 0 if undecided. + int64 decided_at_ms = 6; + // Denormalized endpoint host for dedup and display. + string host = 7; + // Denormalized endpoint port for dedup and display. + int32 port = 8; + // Binary path that triggered the denial. + string binary = 9; + // Current draft version for the owning sandbox. + int64 draft_version = 10; +}