Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,53 @@ Options:
- `--function-invoke-opt TEXT`: Currently we support only `UnstructuredChunking` for functions.


## Testing LLM Gateway

You can use AI models configured in Salesforce to generate responses while transforming your data. Below is a sample code example:

```
from datacustomcode.client import Client, llm_gateway_generate_text_col


def main():
client = Client()
df = client.read_dlo("Input__dll")
df_generated = df.withColumn(
"greeting__c",
llm_gateway_generate_text_col(
"In one sentence, greet {name} from {city}.",
{"name": col("name__c"), "city": col("homecity__c")},
model_id="sfdc_ai__DefaultGPT4Omni", # An AI model in your org
max_tokens=100,
),
)

dlo_name = "Output_dll"
client.write_to_dlo(dlo_name, df_upper1, write_mode=WriteMode.APPEND)

greeting = client.llm_gateway_generate_text("In one sentence, generate a greeting message", "sfdc_ai__DefaultGPT52")

if __name__ == "__main__":
main()
```

In order to test this code on your local machine before deploying it to Data Cloud, you must first set up an External Client App that allows access to the Agent API. Follow this guide to create the ECA https://developer.salesforce.com/docs/ai/agentforce/guide/agent-api-get-started.html#create-a-salesforce-app. You must use `http://localhost:1717/OauthRedirect` as the callback URL.

Once the ECA is set up, log in to your org using this ECA
```
sf org login web \
--alias myorg \
--instance-url https://{MY_DOMAIN_URL} \
--client-id {CONSUMER_KEY} \
--scopes "sfap_api api"
```

then you can test your code using `myorg` alias
```
datacustomcode run ./payload/entrypoint.py --sf-cli-org myorg
```


## Docker usage

The SDK provides Docker-based development options that allow you to test your code in an environment that closely resembles Data Cloud's execution environment.
Expand Down
15 changes: 15 additions & 0 deletions src/datacustomcode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
"AuthType",
"Client",
"Credentials",
"DefaultSparkLLMGateway",
"PrintDataCloudWriter",
"QueryAPIDataCloudReader",
"SparkLLMGateway",
"llm_gateway_generate_text_col",
]


Expand All @@ -44,4 +47,16 @@ def __getattr__(name: str):
from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader

return QueryAPIDataCloudReader
elif name == "SparkLLMGateway":
from datacustomcode.llm_gateway import SparkLLMGateway

return SparkLLMGateway
elif name == "DefaultSparkLLMGateway":
from datacustomcode.llm_gateway import DefaultSparkLLMGateway

return DefaultSparkLLMGateway
elif name == "llm_gateway_generate_text_col":
from datacustomcode.client import llm_gateway_generate_text_col

return llm_gateway_generate_text_col
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
103 changes: 102 additions & 1 deletion src/datacustomcode/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,80 @@
from typing import (
TYPE_CHECKING,
ClassVar,
Dict,
Optional,
Union,
)

from datacustomcode.config import config
from datacustomcode.file.path.default import DefaultFindFilePath
from datacustomcode.io.reader.base import BaseDataCloudReader
from datacustomcode.llm_gateway_config import spark_llm_gateway_config
from datacustomcode.spark.default import DefaultSparkSessionProvider

if TYPE_CHECKING:
from pathlib import Path

from pyspark.sql import DataFrame as PySparkDataFrame
from pyspark.sql import Column, DataFrame as PySparkDataFrame

from datacustomcode.io.reader.base import BaseDataCloudReader
from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode
from datacustomcode.llm_gateway.spark_base import SparkLLMGateway
from datacustomcode.spark.base import BaseSparkSessionProvider


def _build_spark_llm_gateway() -> "SparkLLMGateway":
"""Instantiate the SDK-configured :class:`SparkLLMGateway`.

Raises:
RuntimeError: If no ``spark_llm_gateway_config`` has been loaded.
"""
cfg = spark_llm_gateway_config.spark_llm_gateway_config
if cfg is None:
raise RuntimeError(
"spark_llm_gateway_config is not configured. Add a "
"'spark_llm_gateway_config' section to config.yaml."
)
return cfg.to_object()


