diff --git a/__init__.py b/__init__.py index bdce8c8..09a926f 100644 --- a/__init__.py +++ b/__init__.py @@ -11,6 +11,7 @@ ) from .gemma_api_conditioning import GemmaAPITextEncode from .gemma_encoder import LTXVGemmaCLIPModelLoader, LTXVGemmaEnhancePrompt +from .minimax_prompt_enhancer import MiniMaxPromptEnhancer from .guide import LTXVAddGuideAdvanced, LTXVAddGuideAdvancedAttention from .guiders import GuiderParametersNode, MultimodalGuiderNode from .iclora import ( @@ -99,6 +100,7 @@ "LTXVGemmaCLIPModelLoader": LTXVGemmaCLIPModelLoader, "LTXVGemmaEnhancePrompt": LTXVGemmaEnhancePrompt, "GemmaAPITextEncode": GemmaAPITextEncode, + "MiniMaxPromptEnhancer": MiniMaxPromptEnhancer, "DynamicConditioning": DynamicConditioning, "LowVRAMCheckpointLoader": LowVRAMCheckpointLoader, "LowVRAMAudioVAELoader": LowVRAMAudioVAELoader, diff --git a/minimax_prompt_enhancer.py b/minimax_prompt_enhancer.py new file mode 100644 index 0000000..c3698b7 --- /dev/null +++ b/minimax_prompt_enhancer.py @@ -0,0 +1,182 @@ +""" +API-based prompt enhancement using MiniMax. + +Provides an API-based alternative to the local LLM prompt enhancer, allowing +users to enhance their video generation prompts via MiniMax without requiring +local GPU resources. +""" + +import logging +import re + +import requests + +from .nodes_registry import comfy_node +from .prompt_enhancer_utils import I2V_CINEMATIC_PROMPT, T2V_CINEMATIC_PROMPT + +logger = logging.getLogger(__name__) + +MINIMAX_API_BASE_URL = "https://api.minimax.io/v1" +MINIMAX_MODELS = ["MiniMax-M2.7", "MiniMax-M2.7-highspeed"] + +_THINK_TAG_RE = re.compile(r".*?", re.DOTALL) + + +def enhance_prompt_via_minimax( + prompt: str, + system_prompt: str, + api_key: str, + model: str = "MiniMax-M2.7", + base_url: str = MINIMAX_API_BASE_URL, + max_tokens: int = 512, +) -> str: + """Call MiniMax chat completions API to enhance a video generation prompt.""" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + } + payload = { + "model": model, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ], + "max_tokens": max_tokens, + "temperature": 1.0, + } + response = requests.post( + f"{base_url}/chat/completions", + json=payload, + headers=headers, + timeout=60, + ) + if response.status_code == 401: + raise RuntimeError( + "Invalid API key. Please provide a valid MINIMAX_API_KEY." + ) + if response.status_code != 200: + raise RuntimeError( + f"MiniMax API request failed with status {response.status_code}: {response.text}" + ) + data = response.json() + content = data["choices"][0]["message"]["content"] + # Strip reasoning tokens emitted by MiniMax-M2.7 + content = _THINK_TAG_RE.sub("", content).strip() + return content + + +@comfy_node(name="MiniMaxPromptEnhancer") +class MiniMaxPromptEnhancer: + """ + Enhances video generation prompts using the MiniMax API. + + An API-based alternative to the local LLM prompt enhancer. No local GPU + resources are required — prompts are sent to the MiniMax chat completions + endpoint and returned as cinematic descriptions suitable for LTX-Video. + + Inputs: + - api_key: MiniMax API key (MINIMAX_API_KEY). Get one at https://platform.minimax.io/ + - prompt: Text prompt to enhance + - model: MiniMax model to use for enhancement + - mode: Enhancement mode — T2V (text-to-video) or I2V (image-to-video) + - max_tokens: Maximum number of tokens in the enhanced prompt + + Returns: + - enhanced_prompt: Cinematically enhanced prompt string + """ + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "api_key": ( + "STRING", + { + "default": "", + "placeholder": "MINIMAX_API_KEY", + "multiline": False, + "tooltip": "MiniMax API key. Get one at https://platform.minimax.io/", + }, + ), + "prompt": ( + "STRING", + { + "multiline": True, + "default": "", + "tooltip": "Text prompt to enhance for video generation", + }, + ), + "model": ( + MINIMAX_MODELS, + { + "default": "MiniMax-M2.7", + "tooltip": "MiniMax model to use for prompt enhancement", + }, + ), + "mode": ( + ["T2V", "I2V"], + { + "default": "T2V", + "tooltip": ( + "Enhancement mode: T2V for text-to-video, " + "I2V for image-to-video" + ), + }, + ), + "max_tokens": ( + "INT", + { + "default": 512, + "min": 64, + "max": 1024, + "step": 64, + "tooltip": "Maximum number of tokens in the enhanced prompt", + }, + ), + }, + } + + RETURN_TYPES = ("STRING",) + RETURN_NAMES = ("enhanced_prompt",) + FUNCTION = "enhance" + CATEGORY = "api node/text/Lightricks" + TITLE = "MiniMax Prompt Enhancer" + OUTPUT_NODE = True + DESCRIPTION = ( + "Enhances video generation prompts using the MiniMax API. " + "An API-based alternative to the local LLM enhancer — no local GPU required." + ) + + def enhance( + self, + api_key: str, + prompt: str, + model: str = "MiniMax-M2.7", + mode: str = "T2V", + max_tokens: int = 512, + ) -> tuple: + if not api_key.strip(): + raise ValueError( + "MiniMax API key is required. " + "Get one at https://platform.minimax.io/" + ) + if not prompt.strip(): + raise ValueError("Prompt cannot be empty") + + system_prompt = T2V_CINEMATIC_PROMPT if mode == "T2V" else I2V_CINEMATIC_PROMPT + + logger.info( + "Enhancing prompt via MiniMax API (model=%s, mode=%s): %s...", + model, + mode, + prompt[:60], + ) + enhanced = enhance_prompt_via_minimax( + prompt=prompt, + system_prompt=system_prompt, + api_key=api_key, + model=model, + max_tokens=max_tokens, + ) + logger.info("Enhanced prompt: %s...", enhanced[:60]) + return (enhanced,) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..876a527 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[tool.pytest.ini_options] +testpaths = ["tests"] +addopts = ["-p", "no:flaky", "--import-mode=importlib"] diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..d422070 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,93 @@ +""" +Pytest configuration: sets up ComfyUI stubs so nodes can be imported +without a running ComfyUI environment, and handles the hyphenated package +name by creating a proper package alias. +""" + +import importlib.util +import os +import sys +import types + +REPO_ROOT = os.path.dirname(os.path.dirname(__file__)) +PKG_NAME = "ComfyUI_LTXVideo" # Python-safe alias for the repo folder + +# --------------------------------------------------------------------------- +# Stub heavy ComfyUI / ML dependencies +# --------------------------------------------------------------------------- + +comfy_pkg = types.ModuleType("comfy") +comfy_mm = types.ModuleType("comfy.model_management") +comfy_mp = types.ModuleType("comfy.model_patcher") +comfy_pkg.model_management = comfy_mm +comfy_pkg.model_patcher = comfy_mp +sys.modules.setdefault("comfy", comfy_pkg) +sys.modules.setdefault("comfy.model_management", comfy_mm) +sys.modules.setdefault("comfy.model_patcher", comfy_mp) + +fp = types.ModuleType("folder_paths") +fp.models_dir = "/tmp/models" +fp.get_filename_list = lambda _: [] +fp.get_full_path_or_raise = lambda *a: "/tmp/fake" +sys.modules.setdefault("folder_paths", fp) + +torch_stub = types.ModuleType("torch") +torch_stub.Tensor = object # minimal stub +sys.modules.setdefault("torch", torch_stub) + +PIL_stub = types.ModuleType("PIL") +PIL_Image_stub = types.ModuleType("PIL.Image") +PIL_Image_stub.Image = object # minimal stub +PIL_stub.Image = PIL_Image_stub +sys.modules.setdefault("PIL", PIL_stub) +sys.modules.setdefault("PIL.Image", PIL_Image_stub) + +for _mod in [ + "transformers", + "safetensors", + "safetensors.torch", + "einops", + "kornia", + "kornia.geometry", + "huggingface_hub", +]: + sys.modules.setdefault(_mod, types.ModuleType(_mod)) + +# --------------------------------------------------------------------------- +# Register the repo root as a Python package so relative imports work +# --------------------------------------------------------------------------- + +pkg = types.ModuleType(PKG_NAME) +# Explicitly do NOT set __path__ or __file__ to prevent pytest from +# traversing the repo root and trying to import its __init__.py +pkg.__package__ = PKG_NAME +pkg.__spec__ = None +sys.modules[PKG_NAME] = pkg + +# With --import-mode=importlib, pytest falls back to module_name_from_path() +# which produces "__init__" as the module name for the repo's __init__.py. +# Pre-register it so pytest finds it in sys.modules and skips exec'ing it. +_init_stub = types.ModuleType("__init__") +_init_stub.__package__ = PKG_NAME +sys.modules["__init__"] = _init_stub + + +def _load_submodule(name: str) -> types.ModuleType: + """Load a .py file from REPO_ROOT as PKG_NAME.name.""" + full_name = f"{PKG_NAME}.{name}" + if full_name in sys.modules: + return sys.modules[full_name] + path = os.path.join(REPO_ROOT, f"{name}.py") + spec = importlib.util.spec_from_file_location(full_name, path) + mod = importlib.util.module_from_spec(spec) + mod.__package__ = PKG_NAME + sys.modules[full_name] = mod + spec.loader.exec_module(mod) + setattr(pkg, name, mod) + return mod + + +# Pre-load the modules that tests depend on +_load_submodule("nodes_registry") +_load_submodule("prompt_enhancer_utils") +_load_submodule("minimax_prompt_enhancer") diff --git a/tests/test_minimax_prompt_enhancer.py b/tests/test_minimax_prompt_enhancer.py new file mode 100644 index 0000000..61e989d --- /dev/null +++ b/tests/test_minimax_prompt_enhancer.py @@ -0,0 +1,231 @@ +"""Unit tests for MiniMaxPromptEnhancer node.""" + +import sys +from unittest.mock import MagicMock, patch + +import pytest + +# conftest.py pre-loads all modules with proper package context. +# Retrieve them from sys.modules directly. +PKG_NAME = "ComfyUI_LTXVideo" +minimax_mod = sys.modules[f"{PKG_NAME}.minimax_prompt_enhancer"] + +MINIMAX_API_BASE_URL = minimax_mod.MINIMAX_API_BASE_URL +MINIMAX_MODELS = minimax_mod.MINIMAX_MODELS +MiniMaxPromptEnhancer = minimax_mod.MiniMaxPromptEnhancer +enhance_prompt_via_minimax = minimax_mod.enhance_prompt_via_minimax + + +# --------------------------------------------------------------------------- +# Helper +# --------------------------------------------------------------------------- + + +def _make_response(content: str, status_code: int = 200): + mock = MagicMock() + mock.status_code = status_code + mock.json.return_value = {"choices": [{"message": {"content": content}}]} + mock.text = content + return mock + + +# --------------------------------------------------------------------------- +# Tests for the helper function +# --------------------------------------------------------------------------- + + +class TestEnhancePromptViaMinimax: + def test_returns_enhanced_content(self): + expected = "A cinematic shot of a car driving." + with patch(f"{minimax_mod.__name__}.requests.post") as mock_post: + mock_post.return_value = _make_response(expected) + result = enhance_prompt_via_minimax( + prompt="A car drives.", + system_prompt="Enhance cinematically.", + api_key="test-key", + ) + assert result == expected + + def test_uses_correct_default_base_url(self): + with patch(f"{minimax_mod.__name__}.requests.post") as mock_post: + mock_post.return_value = _make_response("enhanced") + enhance_prompt_via_minimax( + prompt="test", + system_prompt="sys", + api_key="key", + ) + call_url = mock_post.call_args[0][0] + assert call_url.startswith(MINIMAX_API_BASE_URL) + + def test_uses_custom_base_url(self): + custom_url = "https://custom.minimax.io/v1" + with patch(f"{minimax_mod.__name__}.requests.post") as mock_post: + mock_post.return_value = _make_response("enhanced") + enhance_prompt_via_minimax( + prompt="test", + system_prompt="sys", + api_key="key", + base_url=custom_url, + ) + call_url = mock_post.call_args[0][0] + assert call_url.startswith(custom_url) + + def test_sends_api_key_in_header(self): + api_key = "my-secret-key" + with patch(f"{minimax_mod.__name__}.requests.post") as mock_post: + mock_post.return_value = _make_response("enhanced") + enhance_prompt_via_minimax( + prompt="test", + system_prompt="sys", + api_key=api_key, + ) + headers = mock_post.call_args[1]["headers"] + assert headers["Authorization"] == f"Bearer {api_key}" + + def test_temperature_is_not_zero(self): + """MiniMax requires temperature in (0.0, 1.0].""" + with patch(f"{minimax_mod.__name__}.requests.post") as mock_post: + mock_post.return_value = _make_response("enhanced") + enhance_prompt_via_minimax( + prompt="test", + system_prompt="sys", + api_key="key", + ) + payload = mock_post.call_args[1]["json"] + assert payload["temperature"] > 0.0 + assert payload["temperature"] <= 1.0 + + def test_raises_on_401(self): + with patch(f"{minimax_mod.__name__}.requests.post") as mock_post: + mock_post.return_value = _make_response("Unauthorized", 401) + with pytest.raises(RuntimeError, match="Invalid API key"): + enhance_prompt_via_minimax( + prompt="test", + system_prompt="sys", + api_key="bad-key", + ) + + def test_raises_on_non_200(self): + with patch(f"{minimax_mod.__name__}.requests.post") as mock_post: + mock_post.return_value = _make_response("Server Error", 500) + with pytest.raises(RuntimeError, match="500"): + enhance_prompt_via_minimax( + prompt="test", + system_prompt="sys", + api_key="key", + ) + + def test_default_model_is_minimax_m2_7(self): + with patch(f"{minimax_mod.__name__}.requests.post") as mock_post: + mock_post.return_value = _make_response("enhanced") + enhance_prompt_via_minimax( + prompt="test", + system_prompt="sys", + api_key="key", + ) + payload = mock_post.call_args[1]["json"] + assert payload["model"] == "MiniMax-M2.7" + + def test_strips_whitespace_from_response(self): + padded = " enhanced prompt with spaces " + with patch(f"{minimax_mod.__name__}.requests.post") as mock_post: + mock_post.return_value = _make_response(padded) + result = enhance_prompt_via_minimax( + prompt="test", + system_prompt="sys", + api_key="key", + ) + assert result == padded.strip() + + def test_strips_think_tags(self): + """MiniMax-M2.7 returns reasoning blocks; strip them.""" + with_think = "\nsome reasoning\n\n\nActual enhanced prompt." + with patch(f"{minimax_mod.__name__}.requests.post") as mock_post: + mock_post.return_value = _make_response(with_think) + result = enhance_prompt_via_minimax( + prompt="test", + system_prompt="sys", + api_key="key", + ) + assert result == "Actual enhanced prompt." + assert "" not in result + + +# --------------------------------------------------------------------------- +# Tests for the ComfyUI node class +# --------------------------------------------------------------------------- + + +class TestMiniMaxPromptEnhancerNode: + def test_node_metadata(self): + assert MiniMaxPromptEnhancer.RETURN_TYPES == ("STRING",) + assert MiniMaxPromptEnhancer.RETURN_NAMES == ("enhanced_prompt",) + assert MiniMaxPromptEnhancer.FUNCTION == "enhance" + assert MiniMaxPromptEnhancer.CATEGORY == "api node/text/Lightricks" + + def test_model_choices_include_m2_7(self): + input_types = MiniMaxPromptEnhancer.INPUT_TYPES() + model_choices = input_types["required"]["model"][0] + assert "MiniMax-M2.7" in model_choices + assert "MiniMax-M2.7-highspeed" in model_choices + + def test_model_list_matches_constant(self): + input_types = MiniMaxPromptEnhancer.INPUT_TYPES() + model_choices = input_types["required"]["model"][0] + assert model_choices == MINIMAX_MODELS + + def test_enhance_returns_tuple(self): + node = MiniMaxPromptEnhancer() + with patch(f"{minimax_mod.__name__}.enhance_prompt_via_minimax") as mock_fn: + mock_fn.return_value = "A beautiful scene." + result = node.enhance( + api_key="test-key", + prompt="A scene.", + model="MiniMax-M2.7", + mode="T2V", + max_tokens=512, + ) + assert isinstance(result, tuple) + assert result[0] == "A beautiful scene." + + def test_enhance_raises_on_empty_api_key(self): + node = MiniMaxPromptEnhancer() + with pytest.raises(ValueError, match="API key"): + node.enhance(api_key="", prompt="test") + + def test_enhance_raises_on_whitespace_api_key(self): + node = MiniMaxPromptEnhancer() + with pytest.raises(ValueError, match="API key"): + node.enhance(api_key=" ", prompt="test") + + def test_enhance_raises_on_empty_prompt(self): + node = MiniMaxPromptEnhancer() + with pytest.raises(ValueError, match="Prompt"): + node.enhance(api_key="key", prompt="") + + def test_t2v_mode_uses_t2v_system_prompt(self): + node = MiniMaxPromptEnhancer() + with patch(f"{minimax_mod.__name__}.enhance_prompt_via_minimax") as mock_fn: + mock_fn.return_value = "enhanced" + node.enhance(api_key="key", prompt="test", mode="T2V") + sys_prompt = mock_fn.call_args[1]["system_prompt"] + assert "cinematic director" in sys_prompt.lower() + + def test_i2v_mode_uses_i2v_system_prompt(self): + node = MiniMaxPromptEnhancer() + with patch(f"{minimax_mod.__name__}.enhance_prompt_via_minimax") as mock_fn: + mock_fn.return_value = "enhanced" + node.enhance(api_key="key", prompt="test", mode="I2V") + sys_prompt = mock_fn.call_args[1]["system_prompt"] + assert "cinematic director" in sys_prompt.lower() + + def test_highspeed_model_forwarded_to_api(self): + node = MiniMaxPromptEnhancer() + with patch(f"{minimax_mod.__name__}.enhance_prompt_via_minimax") as mock_fn: + mock_fn.return_value = "enhanced" + node.enhance( + api_key="key", + prompt="test", + model="MiniMax-M2.7-highspeed", + ) + assert mock_fn.call_args[1]["model"] == "MiniMax-M2.7-highspeed"