Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 68 additions & 31 deletions vertexai/agent_engines/templates/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@
RunnableSerializable = Any

try:
from langchain_google_vertexai.functions_utils import _ToolsType

_ToolsType = _ToolsType
from langchain_google_genai.functions_utils import _ToolsType
except ImportError:
_ToolsType = Any
try:
from langchain_google_vertexai.functions_utils import _ToolsType
except ImportError:
_ToolsType = Any

try:
from opentelemetry.sdk import trace
Expand Down Expand Up @@ -81,13 +82,15 @@ def _default_runnable_kwargs(has_history: bool) -> Mapping[str, Any]:

def _default_output_parser():
try:
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
from langchain_classic.agents.output_parsers.tools import ToolsAgentOutputParser
except (ModuleNotFoundError, ImportError):
# Fallback to an older version if needed.
from langchain.agents.output_parsers.openai_tools import (
OpenAIToolsAgentOutputParser as ToolsAgentOutputParser,
)

try:
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
except (ModuleNotFoundError, ImportError):
# Fallback to an older version if needed.
from langchain.agents.output_parsers.openai_tools import (
OpenAIToolsAgentOutputParser as ToolsAgentOutputParser,
)
return ToolsAgentOutputParser()


Expand All @@ -98,17 +101,29 @@ def _default_model_builder(
location: str,
model_kwargs: Optional[Mapping[str, Any]] = None,
) -> "BaseLanguageModel":
import vertexai
from google.cloud.aiplatform import initializer
from langchain_google_vertexai import ChatVertexAI

model_kwargs = model_kwargs or {}
current_project = initializer.global_config.project
current_location = initializer.global_config.location
vertexai.init(project=project, location=location)
model = ChatVertexAI(model_name=model_name, **model_kwargs)
vertexai.init(project=current_project, location=current_location)
return model
try:
from langchain_google_genai import ChatGoogleGenerativeAI

model = ChatGoogleGenerativeAI(
model=model_name,
project=project,
location=location,
vertexai=True,
**model_kwargs,
)
return model
except ImportError:
import vertexai
from google.cloud.aiplatform import initializer
from langchain_google_vertexai import ChatVertexAI

current_project = initializer.global_config.project
current_location = initializer.global_config.location
vertexai.init(project=project, location=location)
model = ChatVertexAI(model_name=model_name, **model_kwargs)
vertexai.init(project=current_project, location=current_location)
return model


