Skip to content

Commit 213e8ff

Browse files
Transurgeonclaude
andcommitted
Bindings for diag_mat, upper_tri, and vstack atoms (#7)
* Bindings for diag_mat, kron_left, upper_tri, and vstack atoms Add Python C extension bindings for four new SparseDiffEngine atoms: - diag_mat: extract diagonal from square matrix - upper_tri: extract strict upper triangular elements - kron_left: Kronecker product kron(C, X) with constant sparse C - vstack: vertical stack of expressions (via transpose-hstack composition) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Update SparseDiffEngine submodule with main merge Merges SparseDiffEngine main into adds-more-affine-atoms branch to pick up folder restructuring and parameter support while preserving diag_mat, upper_tri, and vstack atoms. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * pin diffengine to tag --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 966b397 commit 213e8ff

File tree

5 files changed

+123
-1
lines changed

5 files changed

+123
-1
lines changed

SparseDiffEngine

Submodule SparseDiffEngine updated 98 files
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#ifndef ATOM_DIAG_MAT_H
2+
#define ATOM_DIAG_MAT_H
3+
4+
#include "common.h"
5+
6+
static PyObject *py_make_diag_mat(PyObject *self, PyObject *args)
7+
{
8+
PyObject *child_capsule;
9+
10+
if (!PyArg_ParseTuple(args, "O", &child_capsule))
11+
{
12+
return NULL;
13+
}
14+
15+
expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME);
16+
if (!child)
17+
{
18+
return NULL;
19+
}
20+
21+
expr *node = new_diag_mat(child);
22+
if (!node)
23+
{
24+
PyErr_SetString(PyExc_RuntimeError, "failed to create diag_mat node");
25+
return NULL;
26+
}
27+
28+
expr_retain(node);
29+
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
30+
}
31+
32+
#endif /* ATOM_DIAG_MAT_H */
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#ifndef ATOM_UPPER_TRI_H
2+
#define ATOM_UPPER_TRI_H
3+
4+
#include "common.h"
5+
6+
static PyObject *py_make_upper_tri(PyObject *self, PyObject *args)
7+
{
8+
PyObject *child_capsule;
9+
10+
if (!PyArg_ParseTuple(args, "O", &child_capsule))
11+
{
12+
return NULL;
13+
}
14+
15+
expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME);
16+
if (!child)
17+
{
18+
return NULL;
19+
}
20+
21+
expr *node = new_upper_tri(child);
22+
if (!node)
23+
{
24+
PyErr_SetString(PyExc_RuntimeError, "failed to create upper_tri node");
25+
return NULL;
26+
}
27+
28+
expr_retain(node);
29+
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
30+
}
31+
32+
#endif /* ATOM_UPPER_TRI_H */
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#ifndef ATOM_VSTACK_H
2+
#define ATOM_VSTACK_H
3+
4+
#include "common.h"
5+
6+
static PyObject *py_make_vstack(PyObject *self, PyObject *args)
7+
{
8+
(void) self;
9+
PyObject *list_obj;
10+
if (!PyArg_ParseTuple(args, "O", &list_obj))
11+
{
12+
return NULL;
13+
}
14+
if (!PyList_Check(list_obj))
15+
{
16+
PyErr_SetString(PyExc_TypeError,
17+
"First argument must be a list of expr capsules");
18+
return NULL;
19+
}
20+
Py_ssize_t n_args = PyList_Size(list_obj);
21+
if (n_args == 0)
22+
{
23+
PyErr_SetString(PyExc_ValueError, "List of expr capsules cannot be empty");
24+
return NULL;
25+
}
26+
expr **expr_args = (expr **) calloc(n_args, sizeof(expr *));
27+
for (Py_ssize_t i = 0; i < n_args; ++i)
28+
{
29+
PyObject *item = PyList_GetItem(list_obj, i);
30+
expr *e = (expr *) PyCapsule_GetPointer(item, EXPR_CAPSULE_NAME);
31+
if (!e)
32+
{
33+
free(expr_args);
34+
PyErr_SetString(PyExc_ValueError, "Invalid expr capsule in list");
35+
return NULL;
36+
}
37+
expr_args[i] = e;
38+
}
39+
int n_vars = expr_args[0]->n_vars;
40+
expr *node = new_vstack(expr_args, (int) n_args, n_vars);
41+
free(expr_args);
42+
if (!node)
43+
{
44+
PyErr_SetString(PyExc_RuntimeError, "failed to create vstack node");
45+
return NULL;
46+
}
47+
expr_retain(node);
48+
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
49+
}
50+
51+
#endif // ATOM_VSTACK_H

sparsediffpy/_bindings/bindings.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "atoms/atanh.h"
99
#include "atoms/broadcast.h"
1010
#include "atoms/cos.h"
11+
#include "atoms/diag_mat.h"
1112
#include "atoms/diag_vec.h"
1213
#include "atoms/entr.h"
1314
#include "atoms/exp.h"
@@ -40,8 +41,10 @@
4041
#include "atoms/tanh.h"
4142
#include "atoms/trace.h"
4243
#include "atoms/transpose.h"
44+
#include "atoms/upper_tri.h"
4345
#include "atoms/variable.h"
4446
#include "atoms/vector_mult.h"
47+
#include "atoms/vstack.h"
4548
#include "atoms/xexp.h"
4649

4750
/* Include problem bindings */
@@ -80,6 +83,8 @@ static PyMethodDef DNLPMethods[] = {
8083
{"make_hstack", py_make_hstack, METH_VARARGS,
8184
"Create hstack node from list of expr capsules and n_vars (make_hstack([e1, "
8285
"e2, ...], n_vars))"},
86+
{"make_vstack", py_make_vstack, METH_VARARGS,
87+
"Create vstack node from list of expr capsules (make_vstack([e1, e2, ...]))"},
8388
{"make_sum", py_make_sum, METH_VARARGS, "Create sum node"},
8489
{"make_neg", py_make_neg, METH_VARARGS, "Create neg node"},
8590
{"make_normal_cdf", py_make_normal_cdf, METH_VARARGS, "Create normal_cdf node"},
@@ -100,12 +105,14 @@ static PyMethodDef DNLPMethods[] = {
100105
"Create prod_axis_one node"},
101106
{"make_sin", py_make_sin, METH_VARARGS, "Create sin node"},
102107
{"make_cos", py_make_cos, METH_VARARGS, "Create cos node"},
108+
{"make_diag_mat", py_make_diag_mat, METH_VARARGS, "Create diag_mat node"},
103109
{"make_diag_vec", py_make_diag_vec, METH_VARARGS, "Create diag_vec node"},
104110
{"make_tan", py_make_tan, METH_VARARGS, "Create tan node"},
105111
{"make_sinh", py_make_sinh, METH_VARARGS, "Create sinh node"},
106112
{"make_tanh", py_make_tanh, METH_VARARGS, "Create tanh node"},
107113
{"make_asinh", py_make_asinh, METH_VARARGS, "Create asinh node"},
108114
{"make_atanh", py_make_atanh, METH_VARARGS, "Create atanh node"},
115+
{"make_upper_tri", py_make_upper_tri, METH_VARARGS, "Create upper_tri node"},
109116
{"make_broadcast", py_make_broadcast, METH_VARARGS, "Create broadcast node"},
110117
{"make_entr", py_make_entr, METH_VARARGS, "Create entr node"},
111118
{"make_logistic", py_make_logistic, METH_VARARGS, "Create logistic node"},

0 commit comments

Comments
 (0)