diff --git a/README.md b/README.md index 49ea0b9..f5b27c6 100644 --- a/README.md +++ b/README.md @@ -229,12 +229,18 @@ Instead of using `datacustomcode configure`, you can also set credentials via en | `SFDC_REFRESH_TOKEN` | OAuth refresh token | | `SFDC_ACCESS_TOKEN` | (Optional) OAuth core/access token | +**Einstein Platform API Environment (Optional):** +| Variable | Description | +|----------|-------------| +| `SFDC_EINSTEIN_API_ENV` | Einstein Platform API environment: `dev`, `test`, `stage`, or `prod`. If not set, automatically inferred from login URL. Set this explicitly if auto-detection fails. | + Example usage: ```bash export SFDC_LOGIN_URL="https://login.salesforce.com" export SFDC_CLIENT_ID="your_client_id" export SFDC_CLIENT_SECRET="your_client_secret" export SFDC_REFRESH_TOKEN="your_refresh_token" +export SFDC_EINSTEIN_API_ENV="test" # optional datacustomcode run ./payload/entrypoint.py ``` diff --git a/src/datacustomcode/einstein_platform_client.py b/src/datacustomcode/einstein_platform_client.py index 761f80a..8ff546e 100644 --- a/src/datacustomcode/einstein_platform_client.py +++ b/src/datacustomcode/einstein_platform_client.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from typing import ( Any, Dict, @@ -30,10 +31,6 @@ class EinsteinPlatformClient: - EINSTEIN_PLATFORM_MODELS_URL = ( - "https://api.salesforce.com/einstein/platform/v1/models" - ) - def __init__( self, credentials_profile: Optional[str] = None, @@ -48,8 +45,34 @@ def __init__( self._token_provider = CredentialsTokenProvider(profile) logger.debug(f"Using credentials token provider with profile: {profile}") self.token_response = None + self._einstein_url_cache: Optional[str] = None super().__init__(**kwargs) + def _get_einstein_platform_url(self) -> str: + if self._einstein_url_cache is not None: + return self._einstein_url_cache + + env = os.environ.get("SFDC_EINSTEIN_API_ENV", "prod").lower() + if env not in ("dev", "test", "stage", "prod"): + logger.warning( + f"Unknown SFDC_EINSTEIN_API_ENV value '{env}', defaulting to prod" + ) + env = "prod" + + base_url = self._get_base_url_for_env(env) + logger.info(f"Using Einstein Platform API endpoint: {base_url} (env={env})") + self._einstein_url_cache = f"{base_url}/einstein/platform/v1/models" + return self._einstein_url_cache + + def _get_base_url_for_env(self, env: str) -> str: + env_map = { + "dev": "https://dev.api.salesforce.com", + "test": "https://test.api.salesforce.com", + "stage": "https://stage.api.salesforce.com", + "prod": "https://api.salesforce.com", + } + return env_map.get(env, "https://api.salesforce.com") + def _get_headers(self): if self.token_response is None: self.token_response = self._token_provider.get_token() diff --git a/src/datacustomcode/einstein_predictions/impl/default.py b/src/datacustomcode/einstein_predictions/impl/default.py index 28e51f0..c95db71 100644 --- a/src/datacustomcode/einstein_predictions/impl/default.py +++ b/src/datacustomcode/einstein_predictions/impl/default.py @@ -48,7 +48,7 @@ def predict(self, request: PredictionRequest) -> PredictionResponse: ) api_url = ( - f"{self.EINSTEIN_PLATFORM_MODELS_URL}/{request.model_api_name}/{endpoint}" + f"{self._get_einstein_platform_url()}/{request.model_api_name}/{endpoint}" ) prediction_columns: List[Dict[str, Any]] = [] diff --git a/src/datacustomcode/llm_gateway/default.py b/src/datacustomcode/llm_gateway/default.py index 88374e3..54f5105 100644 --- a/src/datacustomcode/llm_gateway/default.py +++ b/src/datacustomcode/llm_gateway/default.py @@ -29,7 +29,7 @@ class DefaultLLMGateway(EinsteinPlatformClient, LLMGateway): def generate_text(self, request: GenerateTextRequest) -> GenerateTextResponse: api_url = ( - f"{self.EINSTEIN_PLATFORM_MODELS_URL}/{request.model_name}/generations" + f"{self._get_einstein_platform_url()}/{request.model_name}/generations" ) payload: Dict[str, Any] = {"prompt": request.prompt}