diff --git a/README.md b/README.md index f5b27c6..352b53c 100644 --- a/README.md +++ b/README.md @@ -305,6 +305,53 @@ Options: - `--function-invoke-opt TEXT`: Currently we support only `UnstructuredChunking` for functions. +## Testing LLM Gateway + +You can use AI models configured in Salesforce to generate responses while transforming your data. Below is a sample code example: + +``` +from datacustomcode.client import Client, llm_gateway_generate_text_col + + +def main(): + client = Client() + df = client.read_dlo("Input__dll") + df_generated = df.withColumn( + "greeting__c", + llm_gateway_generate_text_col( + "In one sentence, greet {name} from {city}.", + {"name": col("name__c"), "city": col("homecity__c")}, + model_id="sfdc_ai__DefaultGPT4Omni", # An AI model in your org + max_tokens=100, + ), + ) + + dlo_name = "Output_dll" + client.write_to_dlo(dlo_name, df_upper1, write_mode=WriteMode.APPEND) + + greeting = client.llm_gateway_generate_text("In one sentence, generate a greeting message", "sfdc_ai__DefaultGPT52") + +if __name__ == "__main__": + main() +``` + +In order to test this code on your local machine before deploying it to Data Cloud, you must first set up an External Client App that allows access to the Agent API. Follow this guide to create the ECA https://developer.salesforce.com/docs/ai/agentforce/guide/agent-api-get-started.html#create-a-salesforce-app. You must use `http://localhost:1717/OauthRedirect` as the callback URL. + +Once the ECA is set up, log in to your org using this ECA +``` +sf org login web \ + --alias myorg \ + --instance-url https://{MY_DOMAIN_URL} \ + --client-id {CONSUMER_KEY} \ + --scopes "sfap_api api" +``` + +then you can test your code using `myorg` alias +``` +datacustomcode run ./payload/entrypoint.py --sf-cli-org myorg +``` + + ## Docker usage The SDK provides Docker-based development options that allow you to test your code in an environment that closely resembles Data Cloud's execution environment. diff --git a/src/datacustomcode/__init__.py b/src/datacustomcode/__init__.py index 85cfa54..76e24b7 100644 --- a/src/datacustomcode/__init__.py +++ b/src/datacustomcode/__init__.py @@ -17,8 +17,11 @@ "AuthType", "Client", "Credentials", + "DefaultSparkLLMGateway", "PrintDataCloudWriter", "QueryAPIDataCloudReader", + "SparkLLMGateway", + "llm_gateway_generate_text_col", ] @@ -44,4 +47,16 @@ def __getattr__(name: str): from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader return QueryAPIDataCloudReader + elif name == "SparkLLMGateway": + from datacustomcode.llm_gateway import SparkLLMGateway + + return SparkLLMGateway + elif name == "DefaultSparkLLMGateway": + from datacustomcode.llm_gateway import DefaultSparkLLMGateway + + return DefaultSparkLLMGateway + elif name == "llm_gateway_generate_text_col": + from datacustomcode.client import llm_gateway_generate_text_col + + return llm_gateway_generate_text_col raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/datacustomcode/client.py b/src/datacustomcode/client.py index 9ad95be..2cd64c0 100644 --- a/src/datacustomcode/client.py +++ b/src/datacustomcode/client.py @@ -18,24 +18,80 @@ from typing import ( TYPE_CHECKING, ClassVar, + Dict, Optional, + Union, ) from datacustomcode.config import config from datacustomcode.file.path.default import DefaultFindFilePath from datacustomcode.io.reader.base import BaseDataCloudReader +from datacustomcode.llm_gateway_config import spark_llm_gateway_config from datacustomcode.spark.default import DefaultSparkSessionProvider if TYPE_CHECKING: from pathlib import Path - from pyspark.sql import DataFrame as PySparkDataFrame + from pyspark.sql import Column, DataFrame as PySparkDataFrame from datacustomcode.io.reader.base import BaseDataCloudReader from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode + from datacustomcode.llm_gateway.spark_base import SparkLLMGateway from datacustomcode.spark.base import BaseSparkSessionProvider +def _build_spark_llm_gateway() -> "SparkLLMGateway": + """Instantiate the SDK-configured :class:`SparkLLMGateway`. + + Raises: + RuntimeError: If no ``spark_llm_gateway_config`` has been loaded. + """ + cfg = spark_llm_gateway_config.spark_llm_gateway_config + if cfg is None: + raise RuntimeError( + "spark_llm_gateway_config is not configured. Add a " + "'spark_llm_gateway_config' section to config.yaml." + ) + return cfg.to_object() + + +def llm_gateway_generate_text_col( + template: str, + values: Union[Dict[str, "Column"], "Column"], + model_id: Optional[str] = None, + max_tokens: Optional[int] = None, +) -> "Column": + """Build a Spark Column that runs the LLM Gateway per row. + + Example: + + >>> df.withColumn( + ... "greeting__c", + ... llm_gateway_generate_text_col( + ... "In one sentence, greet {name} from {city}.", + ... {"name": col("name__c"), "city": col("homecity__c")}, + ... model_id="sfdc_ai__DefaultGPT4Omni", + ... max_tokens=100, + ... ), + ... ) + + Args: + template: The prompt template, with ``{field}`` placeholders matching + keys in ``values``. Substitution uses ``str.format``. + values: Either a mapping from placeholder name to Spark ``Column``, or + a single ``Column`` whose value is already a struct. + model_id: LLM model id. Defaults to ``sfdc_ai__DefaultGPT4Omni``. + max_tokens: Maximum tokens to generate. Defaults to 200. + + Returns: + A Spark ``Column`` that, when evaluated, produces the generated text. + """ + gateway = Client()._get_spark_llm_gateway() + return gateway.llm_gateway_generate_text_col( + template, values, model_id=model_id, max_tokens=max_tokens + ) + + class DataCloudObjectType(Enum): DLO = "dlo" DMO = "dmo" @@ -94,18 +150,23 @@ class Client: finder: Find a file path reader: A custom reader to use for reading Data Cloud objects. writer: A custom writer to use for writing Data Cloud objects. + spark_llm_gateway: Optional custom :class:`SparkLLMGateway`. When + omitted, the gateway is lazily resolved from + ``spark_llm_gateway_config``. Example: >>> client = Client() >>> file_path = client.find_file_path("data.csv") >>> dlo = client.read_dlo("my_dlo") >>> client.write_to_dmo("my_dmo", dlo) + >>> answer = client.llm_gateway_generate_text("Generate a greeting message") """ _instance: ClassVar[Optional[Client]] = None _reader: BaseDataCloudReader _writer: BaseDataCloudWriter _file: DefaultFindFilePath + _spark_llm_gateway: Optional[SparkLLMGateway] _data_layer_history: dict[DataCloudObjectType, set[str]] _code_type: str @@ -114,11 +175,13 @@ def __new__( reader: Optional[BaseDataCloudReader] = None, writer: Optional[BaseDataCloudWriter] = None, spark_provider: Optional[BaseSparkSessionProvider] = None, + spark_llm_gateway: Optional[SparkLLMGateway] = None, code_type: str = "script", ) -> Client: if cls._instance is None: cls._instance = super().__new__(cls) + cls._instance._spark_llm_gateway = spark_llm_gateway # Initialize Readers and Writers from config # and/or provided reader and writer if reader is None or writer is None: @@ -225,6 +288,44 @@ def find_file_path(self, file_name: str) -> Path: return self._file.find_file_path(file_name) # type: ignore[no-any-return] + def llm_gateway_generate_text( + self, + prompt: str, + model_id: Optional[str] = None, + max_tokens: Optional[int] = None, + ) -> str: + """Issue a one-shot LLM Gateway call. This is the scalar counterpart to + :func:`llm_gateway_generate_text_col`: it runs **once** — not per row. + Use the column helper method instead when you want to fan a prompt out across + every row of a DataFrame. + + Example: + + >>> response = Client().llm_gateway_generate_text( + ... "Generate a greeting message" + ... ) + + Args: + prompt: The literal prompt to send. Plain text — no + ``{field}`` substitution is performed on this string. + model_id: LLM model id to target. Defaults to + ``sfdc_ai__DefaultGPT4Omni`` when ``None``. + max_tokens: Hard upper bound on the number of tokens the model + may generate. Defaults to 200 when ``None``. + + Returns: + The generated text as a plain Python ``str``; empty when the + gateway response carries no generated text. + """ + return self._get_spark_llm_gateway().llm_gateway_generate_text( + prompt, model_id=model_id, max_tokens=max_tokens + ) + + def _get_spark_llm_gateway(self) -> SparkLLMGateway: + if self._spark_llm_gateway is None: + self._spark_llm_gateway = _build_spark_llm_gateway() + return self._spark_llm_gateway + def _validate_data_layer_history_does_not_contain( self, data_cloud_object_type: DataCloudObjectType ) -> None: diff --git a/src/datacustomcode/config.yaml b/src/datacustomcode/config.yaml index 8a6c334..25d233c 100644 --- a/src/datacustomcode/config.yaml +++ b/src/datacustomcode/config.yaml @@ -28,3 +28,6 @@ llm_gateway_config: type_config_name: DefaultLLMGateway options: credentials_profile: default + +spark_llm_gateway_config: + type_config_name: DefaultSparkLLMGateway diff --git a/src/datacustomcode/einstein_platform_config.py b/src/datacustomcode/einstein_platform_config.py index 135809d..1be4c4d 100644 --- a/src/datacustomcode/einstein_platform_config.py +++ b/src/datacustomcode/einstein_platform_config.py @@ -15,20 +15,23 @@ from typing import ( ClassVar, + Generic, Optional, Type, - cast, + TypeVar, ) from datacustomcode.common_config import BaseObjectConfig +_T = TypeVar("_T") -class CredentialsObjectConfig(BaseObjectConfig): + +class CredentialsObjectConfig(BaseObjectConfig, Generic[_T]): type_to_create: ClassVar[Type] credentials_profile: Optional[str] = None sf_cli_org: Optional[str] = None - def to_object(self): + def to_object(self) -> _T: """Create an object instance, automatically including credentials in options""" options = self.options.copy() @@ -38,4 +41,5 @@ def to_object(self): options["sf_cli_org"] = self.sf_cli_org type_ = self.type_to_create.subclass_from_config_name(self.type_config_name) - return cast(type_, type_(**options)) + instance: _T = type_(**options) + return instance diff --git a/src/datacustomcode/einstein_predictions_config.py b/src/datacustomcode/einstein_predictions_config.py index 1b4758f..cfd12d5 100644 --- a/src/datacustomcode/einstein_predictions_config.py +++ b/src/datacustomcode/einstein_predictions_config.py @@ -28,7 +28,7 @@ _E = TypeVar("_E", bound=EinsteinPredictions) -class EinsteinPredictionsObjectConfig(CredentialsObjectConfig, Generic[_E]): +class EinsteinPredictionsObjectConfig(CredentialsObjectConfig[_E], Generic[_E]): type_to_create: ClassVar[Type[EinsteinPredictions]] = EinsteinPredictions # type: ignore[type-abstract] diff --git a/src/datacustomcode/function/feature_types/chunking.py b/src/datacustomcode/function/feature_types/chunking.py index 31a1ccf..994489e 100644 --- a/src/datacustomcode/function/feature_types/chunking.py +++ b/src/datacustomcode/function/feature_types/chunking.py @@ -16,6 +16,7 @@ """ Pydantic models for Search Index Chunking V1 """ + from enum import Enum from typing import ( Dict, diff --git a/src/datacustomcode/llm_gateway/__init__.py b/src/datacustomcode/llm_gateway/__init__.py index ea8b4af..bad6f42 100644 --- a/src/datacustomcode/llm_gateway/__init__.py +++ b/src/datacustomcode/llm_gateway/__init__.py @@ -15,8 +15,12 @@ from datacustomcode.llm_gateway.base import LLMGateway from datacustomcode.llm_gateway.default import DefaultLLMGateway +from datacustomcode.llm_gateway.spark_base import SparkLLMGateway +from datacustomcode.llm_gateway.spark_default import DefaultSparkLLMGateway __all__ = [ "DefaultLLMGateway", + "DefaultSparkLLMGateway", "LLMGateway", + "SparkLLMGateway", ] diff --git a/src/datacustomcode/llm_gateway/default.py b/src/datacustomcode/llm_gateway/default.py index 54f5105..d3951b7 100644 --- a/src/datacustomcode/llm_gateway/default.py +++ b/src/datacustomcode/llm_gateway/default.py @@ -34,6 +34,8 @@ def generate_text(self, request: GenerateTextRequest) -> GenerateTextResponse: payload: Dict[str, Any] = {"prompt": request.prompt} + if request.max_tokens is not None: + payload["max_tokens"] = request.max_tokens if request.localization: payload["localization"] = request.localization if request.tags: diff --git a/src/datacustomcode/llm_gateway/spark_base.py b/src/datacustomcode/llm_gateway/spark_base.py new file mode 100644 index 0000000..88b9734 --- /dev/null +++ b/src/datacustomcode/llm_gateway/spark_base.py @@ -0,0 +1,55 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Optional, + Union, +) + +from datacustomcode.mixin import UserExtendableNamedConfigMixin + +if TYPE_CHECKING: + from pyspark.sql import Column + + +class SparkLLMGateway(ABC, UserExtendableNamedConfigMixin): + CONFIG_NAME: str + + def __init__(self, **kwargs: Any) -> None: + pass + + @abstractmethod + def llm_gateway_generate_text( + self, + prompt: str, + model_id: Optional[str] = None, + max_tokens: Optional[int] = None, + ) -> str: + """Issue a one-shot LLM Gateway call and return the generated text.""" + + @abstractmethod + def llm_gateway_generate_text_col( + self, + template: str, + values: Union[Dict[str, "Column"], "Column"], + model_id: Optional[str] = None, + max_tokens: Optional[int] = None, + ) -> "Column": + """Build a Spark ``Column`` that invokes the LLM Gateway per row.""" diff --git a/src/datacustomcode/llm_gateway/spark_default.py b/src/datacustomcode/llm_gateway/spark_default.py new file mode 100644 index 0000000..f963055 --- /dev/null +++ b/src/datacustomcode/llm_gateway/spark_default.py @@ -0,0 +1,119 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Optional, + Union, +) + +from datacustomcode.llm_gateway.spark_base import SparkLLMGateway + +if TYPE_CHECKING: + from pyspark.sql import Column + + from datacustomcode.llm_gateway.base import LLMGateway + + +_DEFAULT_LLM_MODEL_ID = "sfdc_ai__DefaultGPT4Omni" +_DEFAULT_LLM_MAX_TOKENS = 200 + + +class DefaultSparkLLMGateway(SparkLLMGateway): + + CONFIG_NAME = "DefaultSparkLLMGateway" + + def __init__( + self, + llm_gateway: Optional["LLMGateway"] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if llm_gateway is None: + llm_gateway = _build_underlying_gateway() + self._llm_gateway: "LLMGateway" = llm_gateway + + def llm_gateway_generate_text( + self, + prompt: str, + model_id: Optional[str] = None, + max_tokens: Optional[int] = None, + ) -> str: + return _invoke_llm_gateway(self._llm_gateway, prompt, model_id, max_tokens) + + def llm_gateway_generate_text_col( + self, + template: str, + values: Union[Dict[str, "Column"], "Column"], + model_id: Optional[str] = None, + max_tokens: Optional[int] = None, + ) -> "Column": + + from pyspark.sql.functions import struct, udf + from pyspark.sql.types import StringType + + if isinstance(values, dict): + values_col = struct(*[v.alias(k) for k, v in values.items()]) + else: + values_col = values + + gateway = self._llm_gateway + + def _generate(values_row: Any) -> str: + if values_row is None: + return "" + subs = ( + values_row.asDict() + if hasattr(values_row, "asDict") + else dict(values_row) + ) + prompt = template.format(**subs) + return _invoke_llm_gateway(gateway, prompt, model_id, max_tokens) + + return udf(_generate, StringType())(values_col) + + +def _build_underlying_gateway() -> "LLMGateway": + from datacustomcode.llm_gateway_config import llm_gateway_config + + cfg = llm_gateway_config.llm_gateway_config + if cfg is None: + raise RuntimeError( + "llm_gateway_config is not configured. Add an 'llm_gateway_config' " + "section to config.yaml." + ) + return cfg.to_object() + + +def _invoke_llm_gateway( + gateway: "LLMGateway", + prompt: str, + model_id: Optional[str], + max_tokens: Optional[int], +) -> str: + from datacustomcode.llm_gateway.types.generate_text_request_builder import ( + GenerateTextRequestBuilder, + ) + + builder = ( + GenerateTextRequestBuilder() + .set_prompt(prompt) + .set_model(model_id or _DEFAULT_LLM_MODEL_ID) + .set_max_tokens(max_tokens or _DEFAULT_LLM_MAX_TOKENS) + ) + return gateway.generate_text(builder.build()).text diff --git a/src/datacustomcode/llm_gateway/types/generate_text_request.py b/src/datacustomcode/llm_gateway/types/generate_text_request.py index a846ce0..c9098cb 100644 --- a/src/datacustomcode/llm_gateway/types/generate_text_request.py +++ b/src/datacustomcode/llm_gateway/types/generate_text_request.py @@ -40,6 +40,13 @@ class GenerateTextRequest(BaseModel): ) model_name: str = Field(..., min_length=1, description="Name of the model to use") prompt: str = Field(..., description="Input prompt") + max_tokens: Optional[int] = Field( + default=None, + ge=1, + description=( + "Maximum number of tokens to generate. If None, server default applies." + ), + ) localization: Optional[Dict[str, Any]] = Field( default=None, description="Localization settings" ) diff --git a/src/datacustomcode/llm_gateway/types/generate_text_request_builder.py b/src/datacustomcode/llm_gateway/types/generate_text_request_builder.py index f9fb461..d707943 100644 --- a/src/datacustomcode/llm_gateway/types/generate_text_request_builder.py +++ b/src/datacustomcode/llm_gateway/types/generate_text_request_builder.py @@ -26,6 +26,7 @@ class GenerateTextRequestBuilder: def __init__(self) -> None: self._prompt = "" self._model_name = "" + self._max_tokens: Optional[int] = None self._localization: Optional[Dict[str, Any]] = None self._tags: Optional[Dict[str, Any]] = None @@ -37,6 +38,10 @@ def set_model(self, model_name: str) -> "GenerateTextRequestBuilder": self._model_name = model_name return self + def set_max_tokens(self, max_tokens: int) -> "GenerateTextRequestBuilder": + self._max_tokens = max_tokens + return self + def set_localization( self, localization: Optional[Dict[str, Any]] = None, @@ -75,6 +80,7 @@ def build(self) -> GenerateTextRequest: request = GenerateTextRequest( prompt=self._prompt, model_name=self._model_name, + max_tokens=self._max_tokens, localization=self._localization, tags=self._tags, ) diff --git a/src/datacustomcode/llm_gateway_config.py b/src/datacustomcode/llm_gateway_config.py index a65d0eb..3b4ebb5 100644 --- a/src/datacustomcode/llm_gateway_config.py +++ b/src/datacustomcode/llm_gateway_config.py @@ -21,14 +21,20 @@ Union, ) -from datacustomcode.common_config import BaseConfig, default_config_file +from datacustomcode.common_config import ( + BaseConfig, + BaseObjectConfig, + default_config_file, +) from datacustomcode.einstein_platform_config import CredentialsObjectConfig from datacustomcode.llm_gateway.base import LLMGateway +from datacustomcode.llm_gateway.spark_base import SparkLLMGateway _E = TypeVar("_E", bound=LLMGateway) +_S = TypeVar("_S", bound=SparkLLMGateway) -class LLMGatewayObjectConfig(CredentialsObjectConfig, Generic[_E]): +class LLMGatewayObjectConfig(CredentialsObjectConfig[_E], Generic[_E]): type_to_create: ClassVar[Type[LLMGateway]] = LLMGateway # type: ignore[type-abstract] @@ -52,6 +58,41 @@ def merge( return self +class SparkLLMGatewayObjectConfig(BaseObjectConfig, Generic[_S]): + type_to_create: ClassVar[Type[SparkLLMGateway]] = SparkLLMGateway # type: ignore[type-abstract] + + def to_object(self) -> SparkLLMGateway: + type_ = self.type_to_create.subclass_from_config_name(self.type_config_name) + return type_(**self.options) + + +class SparkLLMGatewayConfig(BaseConfig): + spark_llm_gateway_config: Union[ + SparkLLMGatewayObjectConfig[SparkLLMGateway], None + ] = None + + def update(self, other: "SparkLLMGatewayConfig") -> "SparkLLMGatewayConfig": + def merge( + config_a: Union[SparkLLMGatewayObjectConfig, None], + config_b: Union[SparkLLMGatewayObjectConfig, None], + ) -> Union[SparkLLMGatewayObjectConfig, None]: + if config_a is not None and config_a.force: + return config_a + if config_b: + return config_b + return config_a + + self.spark_llm_gateway_config = merge( + self.spark_llm_gateway_config, other.spark_llm_gateway_config + ) + return self + + # Global LLM Gateway config instance llm_gateway_config = LLMGatewayConfig() llm_gateway_config.load(default_config_file()) + + +# Global Spark LLM Gateway config instance +spark_llm_gateway_config = SparkLLMGatewayConfig() +spark_llm_gateway_config.load(default_config_file()) diff --git a/src/datacustomcode/proxy/__init__.py b/src/datacustomcode/proxy/__init__.py deleted file mode 100644 index 93988ff..0000000 --- a/src/datacustomcode/proxy/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) 2025, Salesforce, Inc. -# SPDX-License-Identifier: Apache-2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/src/datacustomcode/proxy/base.py b/src/datacustomcode/proxy/base.py deleted file mode 100644 index 71cf314..0000000 --- a/src/datacustomcode/proxy/base.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) 2025, Salesforce, Inc. -# SPDX-License-Identifier: Apache-2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -from abc import ABC - -from datacustomcode.mixin import UserExtendableNamedConfigMixin - - -class BaseProxyAccessLayer(ABC, UserExtendableNamedConfigMixin): - def __init__(self): - pass diff --git a/src/datacustomcode/proxy/client/__init__.py b/src/datacustomcode/proxy/client/__init__.py deleted file mode 100644 index 93988ff..0000000 --- a/src/datacustomcode/proxy/client/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) 2025, Salesforce, Inc. -# SPDX-License-Identifier: Apache-2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/src/datacustomcode/proxy/client/base.py b/src/datacustomcode/proxy/client/base.py deleted file mode 100644 index 85e304a..0000000 --- a/src/datacustomcode/proxy/client/base.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) 2025, Salesforce, Inc. -# SPDX-License-Identifier: Apache-2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -from abc import abstractmethod - -from datacustomcode.proxy.base import BaseProxyAccessLayer - - -class BaseProxyClient(BaseProxyAccessLayer): - def __init__(self): - pass - - @abstractmethod - def call_llm_gateway(self, llmModelId: str, prompt: str, maxTokens: int) -> str: ... - - @abstractmethod - def llm_gateway_generate_text( - self, template, values, llmModelId: str, maxTokens: int - ): ... diff --git a/src/datacustomcode/templates/function/example/chunking_with_llm/entrypoint.py b/src/datacustomcode/templates/function/example/chunking_with_llm/entrypoint.py index 0a5dbb3..2502cdf 100644 --- a/src/datacustomcode/templates/function/example/chunking_with_llm/entrypoint.py +++ b/src/datacustomcode/templates/function/example/chunking_with_llm/entrypoint.py @@ -7,6 +7,9 @@ - Requires Runtime parameter (for agentic capabilities) - Type-safe with direct field access (no wrappers) - Automatic validation and conversion + +You can use your AI models configured in Salesforce to generate texts. +See README.md for how to test locally before deploying to Data Cloud. """ import logging diff --git a/src/datacustomcode/templates/function/example/chunking_with_prediction/entrypoint.py b/src/datacustomcode/templates/function/example/chunking_with_prediction/entrypoint.py index df9b780..d7fe2fa 100644 --- a/src/datacustomcode/templates/function/example/chunking_with_prediction/entrypoint.py +++ b/src/datacustomcode/templates/function/example/chunking_with_prediction/entrypoint.py @@ -12,6 +12,9 @@ Type: Regression Input: Year_Built__c (numeric) Output: Predicted_SalePrice + +You can use your AI models configured in Salesforce to make predictions. +See README.md for how to test locally before deploying to Data Cloud. """ import logging diff --git a/src/datacustomcode/templates/function/payload/entrypoint.py b/src/datacustomcode/templates/function/payload/entrypoint.py index a1cd685..5231174 100644 --- a/src/datacustomcode/templates/function/payload/entrypoint.py +++ b/src/datacustomcode/templates/function/payload/entrypoint.py @@ -65,13 +65,9 @@ def make_einstein_prediction(runtime: Runtime) -> None: ) -def generate_text(runtime: Runtime): +def generate_text(runtime: Runtime, prompt: str, model: str = "sfdc_ai__DefaultGPT52"): builder = GenerateTextRequestBuilder() - llm_request = ( - builder.set_prompt("Generate 2 dog names") - .set_model("sfdc_ai__DefaultGPT52") - .build() - ) + llm_request = builder.set_prompt(prompt).set_model(model).build() llm_response = runtime.llm_gateway.generate_text(llm_request) logger.info( f"LLM Gateway generate text results - success: [{llm_response.is_success}] " @@ -88,13 +84,16 @@ def function(request: dict, runtime: Runtime) -> dict: current_seq_no = 1 # Start sequence number from 1 """ - You can use your AI models configured in Salesforce - to generate texts or predict an outcome. - First configure an external client app before using these AI APIs - https://developer.salesforce.com/docs/ai/agentforce/guide/agent-api-get-started.html#create-a-salesforce-app" + You can use your AI models configured in Salesforce to generate texts + or predict an outcome. See README.md for how to test locally before + deploying to Data Cloud. + + Example: + + >>> generated_text = generate_text(runtime, "Generate a greeting message") + ... prediction = make_einstein_prediction(runtime) + """ - # generate_text(runtime) - # make_einstein_prediction(runtime) for item in items: # Item is DocElement as dict diff --git a/src/datacustomcode/templates/script/payload/entrypoint.py b/src/datacustomcode/templates/script/payload/entrypoint.py index 10ba1d7..6287173 100644 --- a/src/datacustomcode/templates/script/payload/entrypoint.py +++ b/src/datacustomcode/templates/script/payload/entrypoint.py @@ -12,6 +12,33 @@ def main(): # Perform transformations on the DataFrame df_upper1 = df.withColumn("description__c", upper(col("description__c"))) + """ + You can use your AI models configured in Salesforce to generate column + values. See README.md for how to test locally before deploying to Data Cloud. + + Example: + + >>> from datacustomcode.client import llm_gateway_generate_text_col + df_generated = df.withColumn( + ... "greeting__c", + ... llm_gateway_generate_text_col( + ... "In one sentence, greet {name} from {city}.", + ... {"name": col("name__c"), "city": col("homecity__c")}, + ... model_id="sfdc_ai__DefaultGPT4Omni", + ... max_tokens=100, + ... ), + ... ) + + You can also invoke the LLM with a literal plain text prompt — no + ``{field}`` substitution is performed on this string. + + Example: + + >>> generated_text = client.llm_gateway_generate_text( + ... prompt, model_id, max_tokens + ... ) + """ + # Drop specific columns related to relationships df_upper1 = df_upper1.drop("sfdcorganizationid__c") df_upper1 = df_upper1.drop("kq_id__c") diff --git a/tests/test_client.py b/tests/test_client.py index c2cf46a..e0382af 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -9,6 +9,7 @@ Client, DataCloudAccessLayerException, DataCloudObjectType, + llm_gateway_generate_text_col, ) from datacustomcode.config import ( AccessLayerObjectConfig, @@ -253,6 +254,92 @@ def test_read_pattern_flow(self, reset_client, mock_spark): assert "source_dmo" in client._data_layer_history[DataCloudObjectType.DMO] +class TestClientLlmGatewayGenerateText: + + @patch("datacustomcode.client._build_spark_llm_gateway") + def test_forwards_args_to_spark_llm_gateway(self, mock_build_gateway, reset_client): + mock_spark_gateway = MagicMock() + mock_spark_gateway.llm_gateway_generate_text.return_value = "reply" + mock_build_gateway.return_value = mock_spark_gateway + + reader = MagicMock(spec=BaseDataCloudReader) + writer = MagicMock(spec=BaseDataCloudWriter) + client = Client(reader=reader, writer=writer) + + result = client.llm_gateway_generate_text( + "ping", model_id="test-model", max_tokens=42 + ) + + assert result == "reply" + mock_spark_gateway.llm_gateway_generate_text.assert_called_once_with( + "ping", model_id="test-model", max_tokens=42 + ) + + @patch("datacustomcode.client._build_spark_llm_gateway") + def test_gateway_is_built_lazily_and_cached(self, mock_build_gateway, reset_client): + mock_spark_gateway = MagicMock() + mock_spark_gateway.llm_gateway_generate_text.return_value = "ok" + mock_build_gateway.return_value = mock_spark_gateway + + reader = MagicMock(spec=BaseDataCloudReader) + writer = MagicMock(spec=BaseDataCloudWriter) + client = Client(reader=reader, writer=writer) + + mock_build_gateway.assert_not_called() + + client.llm_gateway_generate_text("a") + client.llm_gateway_generate_text("b") + + mock_build_gateway.assert_called_once_with() + assert mock_spark_gateway.llm_gateway_generate_text.call_count == 2 + + @patch("datacustomcode.client._build_spark_llm_gateway") + def test_uses_injected_spark_llm_gateway_without_config_lookup( + self, mock_build_gateway, reset_client + ): + injected = MagicMock() + injected.llm_gateway_generate_text.return_value = "from-injected" + + reader = MagicMock(spec=BaseDataCloudReader) + writer = MagicMock(spec=BaseDataCloudWriter) + client = Client(reader=reader, writer=writer, spark_llm_gateway=injected) + + result = client.llm_gateway_generate_text("hello") + + assert result == "from-injected" + injected.llm_gateway_generate_text.assert_called_once_with( + "hello", model_id=None, max_tokens=None + ) + mock_build_gateway.assert_not_called() + + +class TestLLMGatewayGenerateTextCol: + """The module-level ``llm_gateway_generate_text_col`` is a thin wrapper + that resolves the client-owned :class:`SparkLLMGateway` and delegates. + """ + + @patch("datacustomcode.client._build_spark_llm_gateway") + def test_delegates_to_spark_llm_gateway(self, mock_build_gateway): + mock_spark_gateway = MagicMock() + sentinel_col = MagicMock(name="col") + mock_spark_gateway.llm_gateway_generate_text_col.return_value = sentinel_col + mock_build_gateway.return_value = mock_spark_gateway + + reader = MagicMock(spec=BaseDataCloudReader) + writer = MagicMock(spec=BaseDataCloudWriter) + Client(reader=reader, writer=writer) + + values = {"name": MagicMock()} + result = llm_gateway_generate_text_col( + "Greet {name}", values, model_id="m", max_tokens=7 + ) + + assert result is sentinel_col + mock_spark_gateway.llm_gateway_generate_text_col.assert_called_once_with( + "Greet {name}", values, model_id="m", max_tokens=7 + ) + + # Add tests for DefaultSparkSessionProvider class TestDefaultSparkSessionProvider: diff --git a/tests/test_llm_gateway.py b/tests/test_llm_gateway.py index 7875e21..e4a9f5f 100644 --- a/tests/test_llm_gateway.py +++ b/tests/test_llm_gateway.py @@ -53,6 +53,22 @@ def test_accepts_camel_case_input(self): request = GenerateTextRequest(modelName="gpt-4", prompt="Hello") assert request.model_name == "gpt-4" + def test_max_tokens_defaults_to_none(self): + """max_tokens is optional and defaults to None (server default applies).""" + request = GenerateTextRequest(model_name="gpt-4", prompt="Hello") + assert request.max_tokens is None + + def test_max_tokens_accepts_int(self): + """max_tokens accepts a positive int.""" + request = GenerateTextRequest(model_name="gpt-4", prompt="Hello", max_tokens=50) + assert request.max_tokens == 50 + + def test_max_tokens_must_be_positive(self): + """max_tokens is constrained to >= 1.""" + with pytest.raises(ValidationError) as exc_info: + GenerateTextRequest(model_name="gpt-4", prompt="Hello", max_tokens=0) + assert "max_tokens" in str(exc_info.value) or "maxTokens" in str(exc_info.value) + class TestGenerateTextRequestBuilder: """Test GenerateTextRequestBuilder.""" @@ -100,6 +116,20 @@ def test_builder_with_tags(self): request = builder.set_prompt("Hello").set_model("gpt-4").set_tags(tags).build() assert request.tags == tags + def test_builder_with_max_tokens(self): + """set_max_tokens propagates to the built request.""" + builder = GenerateTextRequestBuilder() + request = ( + builder.set_prompt("Hello").set_model("gpt-4").set_max_tokens(123).build() + ) + assert request.max_tokens == 123 + + def test_builder_default_max_tokens_is_none(self): + """Omitting set_max_tokens leaves max_tokens as None on the request.""" + builder = GenerateTextRequestBuilder() + request = builder.set_prompt("Hello").set_model("gpt-4").build() + assert request.max_tokens is None + def test_builder_validates_on_build(self): """Test builder validates request on build.""" builder = GenerateTextRequestBuilder() diff --git a/tests/test_spark_llm_gateway.py b/tests/test_spark_llm_gateway.py new file mode 100644 index 0000000..7403092 --- /dev/null +++ b/tests/test_spark_llm_gateway.py @@ -0,0 +1,193 @@ +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from datacustomcode.llm_gateway import DefaultSparkLLMGateway +from datacustomcode.llm_gateway.spark_default import ( + _build_underlying_gateway, + _invoke_llm_gateway, +) +from datacustomcode.llm_gateway.types.generate_text_response import GenerateTextResponse + + +def _success_response(text: str = "ok") -> GenerateTextResponse: + return GenerateTextResponse( + status_code=200, data={"generation": {"generatedText": text}} + ) + + +class TestDefaultSparkLLMGatewayConstruction: + """Construction wires an underlying ``LLMGateway``.""" + + def test_uses_injected_llm_gateway_when_provided(self): + injected = MagicMock() + gateway = DefaultSparkLLMGateway(llm_gateway=injected) + assert gateway._llm_gateway is injected + + @patch("datacustomcode.llm_gateway.spark_default._build_underlying_gateway") + def test_falls_back_to_config_when_no_gateway_injected(self, mock_build): + config_built = MagicMock() + mock_build.return_value = config_built + + gateway = DefaultSparkLLMGateway() + + mock_build.assert_called_once_with() + assert gateway._llm_gateway is config_built + + +class TestBuildUnderlyingGateway: + """``_build_underlying_gateway`` resolves the config-defined ``LLMGateway``.""" + + def test_returns_object_from_config(self): + with patch( + "datacustomcode.llm_gateway_config.llm_gateway_config" + ) as mock_obj_config: + mock_gateway = MagicMock() + mock_obj_config.llm_gateway_config.to_object.return_value = mock_gateway + + assert _build_underlying_gateway() is mock_gateway + mock_obj_config.llm_gateway_config.to_object.assert_called_once_with() + + def test_raises_when_config_missing(self): + with patch( + "datacustomcode.llm_gateway_config.llm_gateway_config" + ) as mock_obj_config: + mock_obj_config.llm_gateway_config = None + with pytest.raises(RuntimeError, match="llm_gateway_config"): + _build_underlying_gateway() + + +class TestDefaultSparkLLMGatewayGenerateText: + + def test_forwards_prompt_model_and_max_tokens(self): + mock_inner = MagicMock() + mock_inner.generate_text.return_value = _success_response("hello back") + gateway = DefaultSparkLLMGateway(llm_gateway=mock_inner) + + result = gateway.llm_gateway_generate_text( + "hello", model_id="m1", max_tokens=42 + ) + + assert result == "hello back" + sent = mock_inner.generate_text.call_args.args[0] + assert sent.prompt == "hello" + assert sent.model_name == "m1" + assert sent.max_tokens == 42 + + def test_applies_defaults_when_model_and_tokens_omitted(self): + mock_inner = MagicMock() + mock_inner.generate_text.return_value = _success_response("ok") + gateway = DefaultSparkLLMGateway(llm_gateway=mock_inner) + + gateway.llm_gateway_generate_text("just a prompt") + + sent = mock_inner.generate_text.call_args.args[0] + assert sent.model_name == "sfdc_ai__DefaultGPT4Omni" + assert sent.max_tokens == 200 + + +class TestDefaultSparkLLMGatewayGenerateTextCol: + + @patch("pyspark.sql.functions.udf") + @patch("pyspark.sql.functions.struct") + def test_dict_values_built_into_struct_and_wrapped_in_udf( + self, mock_struct, mock_udf + ): + sentinel_struct_col = MagicMock(name="struct_col") + mock_struct.return_value = sentinel_struct_col + sentinel_udf = MagicMock(name="udf") + sentinel_applied = MagicMock(name="udf_applied") + sentinel_udf.return_value = sentinel_applied + mock_udf.return_value = sentinel_udf + + mock_inner = MagicMock() + mock_inner.generate_text.return_value = _success_response("row-out") + gateway = DefaultSparkLLMGateway(llm_gateway=mock_inner) + + name_col, city_col = MagicMock(name="name_col"), MagicMock(name="city_col") + name_aliased, city_aliased = ( + MagicMock(name="name_aliased"), + MagicMock(name="city_aliased"), + ) + name_col.alias.return_value = name_aliased + city_col.alias.return_value = city_aliased + + result = gateway.llm_gateway_generate_text_col( + "Greet {name} from {city}.", + {"name": name_col, "city": city_col}, + model_id="test-model", + max_tokens=5, + ) + + name_col.alias.assert_called_once_with("name") + city_col.alias.assert_called_once_with("city") + mock_struct.assert_called_once_with(name_aliased, city_aliased) + mock_udf.assert_called_once() + sentinel_udf.assert_called_once_with(sentinel_struct_col) + assert result is sentinel_applied + + udf_fn = mock_udf.call_args.args[0] + row = MagicMock() + row.asDict.return_value = {"name": "Ada", "city": "London"} + out = udf_fn(row) + + assert out == "row-out" + sent = mock_inner.generate_text.call_args.args[0] + assert sent.prompt == "Greet Ada from London." + assert sent.model_name == "test-model" + assert sent.max_tokens == 5 + + @patch("pyspark.sql.functions.udf") + @patch("pyspark.sql.functions.struct") + def test_column_values_passed_through_without_struct(self, mock_struct, mock_udf): + from pyspark.sql import Column + + existing_col = MagicMock(spec=Column) + sentinel_udf = MagicMock(name="udf") + sentinel_udf.return_value = MagicMock(name="udf_applied") + mock_udf.return_value = sentinel_udf + + gateway = DefaultSparkLLMGateway(llm_gateway=MagicMock()) + + gateway.llm_gateway_generate_text_col("Greet {name}", existing_col) + + mock_struct.assert_not_called() + sentinel_udf.assert_called_once_with(existing_col) + + @patch("pyspark.sql.functions.udf") + @patch("pyspark.sql.functions.struct") + def test_udf_returns_empty_for_null_row(self, mock_struct, mock_udf): + mock_struct.return_value = MagicMock() + mock_udf.return_value = MagicMock() + mock_inner = MagicMock() + gateway = DefaultSparkLLMGateway(llm_gateway=mock_inner) + + gateway.llm_gateway_generate_text_col("template", {"placeholder": MagicMock()}) + + udf_fn = mock_udf.call_args.args[0] + assert udf_fn(None) == "" + mock_inner.generate_text.assert_not_called() + + +class TestInvokeLLMGateway: + + def test_returns_response_text(self): + mock_inner = MagicMock() + mock_inner.generate_text.return_value = _success_response("done") + + assert _invoke_llm_gateway(mock_inner, "prompt", "model", 7) == "done" + sent = mock_inner.generate_text.call_args.args[0] + assert sent.prompt == "prompt" + assert sent.model_name == "model" + assert sent.max_tokens == 7 + + def test_uses_defaults_when_model_and_tokens_none(self): + mock_inner = MagicMock() + mock_inner.generate_text.return_value = _success_response("ok") + + _invoke_llm_gateway(mock_inner, "prompt", None, None) + sent = mock_inner.generate_text.call_args.args[0] + assert sent.model_name == "sfdc_ai__DefaultGPT4Omni" + assert sent.max_tokens == 200