def llm_gateway_generate_text_col(
template: str,
values: Union[Dict[str, "Column"], "Column"],
model_id: Optional[str] = None,
max_tokens: Optional[int] = None,
) -> "Column":
"""Build a Spark Column that runs the LLM Gateway per row.

Example:

>>> df.withColumn(
... "greeting__c",
... llm_gateway_generate_text_col(
... "In one sentence, greet {name} from {city}.",
... {"name": col("name__c"), "city": col("homecity__c")},
... model_id="sfdc_ai__DefaultGPT4Omni",
... max_tokens=100,
... ),
... )

Args:
template: The prompt template, with ``{field}`` placeholders matching
keys in ``values``. Substitution uses ``str.format``.
values: Either a mapping from placeholder name to Spark ``Column``, or
a single ``Column`` whose value is already a struct.
model_id: LLM model id. Defaults to ``sfdc_ai__DefaultGPT4Omni``.
max_tokens: Maximum tokens to generate. Defaults to 200.

Returns:
A Spark ``Column`` that, when evaluated, produces the generated text.
"""
gateway = Client()._get_spark_llm_gateway()
return gateway.llm_gateway_generate_text_col(
template, values, model_id=model_id, max_tokens=max_tokens
)


class DataCloudObjectType(Enum):
DLO = "dlo"
DMO = "dmo"
Expand Down Expand Up @@ -94,18 +150,23 @@ class Client:
finder: Find a file path
reader: A custom reader to use for reading Data Cloud objects.
writer: A custom writer to use for writing Data Cloud objects.
spark_llm_gateway: Optional custom :class:`SparkLLMGateway`. When
omitted, the gateway is lazily resolved from
``spark_llm_gateway_config``.

Example:
>>> client = Client()
>>> file_path = client.find_file_path("data.csv")
>>> dlo = client.read_dlo("my_dlo")
>>> client.write_to_dmo("my_dmo", dlo)
>>> answer = client.llm_gateway_generate_text("Generate a greeting message")
"""

_instance: ClassVar[Optional[Client]] = None
_reader: BaseDataCloudReader
_writer: BaseDataCloudWriter
_file: DefaultFindFilePath
_spark_llm_gateway: Optional[SparkLLMGateway]
_data_layer_history: dict[DataCloudObjectType, set[str]]
_code_type: str

Expand All @@ -114,11 +175,13 @@ def __new__(
reader: Optional[BaseDataCloudReader] = None,
writer: Optional[BaseDataCloudWriter] = None,
spark_provider: Optional[BaseSparkSessionProvider] = None,
spark_llm_gateway: Optional[SparkLLMGateway] = None,
code_type: str = "script",
) -> Client:

if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._spark_llm_gateway = spark_llm_gateway
# Initialize Readers and Writers from config
# and/or provided reader and writer
if reader is None or writer is None:
Expand Down Expand Up @@ -225,6 +288,44 @@ def find_file_path(self, file_name: str) -> Path:

return self._file.find_file_path(file_name) # type: ignore[no-any-return]

def llm_gateway_generate_text(
self,
prompt: str,
model_id: Optional[str] = None,
max_tokens: Optional[int] = None,
) -> str:
"""Issue a one-shot LLM Gateway call. This is the scalar counterpart to
:func:`llm_gateway_generate_text_col`: it runs **once** — not per row.
Use the column helper method instead when you want to fan a prompt out across
every row of a DataFrame.

Example:

>>> response = Client().llm_gateway_generate_text(
... "Generate a greeting message"
... )

Args:
prompt: The literal prompt to send. Plain text — no
``{field}`` substitution is performed on this string.
model_id: LLM model id to target. Defaults to
``sfdc_ai__DefaultGPT4Omni`` when ``None``.
max_tokens: Hard upper bound on the number of tokens the model
may generate. Defaults to 200 when ``None``.

