diff --git a/vertexai/agent_engines/templates/langchain.py b/vertexai/agent_engines/templates/langchain.py index 31dabcbdd3..4bcf40f55b 100644 --- a/vertexai/agent_engines/templates/langchain.py +++ b/vertexai/agent_engines/templates/langchain.py @@ -43,11 +43,12 @@ RunnableSerializable = Any try: - from langchain_google_vertexai.functions_utils import _ToolsType - - _ToolsType = _ToolsType + from langchain_google_genai.functions_utils import _ToolsType except ImportError: - _ToolsType = Any + try: + from langchain_google_vertexai.functions_utils import _ToolsType + except ImportError: + _ToolsType = Any try: from opentelemetry.sdk import trace @@ -81,13 +82,15 @@ def _default_runnable_kwargs(has_history: bool) -> Mapping[str, Any]: def _default_output_parser(): try: - from langchain.agents.output_parsers.tools import ToolsAgentOutputParser + from langchain_classic.agents.output_parsers.tools import ToolsAgentOutputParser except (ModuleNotFoundError, ImportError): - # Fallback to an older version if needed. - from langchain.agents.output_parsers.openai_tools import ( - OpenAIToolsAgentOutputParser as ToolsAgentOutputParser, - ) - + try: + from langchain.agents.output_parsers.tools import ToolsAgentOutputParser + except (ModuleNotFoundError, ImportError): + # Fallback to an older version if needed. + from langchain.agents.output_parsers.openai_tools import ( + OpenAIToolsAgentOutputParser as ToolsAgentOutputParser, + ) return ToolsAgentOutputParser() @@ -98,17 +101,29 @@ def _default_model_builder( location: str, model_kwargs: Optional[Mapping[str, Any]] = None, ) -> "BaseLanguageModel": - import vertexai - from google.cloud.aiplatform import initializer - from langchain_google_vertexai import ChatVertexAI - model_kwargs = model_kwargs or {} - current_project = initializer.global_config.project - current_location = initializer.global_config.location - vertexai.init(project=project, location=location) - model = ChatVertexAI(model_name=model_name, **model_kwargs) - vertexai.init(project=current_project, location=current_location) - return model + try: + from langchain_google_genai import ChatGoogleGenerativeAI + + model = ChatGoogleGenerativeAI( + model=model_name, + project=project, + location=location, + vertexai=True, + **model_kwargs, + ) + return model + except ImportError: + import vertexai + from google.cloud.aiplatform import initializer + from langchain_google_vertexai import ChatVertexAI + + current_project = initializer.global_config.project + current_location = initializer.global_config.location + vertexai.init(project=project, location=location) + model = ChatVertexAI(model_name=model_name, **model_kwargs) + vertexai.init(project=current_project, location=current_location) + return model def _default_runnable_builder( @@ -124,8 +139,16 @@ def _default_runnable_builder( runnable_kwargs: Optional[Mapping[str, Any]] = None, ) -> "RunnableSerializable": from langchain_core import tools as lc_tools - from langchain.agents import AgentExecutor - from langchain.tools.base import StructuredTool + + try: + from langchain_classic.agents import AgentExecutor + except ImportError: + from langchain.agents import AgentExecutor + + try: + from langchain_core.tools import StructuredTool + except ImportError: + from langchain.tools.base import StructuredTool # The prompt template and runnable_kwargs needs to be customized depending # on whether the user intends for the agent to have history. The way the @@ -261,12 +284,16 @@ def _default_prompt( from langchain_core import prompts try: - from langchain.agents.format_scratchpad.tools import format_to_tool_messages - except (ModuleNotFoundError, ImportError): - # Fallback to an older version if needed. - from langchain.agents.format_scratchpad.openai_tools import ( - format_to_openai_tool_messages as format_to_tool_messages, + from langchain_classic.agents.format_scratchpad.tools import ( + format_to_tool_messages, ) + except (ModuleNotFoundError, ImportError): + try: + from langchain.agents.format_scratchpad.tools import format_to_tool_messages + except (ModuleNotFoundError, ImportError): + from langchain.agents.format_scratchpad.openai_tools import ( + format_to_openai_tool_messages as format_to_tool_messages, + ) system_instructions = [] if system_instruction: @@ -629,13 +656,18 @@ def query( Returns: The output of querying the Agent with the given input and config. """ - from langchain.load import dump as langchain_load_dump + try: + from langchain_core.load import dumpd + except ImportError: + from langchain.load import dump as langchain_load_dump + + dumpd = langchain_load_dump.dumpd if isinstance(input, str): input = {"input": input} if not self._tmpl_attrs.get("runnable"): self.set_up() - return langchain_load_dump.dumpd( + return dumpd( self._tmpl_attrs.get("runnable").invoke( input=input, config=config, **kwargs ) @@ -662,7 +694,12 @@ def stream_query( Yields: The output of querying the Agent with the given input and config. """ - from langchain.load import dump as langchain_load_dump + try: + from langchain_core.load import dumpd + except ImportError: + from langchain.load import dump as langchain_load_dump + + dumpd = langchain_load_dump.dumpd if isinstance(input, str): input = {"input": input} @@ -673,4 +710,4 @@ def stream_query( config=config, **kwargs, ): - yield langchain_load_dump.dumpd(chunk) + yield dumpd(chunk) diff --git a/vertexai/preview/reasoning_engines/templates/langchain.py b/vertexai/preview/reasoning_engines/templates/langchain.py index eddb105d5b..cb39fb172e 100644 --- a/vertexai/preview/reasoning_engines/templates/langchain.py +++ b/vertexai/preview/reasoning_engines/templates/langchain.py @@ -43,11 +43,12 @@ RunnableSerializable = Any try: - from langchain_google_vertexai.functions_utils import _ToolsType - - _ToolsType = _ToolsType + from langchain_google_genai.functions_utils import _ToolsType except ImportError: - _ToolsType = Any + try: + from langchain_google_vertexai.functions_utils import _ToolsType + except ImportError: + _ToolsType = Any try: from opentelemetry.sdk import trace @@ -81,13 +82,15 @@ def _default_runnable_kwargs(has_history: bool) -> Mapping[str, Any]: def _default_output_parser(): try: - from langchain.agents.output_parsers.tools import ToolsAgentOutputParser + from langchain_classic.agents.output_parsers.tools import ToolsAgentOutputParser except (ModuleNotFoundError, ImportError): - # Fallback to an older version if needed. - from langchain.agents.output_parsers.openai_tools import ( - OpenAIToolsAgentOutputParser as ToolsAgentOutputParser, - ) - + try: + from langchain.agents.output_parsers.tools import ToolsAgentOutputParser + except (ModuleNotFoundError, ImportError): + # Fallback to an older version if needed. + from langchain.agents.output_parsers.openai_tools import ( + OpenAIToolsAgentOutputParser as ToolsAgentOutputParser, + ) return ToolsAgentOutputParser() @@ -98,17 +101,29 @@ def _default_model_builder( location: str, model_kwargs: Optional[Mapping[str, Any]] = None, ) -> "BaseLanguageModel": - import vertexai - from google.cloud.aiplatform import initializer - from langchain_google_vertexai import ChatVertexAI - model_kwargs = model_kwargs or {} - current_project = initializer.global_config.project - current_location = initializer.global_config.location - vertexai.init(project=project, location=location) - model = ChatVertexAI(model_name=model_name, **model_kwargs) - vertexai.init(project=current_project, location=current_location) - return model + try: + from langchain_google_genai import ChatGoogleGenerativeAI + + model = ChatGoogleGenerativeAI( + model=model_name, + project=project, + location=location, + vertexai=True, + **model_kwargs, + ) + return model + except ImportError: + import vertexai + from google.cloud.aiplatform import initializer + from langchain_google_vertexai import ChatVertexAI + + current_project = initializer.global_config.project + current_location = initializer.global_config.location + vertexai.init(project=project, location=location) + model = ChatVertexAI(model_name=model_name, **model_kwargs) + vertexai.init(project=current_project, location=current_location) + return model def _default_runnable_builder( @@ -124,8 +139,16 @@ def _default_runnable_builder( runnable_kwargs: Optional[Mapping[str, Any]] = None, ) -> "RunnableSerializable": from langchain_core import tools as lc_tools - from langchain.agents import AgentExecutor - from langchain.tools.base import StructuredTool + + try: + from langchain_classic.agents import AgentExecutor + except ImportError: + from langchain.agents import AgentExecutor + + try: + from langchain_core.tools import StructuredTool + except ImportError: + from langchain.tools.base import StructuredTool # The prompt template and runnable_kwargs needs to be customized depending # on whether the user intends for the agent to have history. The way the @@ -175,12 +198,16 @@ def _default_prompt( from langchain_core import prompts try: - from langchain.agents.format_scratchpad.tools import format_to_tool_messages - except (ModuleNotFoundError, ImportError): - # Fallback to an older version if needed. - from langchain.agents.format_scratchpad.openai_tools import ( - format_to_openai_tool_messages as format_to_tool_messages, + from langchain_classic.agents.format_scratchpad.tools import ( + format_to_tool_messages, ) + except (ModuleNotFoundError, ImportError): + try: + from langchain.agents.format_scratchpad.tools import format_to_tool_messages + except (ModuleNotFoundError, ImportError): + from langchain.agents.format_scratchpad.openai_tools import ( + format_to_openai_tool_messages as format_to_tool_messages, + ) system_instructions = [] if system_instruction: @@ -605,15 +632,18 @@ def query( Returns: The output of querying the Agent with the given input and config. """ - from langchain.load import dump as langchain_load_dump + try: + from langchain_core.load import dumpd + except ImportError: + from langchain.load import dump as langchain_load_dump + + dumpd = langchain_load_dump.dumpd if isinstance(input, str): input = {"input": input} if not self._runnable: self.set_up() - return langchain_load_dump.dumpd( - self._runnable.invoke(input=input, config=config, **kwargs) - ) + return dumpd(self._runnable.invoke(input=input, config=config, **kwargs)) def stream_query( self, @@ -636,11 +666,16 @@ def stream_query( Yields: The output of querying the Agent with the given input and config. """ - from langchain.load import dump as langchain_load_dump + try: + from langchain_core.load import dumpd + except ImportError: + from langchain.load import dump as langchain_load_dump + + dumpd = langchain_load_dump.dumpd if isinstance(input, str): input = {"input": input} if not self._runnable: self.set_up() for chunk in self._runnable.stream(input=input, config=config, **kwargs): - yield langchain_load_dump.dumpd(chunk) + yield dumpd(chunk)