def _default_runnable_builder(
Expand All @@ -124,8 +139,16 @@ def _default_runnable_builder(
runnable_kwargs: Optional[Mapping[str, Any]] = None,
) -> "RunnableSerializable":
from langchain_core import tools as lc_tools
from langchain.agents import AgentExecutor
from langchain.tools.base import StructuredTool

try:
from langchain_classic.agents import AgentExecutor
except ImportError:
from langchain.agents import AgentExecutor

try:
from langchain_core.tools import StructuredTool
except ImportError:
from langchain.tools.base import StructuredTool

# The prompt template and runnable_kwargs needs to be customized depending
# on whether the user intends for the agent to have history. The way the
Expand Down Expand Up @@ -261,12 +284,16 @@ def _default_prompt(
from langchain_core import prompts

try:
from langchain.agents.format_scratchpad.tools import format_to_tool_messages
except (ModuleNotFoundError, ImportError):
# Fallback to an older version if needed.
from langchain.agents.format_scratchpad.openai_tools import (
format_to_openai_tool_messages as format_to_tool_messages,
from langchain_classic.agents.format_scratchpad.tools import (
format_to_tool_messages,
)
except (ModuleNotFoundError, ImportError):
try:
from langchain.agents.format_scratchpad.tools import format_to_tool_messages
except (ModuleNotFoundError, ImportError):
from langchain.agents.format_scratchpad.openai_tools import (
format_to_openai_tool_messages as format_to_tool_messages,
)

system_instructions = []
if system_instruction:
Expand Down Expand Up @@ -629,13 +656,18 @@ def query(
Returns:
The output of querying the Agent with the given input and config.
"""
from langchain.load import dump as langchain_load_dump
try:
from langchain_core.load import dumpd
except ImportError:
from langchain.load import dump as langchain_load_dump

dumpd = langchain_load_dump.dumpd

if isinstance(input, str):
input = {"input": input}
if not self._tmpl_attrs.get("runnable"):
self.set_up()
return langchain_load_dump.dumpd(
return dumpd(
self._tmpl_attrs.get("runnable").invoke(
input=input, config=config, **kwargs
)
Expand All @@ -662,7 +694,12 @@ def stream_query(
Yields:
The output of querying the Agent with the given input and config.
"""
from langchain.load import dump as langchain_load_dump
try:
from langchain_core.load import dumpd
except ImportError:
from langchain.load import dump as langchain_load_dump

dumpd = langchain_load_dump.dumpd

if isinstance(input, str):
input = {"input": input}
Expand All @@ -673,4 +710,4 @@ def stream_query(
config=config,
**kwargs,
):
yield langchain_load_dump.dumpd(chunk)
yield dumpd(chunk)
101 changes: 68 additions & 33 deletions vertexai/preview/reasoning_engines/templates/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@
RunnableSerializable = Any

try:
from langchain_google_vertexai.functions_utils import _ToolsType

_ToolsType = _ToolsType
from langchain_google_genai.functions_utils import _ToolsType
except ImportError:
_ToolsType = Any
try:
from langchain_google_vertexai.functions_utils import _ToolsType
except ImportError:
_ToolsType = Any

try:
from opentelemetry.sdk import trace
Expand Down Expand Up @@ -81,13 +82,15 @@ def _default_runnable_kwargs(has_history: bool) -> Mapping[str, Any]:

def _default_output_parser():
try:
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
from langchain_classic.agents.output_parsers.tools import ToolsAgentOutputParser
except (ModuleNotFoundError, ImportError):
# Fallback to an older version if needed.
from langchain.agents.output_parsers.openai_tools import (
OpenAIToolsAgentOutputParser as ToolsAgentOutputParser,
)

try:
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
except (ModuleNotFoundError, ImportError):
# Fallback to an older version if needed.
from langchain.agents.output_parsers.openai_tools import (
OpenAIToolsAgentOutputParser as ToolsAgentOutputParser,
)
return ToolsAgentOutputParser()


Expand All @@ -98,17 +101,29 @@ def _default_model_builder(
location: str,
model_kwargs: Optional[Mapping[str, Any]] = None,
) -> "BaseLanguageModel":
import vertexai
from google.cloud.aiplatform import initializer
from langchain_google_vertexai import ChatVertexAI

model_kwargs = model_kwargs or {}
current_project = initializer.global_config.project
current_location = initializer.global_config.location
vertexai.init(project=project, location=location)
model = ChatVertexAI(model_name=model_name, **model_kwargs)
vertexai.init(project=current_project, location=current_location)
return model
try:
from langchain_google_genai import ChatGoogleGenerativeAI

model = ChatGoogleGenerativeAI(
model=model_name,
project=project,
location=location,
vertexai=True,
**model_kwargs,
)
return model
except ImportError:
import vertexai
from google.cloud.aiplatform import initializer
from langchain_google_vertexai import ChatVertexAI

current_project = initializer.global_config.project
current_location = initializer.global_config.location
vertexai.init(project=project, location=location)
model = ChatVertexAI(model_name=model_name, **model_kwargs)
vertexai.init(project=current_project, location=current_location)
return model


def _default_runnable_builder(
Expand All @@ -124,8 +139,16 @@ def _default_runnable_builder(
runnable_kwargs: Optional[Mapping[str, Any]] = None,
) -> "RunnableSerializable":
from langchain_core import tools as lc_tools
from langchain.agents import AgentExecutor
from langchain.tools.base import StructuredTool

try:
from langchain_classic.agents import AgentExecutor
except ImportError:
from langchain.agents import AgentExecutor

try:
from langchain_core.tools import StructuredTool
except ImportError:
from langchain.tools.base import StructuredTool

# The prompt template and runnable_kwargs needs to be customized depending
# on whether the user intends for the agent to have history. The way the
Expand Down Expand Up @@ -175,12 +198,16 @@ def _default_prompt(
from langchain_core import prompts

try:
from langchain.agents.format_scratchpad.tools import format_to_tool_messages
except (ModuleNotFoundError, ImportError):
# Fallback to an older version if needed.
from langchain.agents.format_scratchpad.openai_tools import (
format_to_openai_tool_messages as format_to_tool_messages,
from langchain_classic.agents.format_scratchpad.tools import (
format_to_tool_messages,
)
except (ModuleNotFoundError, ImportError):
try:
from langchain.agents.format_scratchpad.tools import format_to_tool_messages
except (ModuleNotFoundError, ImportError):
from langchain.agents.format_scratchpad.openai_tools import (
format_to_openai_tool_messages as format_to_tool_messages,
)

system_instructions = []
if system_instruction:
Expand Down Expand Up @@ -605,15 +632,18 @@ def query(
Returns:
The output of querying the Agent with the given input and config.
"""
from langchain.load import dump as langchain_load_dump
try:
from langchain_core.load import dumpd
except ImportError:
from langchain.load import dump as langchain_load_dump

dumpd = langchain_load_dump.dumpd

if isinstance(input, str):
input = {"input": input}
if not self._runnable:
self.set_up()
return langchain_load_dump.dumpd(
self._runnable.invoke(input=input, config=config, **kwargs)
)
return dumpd(self._runnable.invoke(input=input, config=config, **kwargs))

def stream_query(
self,
Expand All @@ -636,11 +666,16 @@ def stream_query(
Yields:
The output of querying the Agent with the given input and config.
"""
from langchain.load import dump as langchain_load_dump
try:
from langchain_core.load import dumpd
except ImportError:
from langchain.load import dump as langchain_load_dump

dumpd = langchain_load_dump.dumpd

if isinstance(input, str):
input = {"input": input}
if not self._runnable:
self.set_up()
for chunk in self._runnable.stream(input=input, config=config, **kwargs):
yield langchain_load_dump.dumpd(chunk)
yield dumpd(chunk)
Loading