Returns:
The generated text as a plain Python ``str``; empty when the
gateway response carries no generated text.
"""
return self._get_spark_llm_gateway().llm_gateway_generate_text(
prompt, model_id=model_id, max_tokens=max_tokens
)

def _get_spark_llm_gateway(self) -> SparkLLMGateway:
if self._spark_llm_gateway is None:
self._spark_llm_gateway = _build_spark_llm_gateway()
return self._spark_llm_gateway

def _validate_data_layer_history_does_not_contain(
self, data_cloud_object_type: DataCloudObjectType
) -> None:
Expand Down
3 changes: 3 additions & 0 deletions src/datacustomcode/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ llm_gateway_config:
type_config_name: DefaultLLMGateway
options:
credentials_profile: default

spark_llm_gateway_config:
type_config_name: DefaultSparkLLMGateway
12 changes: 8 additions & 4 deletions src/datacustomcode/einstein_platform_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,23 @@

from typing import (
ClassVar,
Generic,
Optional,
Type,
cast,
TypeVar,
)

from datacustomcode.common_config import BaseObjectConfig

_T = TypeVar("_T")

class CredentialsObjectConfig(BaseObjectConfig):

class CredentialsObjectConfig(BaseObjectConfig, Generic[_T]):
type_to_create: ClassVar[Type]
credentials_profile: Optional[str] = None
sf_cli_org: Optional[str] = None

def to_object(self):
def to_object(self) -> _T:
"""Create an object instance, automatically including credentials in options"""

options = self.options.copy()
Expand All @@ -38,4 +41,5 @@ def to_object(self):
options["sf_cli_org"] = self.sf_cli_org

type_ = self.type_to_create.subclass_from_config_name(self.type_config_name)
return cast(type_, type_(**options))
instance: _T = type_(**options)
return instance
2 changes: 1 addition & 1 deletion src/datacustomcode/einstein_predictions_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
_E = TypeVar("_E", bound=EinsteinPredictions)


class EinsteinPredictionsObjectConfig(CredentialsObjectConfig, Generic[_E]):
class EinsteinPredictionsObjectConfig(CredentialsObjectConfig[_E], Generic[_E]):
type_to_create: ClassVar[Type[EinsteinPredictions]] = EinsteinPredictions # type: ignore[type-abstract]


Expand Down
1 change: 1 addition & 0 deletions src/datacustomcode/function/feature_types/chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""
Pydantic models for Search Index Chunking V1
"""

from enum import Enum
from typing import (
Dict,
Expand Down
4 changes: 4 additions & 0 deletions src/datacustomcode/llm_gateway/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@

from datacustomcode.llm_gateway.base import LLMGateway
from datacustomcode.llm_gateway.default import DefaultLLMGateway
from datacustomcode.llm_gateway.spark_base import SparkLLMGateway
from datacustomcode.llm_gateway.spark_default import DefaultSparkLLMGateway

__all__ = [
"DefaultLLMGateway",
"DefaultSparkLLMGateway",
"LLMGateway",
"SparkLLMGateway",
]
2 changes: 2 additions & 0 deletions src/datacustomcode/llm_gateway/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def generate_text(self, request: GenerateTextRequest) -> GenerateTextResponse:

payload: Dict[str, Any] = {"prompt": request.prompt}

if request.max_tokens is not None:
payload["max_tokens"] = request.max_tokens
if request.localization:
payload["localization"] = request.localization
if request.tags:
Expand Down
55 changes: 55 additions & 0 deletions src/datacustomcode/llm_gateway/spark_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) 2025, Salesforce, Inc.
# SPDX-License-Identifier: Apache-2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING,
Any,
Dict,
Optional,
Union,
)

from datacustomcode.mixin import UserExtendableNamedConfigMixin

if TYPE_CHECKING:
from pyspark.sql import Column


class SparkLLMGateway(ABC, UserExtendableNamedConfigMixin):
CONFIG_NAME: str

def __init__(self, **kwargs: Any) -> None:
pass

@abstractmethod
def llm_gateway_generate_text(
self,
prompt: str,
model_id: Optional[str] = None,
max_tokens: Optional[int] = None,
) -> str:
"""Issue a one-shot LLM Gateway call and return the generated text."""

@abstractmethod
def llm_gateway_generate_text_col(
self,
template: str,
values: Union[Dict[str, "Column"], "Column"],
model_id: Optional[str] = None,
max_tokens: Optional[int] = None,
) -> "Column":
"""Build a Spark ``Column`` that invokes the LLM Gateway per row."""
Loading
Loading