diff --git a/src/google/adk_community/plugins/__init__.py b/src/google/adk_community/plugins/__init__.py index ab61116..2e4b2ee 100644 --- a/src/google/adk_community/plugins/__init__.py +++ b/src/google/adk_community/plugins/__init__.py @@ -15,5 +15,23 @@ from google.adk_community.plugins.agent_governance_plugin import ( AgentGovernancePlugin, ) +from google.adk_community.plugins.taxonomy import ( + DefaultSkillPolicy, + SkillPolicy, + TaxonomyPipeline, + TaxonomyPlugin, + TaxonomyRegistry, + TaxonomyResolver, + TaxonomyTerm, +) -__all__ = ["AgentGovernancePlugin"] +__all__ = [ + "AgentGovernancePlugin", + "DefaultSkillPolicy", + "SkillPolicy", + "TaxonomyPipeline", + "TaxonomyPlugin", + "TaxonomyRegistry", + "TaxonomyResolver", + "TaxonomyTerm", +] diff --git a/src/google/adk_community/plugins/taxonomy/__init__.py b/src/google/adk_community/plugins/taxonomy/__init__.py new file mode 100644 index 0000000..a745b61 --- /dev/null +++ b/src/google/adk_community/plugins/taxonomy/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pluggable Policy & Taxonomy Security Engine for ADK Community.""" + +from .policy import DefaultSkillPolicy +from .policy import DefaultKeywordResolver +from .policy import SkillPolicy +from .policy import TaxonomyPipeline +from .policy import TaxonomyResolver +from .taxonomy_config import TaxonomyRegistry +from .taxonomy_config import TaxonomyTerm +from .taxonomy_plugin import TaxonomyPlugin + +__all__ = [ + "DefaultSkillPolicy", + "DefaultKeywordResolver", + "SkillPolicy", + "TaxonomyPipeline", + "TaxonomyPlugin", + "TaxonomyRegistry", + "TaxonomyResolver", + "TaxonomyTerm", +] diff --git a/src/google/adk_community/plugins/taxonomy/policy.py b/src/google/adk_community/plugins/taxonomy/policy.py new file mode 100644 index 0000000..3a392c9 --- /dev/null +++ b/src/google/adk_community/plugins/taxonomy/policy.py @@ -0,0 +1,323 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Abstract interfaces for taxonomy resolution and skill policy enforcement.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +import logging +from typing import Any, Optional + +from google.adk.agents.readonly_context import ReadonlyContext +from google.adk.models.llm_request import LlmRequest +from google.adk.skills.models import Skill + +logger = logging.getLogger("google_adk_community." + __name__) + +class TaxonomyResolver(ABC): + """Abstract base class for taxonomy resolution. + + Resolvers analyze context and LLM history to determine which taxonomy + classification domains (e.g. URI strings) are currently active and relevant. + """ + + @abstractmethod + async def resolve_taxonomies( + self, context: ReadonlyContext, llm_request: LlmRequest + ) -> list[str]: + """Resolves active taxonomy domain URIs from context and LLM history. + + Args: + context: The current read-only execution context. + llm_request: The upcoming LLM request holding prompt configurations. + + Returns: + A list of resolved active taxonomy strings/URIs. + """ + pass + + +class TaxonomyPipeline(TaxonomyResolver): + """Executes a sequence of taxonomy resolvers in order (multi-step pipeline). + + This implements a composite/pipeline pattern to merge active taxonomy domains + identified by multiple independent heuristics (e.g. lexical, model-based). + """ + + def __init__(self, resolvers: list[TaxonomyResolver]): + self.resolvers = resolvers + + async def resolve_taxonomies( + self, context: ReadonlyContext, llm_request: LlmRequest + ) -> list[str]: + # Aggregates unique taxonomy domains across all registered resolvers + active_domains: set[str] = set() + for resolver in self.resolvers: + domains = await resolver.resolve_taxonomies(context, llm_request) + if domains: + active_domains.update(domains) + return list(active_domains) + + +class DefaultKeywordResolver(TaxonomyResolver): + """Declarative, configuration-driven keyword/phrase resolver. + + Scans user prompt history for triggering phrases defined directly inside each + taxonomy term's triggers list or alt_labels, resolving active domains natively. + """ + + def __init__(self, registry: Any): + self.registry = registry + + async def resolve_taxonomies(self, context: ReadonlyContext, llm_request: LlmRequest) -> list[str]: + active_domains: set[str] = set() + + for term_id in self.registry.list_ids(): + term = self.registry.get_term(term_id) + if term: + triggers = getattr(term, "triggers", []) + if not triggers and hasattr(term, "model_extra"): + triggers = (term.model_extra or {}).get("triggers", []) + + # Fall back to alt_labels as secondary keyword triggers + if not triggers and hasattr(term, "alt_labels"): + triggers = term.alt_labels + + if triggers: + for turn in llm_request.contents: + for part in turn.parts: + if part.text: + text_upper = part.text.upper() + if any(str(phrase).upper() in text_upper for phrase in triggers): + active_domains.add(term_id) + break + + return list(active_domains) + + +class SkillPolicy(ABC): + """Abstract policy engine determining skill execution permissions and instruction shaping. + + This class defines the interface for two main responsibilities: + 1. Access Control (Authorization): Blocking or permitting skills based on active taxonomies. + 2. Cognitive Steering (Behavioral Shaping): Altering skill instructions, descriptions, + prioritization, and global system prompts to steer agent execution dynamically. + + Implements the Hook Method pattern, providing concrete default pass-throughs + for steering while keeping authorization and core shaping abstract. + """ + + registry: Optional[Any] = None + + @abstractmethod + def is_skill_allowed( + self, + skill: Skill, + context: ReadonlyContext, + active_taxonomies: list[str], + ) -> bool: + """Determines if a skill can be loaded/used under the active taxonomies and context. + + Args: + skill: The target Skill model instance. + context: The read-only interaction context. + active_taxonomies: The list of currently active taxonomy domains. + + Returns: + True if the skill is permitted to run, False otherwise. + """ + pass + + @abstractmethod + def shape_instructions( + self, + skill: Skill, + context: ReadonlyContext, + original_instructions: str, + ) -> str: + """Applies dynamic instruction shaping/guardrails to a skill's instructions. + + Use this to append safety restrictions, enforce compliance constraints, + or adjust operating parameters of a skill before execution. + """ + pass + + def shape_description( + self, + skill: Skill, + context: ReadonlyContext, + original_description: str, + ) -> str: + """Applies dynamic description shaping before the tool reaches the agent. + + This can be used to emphasize specific features of a skill to the LLM or + prune redundant information to fit within context limits. + """ + return original_description + + def shape_system_instruction( + self, + context: ReadonlyContext, + active_taxonomies: list[str], + original_instructions: str, + ) -> str: + """Applies dynamic instruction shaping to the global agent system instructions. + + Use this to dynamically inject directives (e.g. telling the LLM to trigger + certain tools almost by default or prioritize specific workflows) depending + on the current active taxonomy classification. + """ + return original_instructions + + def prioritize_skills( + self, + skills: list[Skill], + context: ReadonlyContext, + active_taxonomies: list[str], + ) -> list[Skill]: + """Prioritizes, reorders, or accentuates skills under the active taxonomy. + + Allows the policy to sort key tools to the top of the available_skills XML list + presented in the prompt, encouraging the LLM to select preferred actions. + """ + return skills + + def shape_skill( + self, + skill: Skill, + context: ReadonlyContext, + shaped_description: Optional[str], + ) -> Skill: + """Prepares and shapes a skill representation for presentation to the agent. + + Defaults to a secure manual reconstruction to prevent accidental leakage of + internal developer/business flags to LLM prompts, but can be overridden by + custom policies to use `model_copy()` or other strategies. + """ + assert skill is not None, "Skill instance cannot be None" + + from google.adk.skills.models import Skill, Frontmatter + extra = getattr(skill.frontmatter, "model_extra", None) or {} + return Skill( + frontmatter=Frontmatter( + name=skill.frontmatter.name, + description=shaped_description, + **extra + ), + instructions=skill.instructions + ) + + +def _get_taxonomy_binds(skill: Skill) -> list[str]: + """Dynamically extracts taxonomy binds, supporting both modified and unmodified core SDKs. + + This utility functions as a robust protocol layer. If the SDK natively supports + frontmatter taxonomy binds, it reads them directly. Otherwise, it falls back to parsing + Pydantic extra fields (since core SDK uses `extra="allow"`), handling variations in + hyphenation/naming conventions. + """ + # Direct attribute access check + if hasattr(skill.frontmatter, "taxonomy_binds"): + return skill.frontmatter.taxonomy_binds + + # Fallback: Read from Pydantic's model_extra dictionary (natively populated because of extra="allow") + extra = getattr(skill.frontmatter, "model_extra", None) or {} + binds = extra.get("taxonomy-binds") or extra.get("taxonomy_binds") or [] + if isinstance(binds, str): + return [binds] + return list(binds) + + +def _interpolate_variables(text: str, active_taxonomies: list[str], registry: Optional[Any]) -> str: + if not text or not registry: + return text + + import re + pattern = r"\{taxonomy:([a-zA-Z0-9_-]+)\}" + + def replace(match): + var_name = match.group(1) + for tax_id in active_taxonomies: + term = registry.get_term(tax_id) + if term: + variables = getattr(term, "variables", {}) + if not variables and hasattr(term, "model_extra"): + variables = (term.model_extra or {}).get("variables", {}) + if variables and var_name in variables: + return str(variables[var_name]) + + logger.warning("Taxonomy variable %r not found under active taxonomies: %s", var_name, active_taxonomies) + return "" + + return re.sub(pattern, replace, text) + + +class DefaultSkillPolicy(SkillPolicy): + """Default skill policy using taxonomy-bind set-intersection matching. + + If a skill has no taxonomy binds defined, it is treated as unrestricted/allowed by default. + If it has binds, at least one bind must intersect with the active taxonomy set. + """ + + def __init__(self, registry: Optional[Any] = None): + self.registry = registry + + def is_skill_allowed( + self, + skill: Skill, + context: ReadonlyContext, + active_taxonomies: list[str], + ) -> bool: + binds = _get_taxonomy_binds(skill) + # Unrestricted skills are always allowed + if not binds: + return True + # Require at least one matching taxonomy between active set and skill binds + return bool(set(binds) & set(active_taxonomies)) + + def shape_instructions( + self, + skill: Skill, + context: ReadonlyContext, + original_instructions: str, + ) -> str: + active_taxonomies = context.state.get("_active_taxonomies") or [] + return _interpolate_variables(original_instructions, active_taxonomies, self.registry) + + def shape_description( + self, + skill: Skill, + context: ReadonlyContext, + original_description: str, + ) -> str: + active_taxonomies = context.state.get("_active_taxonomies") or [] + return _interpolate_variables(original_description, active_taxonomies, self.registry) + + def shape_system_instruction( + self, + context: ReadonlyContext, + active_taxonomies: list[str], + original_instructions: str, + ) -> str: + return _interpolate_variables(original_instructions, active_taxonomies, self.registry) + + def prioritize_skills( + self, + skills: list[Skill], + context: ReadonlyContext, + active_taxonomies: list[str], + ) -> list[Skill]: + # No-op pass-through for default behavior + return skills diff --git a/src/google/adk_community/plugins/taxonomy/taxonomy_config.py b/src/google/adk_community/plugins/taxonomy/taxonomy_config.py new file mode 100644 index 0000000..e82cddd --- /dev/null +++ b/src/google/adk_community/plugins/taxonomy/taxonomy_config.py @@ -0,0 +1,136 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pydantic models for taxonomy configuration parsing.""" + +from __future__ import annotations + +from typing import Optional + +from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic import Field + + +class TaxonomyTerm(BaseModel): + """A single taxonomy term with metadata for validation and LLM disambiguation. + Attributes: + id: (str) + parent_id: (Optional[str]) + name: (str) + definition: (Optional[str]) + alt_labels: (list[str]) + variables: (dict[str, str]) + triggers: (list[str]) + """ + + model_config = ConfigDict(populate_by_name=True) + + id: str + parent_id: Optional[str] = Field(None, alias="parentId") + name: str + definition: Optional[str] = None + alt_labels: list[str] = Field(default_factory=list, alias="altLabels") + variables: dict[str, str] = Field(default_factory=dict) + triggers: list[str] = Field(default_factory=list) + + +class TaxonomyRegistry(BaseModel): + """Central registry for taxonomy term definitions. + + Supported JSON Schemas: + + **Flat Key-Value JSON** (``from_flat_json``): + id: str + parentId: Optional[str] + name: str + definition: Optional[str] + + **JSON-LD with SKOS** (``from_json_ld``): + @context: str + @type: str + @id: str + prefLabel: dict (``{"@value": str, "@language": str}``) + altLabel: list[dict] (``[{"@value": str, "@language": str}]``) + definition: dict (``{"@value": str, "@language": str}``) + broader: Optional[str] + """ + + terms: dict[str, TaxonomyTerm] = {} + + @classmethod + def from_flat_json(cls, data: list[dict]) -> TaxonomyRegistry: + """Parse taxonomy terms from flat key-value JSON format.""" + terms = {} + for item in data: + term = TaxonomyTerm.model_validate(item) + terms[term.id] = term + return cls(terms=terms) + + @classmethod + def from_json_ld(cls, data: list[dict]) -> TaxonomyRegistry: + """Parse JSON-LD SKOS format into TaxonomyRegistry.""" + terms = {} + for item in data: + term_id = item.get("@id") + if not term_id: + continue + + pref_label = item.get("prefLabel", {}) + if isinstance(pref_label, dict): + pref_label = pref_label.get("@value", "") + + definition_raw = item.get("definition", {}) + if isinstance(definition_raw, dict): + definition = definition_raw.get("@value") or None + elif isinstance(definition_raw, str): + definition = definition_raw or None + else: + definition = None + + alt_labels_raw = item.get("altLabel", []) + if not isinstance(alt_labels_raw, list): + alt_labels_raw = [alt_labels_raw] + alt_labels = [ + label.get("@value") + for label in alt_labels_raw + if isinstance(label, dict) and label.get("@value") + ] + + broader = item.get("broader") + variables = item.get("variables", {}) + triggers = item.get("triggers", []) + term = TaxonomyTerm( + id=term_id, + parent_id=broader, + name=pref_label, + definition=definition, + alt_labels=alt_labels, + variables=variables, + triggers=triggers, + ) + terms[term_id] = term + return cls(terms=terms) + + def get_term(self, term_id: str) -> Optional[TaxonomyTerm]: + """Lookup a term by its ID.""" + return self.terms.get(term_id) + + def get_children(self, parent_id: str) -> list[TaxonomyTerm]: + """Get all direct children of a term.""" + return [t for t in self.terms.values() if t.parent_id == parent_id] + + def list_ids(self) -> list[str]: + """List all term IDs in the registry.""" + return list(self.terms.keys()) diff --git a/src/google/adk_community/plugins/taxonomy/taxonomy_plugin.py b/src/google/adk_community/plugins/taxonomy/taxonomy_plugin.py new file mode 100644 index 0000000..7722982 --- /dev/null +++ b/src/google/adk_community/plugins/taxonomy/taxonomy_plugin.py @@ -0,0 +1,265 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TaxonomyPlugin — ADK BasePlugin for pluggable taxonomy policy enforcement.""" + +from __future__ import annotations + +import logging +from pathlib import PurePosixPath, PureWindowsPath +from typing import Any, Optional + +from google.adk.plugins.base_plugin import BasePlugin +from google.adk.agents.callback_context import CallbackContext +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.skills import prompt +from google.adk.tools.base_tool import BaseTool +from google.adk.tools.tool_context import ToolContext + +from .policy import DefaultSkillPolicy +from .policy import SkillPolicy +from .policy import TaxonomyResolver +from .taxonomy_config import TaxonomyRegistry + +logger = logging.getLogger("google_adk_community." + __name__) + +_ACTIVE_TAXONOMIES_STATE_KEY = "_active_taxonomies" + +_SKILL_GATE_TOOLS = frozenset({ + "list_skills", + "load_skill", + "load_skill_resource", + "run_skill_script", +}) + + +class TaxonomyPlugin(BasePlugin): + """Native ADK Plugin enforcing pluggable taxonomy policies.""" + + def __init__( + self, + name: str = "taxonomy_plugin", + *, + taxonomy_registry: Optional[TaxonomyRegistry] = None, + resolver: Optional[TaxonomyResolver] = None, + policy: Optional[SkillPolicy] = None, + ): + super().__init__(name) + self.taxonomy_registry = taxonomy_registry or TaxonomyRegistry() + self.resolver = resolver + self.policy = policy or DefaultSkillPolicy(self.taxonomy_registry) + if self.policy and getattr(self.policy, "registry", None) is None: + try: + self.policy.registry = self.taxonomy_registry + except Exception: + pass + + async def before_model_callback( + self, *, callback_context: CallbackContext, llm_request: LlmRequest + ) -> Optional[LlmResponse]: + """Resolves active taxonomies and stores them in session state.""" + if not self.resolver: + return None + + active_taxonomies = await self.resolver.resolve_taxonomies( + callback_context, llm_request + ) + callback_context.state[_ACTIVE_TAXONOMIES_STATE_KEY] = active_taxonomies + + logger.debug( + "[%s] Resolved active taxonomies: %s", self.name, active_taxonomies + ) + + if self.policy: + orig_instructions = llm_request.config.system_instruction or "" + shaped_instructions = self.policy.shape_system_instruction( + callback_context, active_taxonomies, orig_instructions + ) + if shaped_instructions != orig_instructions: + logger.debug( + "[%s] Active taxonomy dynamic system prompt shaping applied.", + self.name, + ) + llm_request.config.system_instruction = shaped_instructions + + return None + + async def before_tool_callback( + self, + *, + tool: BaseTool, + tool_args: dict[str, Any], + tool_context: ToolContext, + ) -> Optional[dict]: + """Intercepts skill tools to enforce taxonomy policy and path validation.""" + if tool.name not in _SKILL_GATE_TOOLS: + return None + + assert tool is not None, "Intercepted tool cannot be None" + assert isinstance(tool_args, dict), "tool_args must be a dictionary" + assert tool_context is not None, "tool_context cannot be None" + + active_taxonomies = ( + tool_context.state.get(_ACTIVE_TAXONOMIES_STATE_KEY) or [] + ) + + if tool.name == "list_skills": + return self._filter_list_skills(tool, tool_context, active_taxonomies) + + skill_name = tool_args.get("skill_name") + if not skill_name: + return None + + # Inline path validation (avoids importing private _validate_path_segment) + if ( + not skill_name + or "\x00" in skill_name + or "/" in skill_name + or "\\" in skill_name + or skill_name in (".", "..") + or ".." in skill_name.split("/") + ): + return { + "error": f"Invalid skill_name parameter: {skill_name!r}", + "error_code": "INVALID_ARGUMENTS", + } + + file_path = tool_args.get("file_path") + if file_path: + posix_p = PurePosixPath(file_path) + win_p = PureWindowsPath(file_path) + + # Block absolute paths or presence of a drive letter + if posix_p.is_absolute() or win_p.is_absolute() or win_p.drive: + return { + "error": f"Absolute path blocked: {file_path}", + "error_code": "INVALID_ARGUMENTS", + } + + # Block traversal segments + if ".." in posix_p.parts or ".." in win_p.parts: + return { + "error": f"Path traversal attempt blocked: {file_path}", + "error_code": "INVALID_ARGUMENTS", + } + + if self.policy and self.resolver: + toolset = getattr(tool, "_toolset", None) + if toolset: + skill = await toolset._get_or_fetch_skill( + skill_name, tool_context.invocation_id + ) + if skill and not self.policy.is_skill_allowed( + skill, tool_context, active_taxonomies + ): + logger.warning( + "[%s] Skill '%s' blocked by policy. Active taxonomies: %s", + self.name, + skill_name, + active_taxonomies, + ) + return { + "error": ( + f"Access to skill '{skill_name}' is not permitted" + " under active policy constraints." + ), + "error_code": "SKILL_NOT_PERMITTED", + } + + return None + + def _filter_list_skills( + self, tool: BaseTool, tool_context: ToolContext, active_taxonomies: list[str] + ) -> Optional[dict]: + """Filters the list_skills result to only show policy-permitted skills.""" + if not self.policy or not self.resolver: + return None + + toolset = getattr(tool, "_toolset", None) + if not toolset: + return None + + all_skills = toolset._list_skills() + allowed_skills = [ + skill + for skill in all_skills + if self.policy.is_skill_allowed(skill, tool_context, active_taxonomies) + ] + + # Reorder and prioritize skills dynamically + prioritized_skills = self.policy.prioritize_skills( + allowed_skills, tool_context, active_taxonomies + ) + + shaped_skills = [] + for skill in prioritized_skills: + original_desc = skill.frontmatter.description or "" + shaped_desc = self.policy.shape_description(skill, tool_context, original_desc) + new_skill = self.policy.shape_skill(skill, tool_context, shaped_desc) + shaped_skills.append(new_skill) + + logger.debug( + "[%s] Filtered skills: %d/%d visible", + self.name, + len(shaped_skills), + len(all_skills), + ) + return {"result": prompt.format_skills_as_xml(shaped_skills)} + + + async def after_tool_callback( + self, + *, + tool: BaseTool, + tool_args: dict[str, Any], + tool_context: ToolContext, + result: dict, + ) -> Optional[dict]: + """Applies dynamic instruction shaping to load_skill results.""" + if tool.name != "load_skill": + return None + if not self.policy or not self.resolver: + return None + if not isinstance(result, dict) or "instructions" not in result: + return None + + skill_name = tool_args.get("skill_name") + if not skill_name: + return None + + toolset = getattr(tool, "_toolset", None) + if not toolset: + return None + + skill = await toolset._get_or_fetch_skill( + skill_name, tool_context.invocation_id + ) + if not skill: + return None + + shaped_instructions = self.policy.shape_instructions( + skill, tool_context, result["instructions"] + ) + + if shaped_instructions != result["instructions"]: + logger.debug( + "[%s] Shaped instructions for skill '%s'", + self.name, + skill_name, + ) + + shaped_result = dict(result) + shaped_result["instructions"] = shaped_instructions + return shaped_result diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..c4e4f3d --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,36 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from types import ModuleType + +# Pre-emptively mock/patch google.genai.types.AvatarConfig if it's missing or fails to import +try: + import google.genai.types as genai_types + if not hasattr(genai_types, "AvatarConfig"): + from pydantic import BaseModel + class AvatarConfig(BaseModel): + pass + genai_types.AvatarConfig = AvatarConfig +except Exception: + try: + sys.modules["google.genai"] = ModuleType("google.genai") + + from pydantic import BaseModel + class AvatarConfig(BaseModel): + pass + genai_types = sys.modules["google.genai.types"] = ModuleType("google.genai.types") + genai_types.AvatarConfig = AvatarConfig + except Exception: + pass diff --git a/tests/plugins/test_taxonomy_plugin.py b/tests/plugins/test_taxonomy_plugin.py new file mode 100644 index 0000000..913baa7 --- /dev/null +++ b/tests/plugins/test_taxonomy_plugin.py @@ -0,0 +1,444 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the Pluggable Policy & Taxonomy Security Engine in Community. + +This test suite covers taxonomy classification data loading formats, resolver aggregation, +access-control authorization filtering, path validation/traversal prevention, and +cognitive steering/behavioral shaping mechanisms. +""" + +from unittest import mock +import pytest + +from google.adk_community.plugins.taxonomy import DefaultSkillPolicy +from google.adk_community.plugins.taxonomy import SkillPolicy +from google.adk_community.plugins.taxonomy import TaxonomyPipeline +from google.adk_community.plugins.taxonomy import TaxonomyPlugin +from google.adk_community.plugins.taxonomy import TaxonomyRegistry +from google.adk_community.plugins.taxonomy import TaxonomyResolver +from google.adk_community.plugins.taxonomy import TaxonomyTerm +from google.adk_community.plugins.taxonomy.policy import _get_taxonomy_binds +from google.adk.skills.models import Frontmatter +from google.adk.skills.models import Skill + + +def test_taxonomy_term(): + """Tests TaxonomyTerm model instantiation and defaults. + + Ensures taxonomy term instances hold core metadata and instantiate with standard + defaults (like empty alternate labels and no parents). + """ + term = TaxonomyTerm(id="tech", name="Technology", definition="Tech domain") + assert term.id == "tech" + assert term.name == "Technology" + assert term.definition == "Tech domain" + assert term.parent_id is None + assert term.alt_labels == [] + + +def test_registry_flat_json(): + """Tests parsing flat JSON structure into TaxonomyRegistry. + + Verifies that a plain list of objects defining IDs and parent IDs are correctly + loaded and indexed into hierarchical parent-child relationships. + """ + data = [ + { + "id": "eng", + "parentId": None, + "name": "Engineering", + "definition": "Eng dept", + }, + { + "id": "ml", + "parentId": "eng", + "name": "Machine Learning", + "definition": "ML team", + }, + ] + registry = TaxonomyRegistry.from_flat_json(data) + assert len(registry.list_ids()) == 2 + assert "eng" in registry.list_ids() + assert "ml" in registry.list_ids() + + term_eng = registry.get_term("eng") + term_ml = registry.get_term("ml") + assert term_eng.name == "Engineering" + assert term_ml.parent_id == "eng" + + children = registry.get_children("eng") + assert len(children) == 1 + assert children[0].id == "ml" + + +def test_registry_json_ld(): + """Tests parsing JSON-LD SKOS structure into TaxonomyRegistry. + + Validates SKOS standard structure imports, including URI mapping, prefLabel + mapping, altLabel array conversions, and broader relation parsing. + """ + data = [ + { + "@context": "http://w3.org", + "@type": "Concept", + "@id": "https://example.com/eng", + "prefLabel": {"@value": "Engineering", "@language": "en"}, + "definition": {"@value": "Eng dept", "@language": "en"}, + }, + { + "@context": "http://w3.org", + "@type": "Concept", + "@id": "https://example.com/ml", + "prefLabel": {"@value": "Machine Learning", "@language": "en"}, + "altLabel": [{"@value": "ML", "@language": "en"}], + "definition": {"@value": "ML team", "@language": "en"}, + "broader": "https://example.com/eng", + }, + ] + registry = TaxonomyRegistry.from_json_ld(data) + assert len(registry.list_ids()) == 2 + + term_eng = registry.get_term("https://example.com/eng") + term_ml = registry.get_term("https://example.com/ml") + assert term_eng.name == "Engineering" + assert term_ml.parent_id == "https://example.com/eng" + assert term_ml.alt_labels == ["ML"] + + +@pytest.mark.asyncio +async def test_taxonomy_pipeline(): + """Tests pipeline resolution chaining multiple resolvers. + + Ensures that the composite pipeline runs each individual resolver and merges + their outputs into a unique, aggregated active taxonomy list. + """ + + class SimpleResolver(TaxonomyResolver): + + def __init__(self, resolved_domains: list[str]): + self.resolved_domains = resolved_domains + + async def resolve_taxonomies(self, context, llm_request) -> list[str]: + return self.resolved_domains + + context = mock.MagicMock() + llm_request = mock.MagicMock() + + pipeline = TaxonomyPipeline([SimpleResolver(["eng"]), SimpleResolver(["finance"])]) + resolved = await pipeline.resolve_taxonomies(context, llm_request) + assert sorted(resolved) == ["eng", "finance"] + + +def test_default_skill_policy(): + """Tests DefaultSkillPolicy filter mechanism. + + Checks that the default intersection policy correctly authorizes matching skills, + blocks skills with non-overlapping binds, and always allows unrestricted skills. + """ + policy = DefaultSkillPolicy() + + skill_eng = Skill( + frontmatter=Frontmatter( + name="eng-skill", + description="Desc", + taxonomy_binds=["eng"], + ), + instructions="body", + ) + skill_finance = Skill( + frontmatter=Frontmatter( + name="finance-skill", + description="Desc", + taxonomy_binds=["finance"], + ), + instructions="body", + ) + + context = mock.MagicMock() + assert policy.is_skill_allowed(skill_eng, context, ["eng"]) is True + assert policy.is_skill_allowed(skill_finance, context, ["eng"]) is False + assert policy.is_skill_allowed(skill_finance, context, ["eng", "finance"]) is True + + skill_unrestricted = Skill( + frontmatter=Frontmatter(name="any-skill", description="Desc"), + instructions="body", + ) + assert policy.is_skill_allowed(skill_unrestricted, context, ["marketing"]) is True + + assert policy.shape_instructions(skill_eng, context, "original") == "original" + + +@pytest.mark.asyncio +async def test_taxonomy_plugin_list_skills(): + """Tests TaxonomyPlugin intercepts and filters skill lists correctly. + + Verifies that list_skills tool calls are intercepted in before_tool_callback + and that the return payload contains only the policy-allowed skills in valid XML format. + """ + + class RestrictedPolicy(SkillPolicy): + + def is_skill_allowed(self, skill: Skill, context, active_taxonomies: list[str]) -> bool: + binds = _get_taxonomy_binds(skill) + return "eng" in binds + + def shape_instructions(self, skill: Skill, context, original_instructions: str) -> str: + return original_instructions + + mock_resolver = mock.MagicMock() + plugin = TaxonomyPlugin(policy=RestrictedPolicy(), resolver=mock_resolver) + + skills = { + "skill-1": Skill( + frontmatter=Frontmatter( + name="skill-1", + description="Desc", + taxonomy_binds=["eng"], + ), + instructions="body", + ), + "skill-2": Skill( + frontmatter=Frontmatter( + name="skill-2", + description="Desc", + taxonomy_binds=["finance"], + ), + instructions="body", + ), + } + + context = mock.MagicMock() + context.state = {"_active_taxonomies": ["eng"]} + + mock_tool = mock.MagicMock() + mock_tool.name = "list_skills" + mock_tool._toolset._list_skills.return_value = list(skills.values()) + + # Patch XML formatter to focus purely on verifying taxonomy filtration behavior + with mock.patch("google.adk_community.plugins.taxonomy.taxonomy_plugin.prompt.format_skills_as_xml") as mock_format: + mock_format.return_value = "" + + result = await plugin.before_tool_callback( + tool=mock_tool, + tool_args={}, + tool_context=context, + ) + + assert isinstance(result, dict) + assert "result" in result + assert "skill-1" in result["result"] + assert "skill-2" not in result["result"] + + +@pytest.mark.asyncio +async def test_taxonomy_steering_capabilities(): + """Tests prioritizing/sorting skills and injecting global system prompts. + + Verifies cognitive steering hooks: + 1. System Instruction Shaping (injecting dynamic instructions into LLM system prompts). + 2. Skill Prioritization (reordering skills in list_skills results). + """ + + class SteeringPolicy(SkillPolicy): + + def is_skill_allowed(self, skill: Skill, context, active_taxonomies: list[str]) -> bool: + return True + + def shape_instructions(self, skill: Skill, context, original_instructions: str) -> str: + return original_instructions + + def shape_system_instruction(self, context, active_taxonomies: list[str], original_instructions: str) -> str: + if "strict" in active_taxonomies: + return original_instructions + " - MANDATED COMPLIANCE TURN" + return original_instructions + + def prioritize_skills(self, skills: list[Skill], context, active_taxonomies: list[str]) -> list[Skill]: + if "strict" in active_taxonomies: + return sorted(skills, key=lambda s: 0 if s.frontmatter.name == "important" else 1) + return skills + + class MockResolver(TaxonomyResolver): + async def resolve_taxonomies(self, context, llm_request) -> list[str]: + return ["strict"] + + plugin = TaxonomyPlugin(policy=SteeringPolicy(), resolver=MockResolver()) + + # Verify before_model_callback system instruction injection + context = mock.MagicMock() + context.state = {} + llm_request = mock.MagicMock() + llm_request.config = mock.MagicMock() + llm_request.config.system_instruction = "Original Prompt" + + await plugin.before_model_callback(callback_context=context, llm_request=llm_request) + assert context.state["_active_taxonomies"] == ["strict"] + assert llm_request.config.system_instruction == "Original Prompt - MANDATED COMPLIANCE TURN" + + # Verify skill prioritization/sorting in list_skills + skills = [ + Skill(frontmatter=Frontmatter(name="normal", description="Desc"), instructions="body"), + Skill(frontmatter=Frontmatter(name="important", description="Desc"), instructions="body"), + ] + + mock_tool = mock.MagicMock() + mock_tool.name = "list_skills" + mock_tool._toolset._list_skills.return_value = skills + + with mock.patch("google.adk_community.plugins.taxonomy.taxonomy_plugin.prompt.format_skills_as_xml") as mock_format: + await plugin.before_tool_callback( + tool=mock_tool, + tool_args={}, + tool_context=context, + ) + # Check that format_skills_as_xml was called with "important" sorted first + called_skills = mock_format.call_args[0][0] + assert called_skills[0].frontmatter.name == "important" + assert called_skills[1].frontmatter.name == "normal" + + +@pytest.mark.asyncio +async def test_taxonomy_variable_interpolation(): + """Tests that DefaultSkillPolicy correctly interpolates taxonomy variables.""" + taxonomy_data = [ + { + "id": "urn:adk:domain:finance", + "name": "Strict Finance", + "variables": { + "warning": "[PII WARNING]", + "guardrail": "Mask SSN" + } + } + ] + registry = TaxonomyRegistry.from_flat_json(taxonomy_data) + policy = DefaultSkillPolicy(registry) + + skill = Skill( + frontmatter=Frontmatter( + name="audit", + description="Read accounts. {taxonomy:warning}", + taxonomy_binds=["urn:adk:domain:finance"] + ), + instructions="Fetch logs.\n{taxonomy:guardrail}" + ) + + context = mock.MagicMock() + context.state = {"_active_taxonomies": ["urn:adk:domain:finance"]} + + # Test shape_description + desc = policy.shape_description(skill, context, skill.frontmatter.description) + assert desc == "Read accounts. [PII WARNING]" + + # Test shape_instructions + inst = policy.shape_instructions(skill, context, skill.instructions) + assert inst == "Fetch logs.\nMask SSN" + + # Test shape_system_instruction + sys_inst = policy.shape_system_instruction(context, ["urn:adk:domain:finance"], "Start. {taxonomy:warning}") + assert sys_inst == "Start. [PII WARNING]" + + +@pytest.mark.asyncio +async def test_taxonomy_plugin_path_validation(): + """Tests that absolute paths and path traversals are blocked on all platforms.""" + plugin = TaxonomyPlugin() + mock_tool = mock.MagicMock() + mock_tool.name = "load_skill" + context = mock.MagicMock() + context.state = {} + + # Test cases for blocked paths (absolute or traversal) + blocked_cases = [ + "/etc/passwd", + "C:\\Windows\\System32", + "\\\\unc\\share\\file", + "../../traversal", + "subdir\\..\\parent", + ] + + for file_path in blocked_cases: + result = await plugin.before_tool_callback( + tool=mock_tool, + tool_args={"skill_name": "test-skill", "file_path": file_path}, + tool_context=context, + ) + assert isinstance(result, dict) + assert result.get("error_code") == "INVALID_ARGUMENTS" + assert "blocked" in result.get("error", "").lower() + + +@pytest.mark.asyncio +async def test_taxonomy_variable_interpolation_warning(caplog): + """Tests that unresolved taxonomy variables log a warning and fallback to empty string.""" + registry = TaxonomyRegistry(terms={}) + policy = DefaultSkillPolicy(registry) + context = mock.MagicMock() + context.state = {"_active_taxonomies": ["urn:adk:domain:test"]} + + with caplog.at_level("WARNING"): + result = policy.shape_system_instruction( + context, ["urn:adk:domain:test"], "Prompt: {taxonomy:missing_variable}" + ) + + assert result == "Prompt: " + assert len(caplog.records) == 1 + assert "missing_variable" in caplog.records[0].message + + +def test_taxonomy_custom_shape_skill(): + """Tests default sanitization and custom policy shape_skill overriding.""" + + from pydantic import Field + + class ExtendedFrontmatter(Frontmatter): + billing_entitlement: str = Field("premium", alias="billingEntitlement") + custom_flag: bool = True + + original_skill = Skill( + frontmatter=ExtendedFrontmatter( + name="custom-skill", + description="My skill", + billingEntitlement="enterprise", + ), + instructions="Execute tasks" + ) + + policy = DefaultSkillPolicy() + context = mock.MagicMock() + + # It should drop the custom pydantic field billing_entitlement because it's not captured by standard Frontmatter + default_shaped = policy.shape_skill(original_skill, context, "Shaped My skill") + assert default_shaped.frontmatter.name == "custom-skill" + assert default_shaped.frontmatter.description == "Shaped My skill" + # Standard Frontmatter doesn't have custom_flag or billing_entitlement as defined properties + assert not hasattr(default_shaped.frontmatter, "custom_flag") + assert not hasattr(default_shaped.frontmatter, "billing_entitlement") + + # Verify Custom Policy Behavior using model_copy + class CustomCopyPolicy(DefaultSkillPolicy): + def shape_skill(self, skill, context, shaped_description): + new_fm = skill.frontmatter.model_copy(update={"description": shaped_description}) + return skill.model_copy(update={"frontmatter": new_fm}) + + custom_policy = CustomCopyPolicy() + custom_shaped = custom_policy.shape_skill(original_skill, context, "Shaped My skill") + + # Ensure all custom attributes, types, and values are fully preserved! + assert custom_shaped.frontmatter.name == "custom-skill" + assert custom_shaped.frontmatter.description == "Shaped My skill" + assert isinstance(custom_shaped.frontmatter, ExtendedFrontmatter) + assert custom_shaped.frontmatter.billing_entitlement == "enterprise" + assert custom_shaped.frontmatter.custom_flag is True + +