From 3d40cb79b3defacfdc013e638fcb0f3c38cd9841 Mon Sep 17 00:00:00 2001 From: David Choi Date: Tue, 25 Nov 2025 16:45:04 -0800 Subject: [PATCH 01/50] Clean up e2e, prepare more test cases, Triton model repository is now created by endly. --- .gitignore | 1 + .../sli_1/1/model.savedmodel/saved_model.pb | Bin 82272 -> 0 bytes .../variables/variables.data-00000-of-00001 | Bin 1445 -> 0 bytes .../variables/variables.index | Bin 319 -> 0 bytes .../sli_1/config.pbtxt | 29 --------- .../sli_2/1/model.savedmodel/saved_model.pb | Bin 82272 -> 0 bytes .../variables/variables.data-00000-of-00001 | Bin 1445 -> 0 bytes .../variables/variables.index | Bin 319 -> 0 bytes .../sli_2/config.pbtxt | 27 -------- .../expect-slft_batch-maker.json | 5 ++ .../001_e2e_client/expect-slft_batch.json | 10 +++ .../cases/001_e2e_client/expect-sli.json | 7 +++ .../regression/cases/001_e2e_client/test.yaml | 40 ++---------- .../regression/cases/002_sls_client/test.yaml | 8 +-- .../regression/cases/003_vec_client/test.yaml | 14 ++--- .../expect-batch-cache.json | 2 +- .../004_slf_transform_batch/expect-batch.json | 12 ++-- .../cases/004_slf_transform_batch/expect.json | 2 +- .../cases/004_slf_transform_batch/test.yaml | 11 +--- .../cases/005_lookup_transform/test.yaml | 14 ++--- .../regression/cases/006_metrics/test.yaml | 8 +-- .../regression/cases/007_health/metrics.json | 3 - .../e2e/regression/cases/007_health/test.yaml | 10 +-- .../expect.json | 2 +- .../regression/cases/010_triton/request.json | 6 ++ .../e2e/regression/cases/010_triton/test.yaml | 22 +++++++ .../regression/cases/010_triton_sli/test.yaml | 13 ---- .../cases/011_aux_cache/expect-cache.json | 6 ++ .../cases/011_aux_cache/expect.json | 6 ++ .../regression/cases/011_aux_cache/test.yaml | 23 +++++++ .../e2e/regression/cases/012_router/test.yaml | 23 +++++++ example/e2e/regression/regression.yaml | 2 +- example/e2e/regression/reset/fli.json | 1 + example/e2e/regression/reset/flinc.json | 1 + example/e2e/run.yaml | 13 ++-- example/e2e/system.yaml | 59 +++++++++++++++--- example/server/etc/config.yaml | 51 ++++++++++++++- 37 files changed, 257 insertions(+), 174 deletions(-) delete mode 100644 example/e2e/data/triton_model_repository/sli_1/1/model.savedmodel/saved_model.pb delete mode 100644 example/e2e/data/triton_model_repository/sli_1/1/model.savedmodel/variables/variables.data-00000-of-00001 delete mode 100644 example/e2e/data/triton_model_repository/sli_1/1/model.savedmodel/variables/variables.index delete mode 100644 example/e2e/data/triton_model_repository/sli_1/config.pbtxt delete mode 100644 example/e2e/data/triton_model_repository/sli_2/1/model.savedmodel/saved_model.pb delete mode 100644 example/e2e/data/triton_model_repository/sli_2/1/model.savedmodel/variables/variables.data-00000-of-00001 delete mode 100644 example/e2e/data/triton_model_repository/sli_2/1/model.savedmodel/variables/variables.index delete mode 100644 example/e2e/data/triton_model_repository/sli_2/config.pbtxt create mode 100644 example/e2e/regression/cases/001_e2e_client/expect-slft_batch-maker.json create mode 100644 example/e2e/regression/cases/001_e2e_client/expect-slft_batch.json create mode 100644 example/e2e/regression/cases/001_e2e_client/expect-sli.json delete mode 100644 example/e2e/regression/cases/007_health/metrics.json rename example/e2e/regression/cases/{010_triton_sli => 010_triton}/expect.json (94%) create mode 100644 example/e2e/regression/cases/010_triton/request.json create mode 100644 example/e2e/regression/cases/010_triton/test.yaml delete mode 100644 example/e2e/regression/cases/010_triton_sli/test.yaml create mode 100644 example/e2e/regression/cases/011_aux_cache/expect-cache.json create mode 100644 example/e2e/regression/cases/011_aux_cache/expect.json create mode 100644 example/e2e/regression/cases/011_aux_cache/test.yaml create mode 100644 example/e2e/regression/cases/012_router/test.yaml create mode 100644 example/e2e/regression/reset/fli.json create mode 100644 example/e2e/regression/reset/flinc.json diff --git a/.gitignore b/.gitignore index 9eea4ed..8408ccc 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ /vendor /example/e2e/logs +/example/e2e/data/triton_model_repository # binaries /example/client/mlyc/mlyc diff --git a/example/e2e/data/triton_model_repository/sli_1/1/model.savedmodel/saved_model.pb b/example/e2e/data/triton_model_repository/sli_1/1/model.savedmodel/saved_model.pb deleted file mode 100644 index cb4a79b07f3325adbee16c00df05346ed42b1193..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 82272 zcmeHw3v3+ec^GGx$)k5 z6hR-TTf5)?&tw0Y{bxA4JJfT1hL$t)KfnL`zyEvxNs0X6(>?I+2>Fl0kcarGLV@&x zbR#WYp}#lrSH|DB@i)B}@MU-H=7ydhhC!u&a8!zQ_x8k7Ns@r8kX)#gm$lkU3K@jn zYg%bjH?USlxZ*wD&_ zr;F+br|7#J6$XnM13I)%)a7Rie<731`7I$wppsFtRn;sE}C3lS@xVpxY5!KDR3>> z%X-I4u4$fK_nW4>>*o5K8f?m>Vx>_j})IzapknW{Qr6k7=OB9B`DPcdGY;~U1jH_(j zUQ*1G18UZ1AWF-t*RH)vbL_fK0OS(HVs`ZR|@zK zepG2etmh1&D?I`!!~m7yxdrY~=wnO6zobLOeR834aRZ)*?n~P8ns+z4k&@qk62Je3 z%stEYrW~8-f^SGLbV;isjw9go!h%*R(RI_RtA;|6fZ~R^uN4zFMorv+Qqia#ru*}; z!%}+8!V&jd++-`Hdu6k{j0JiTek;XQ7ldT!H{arlRcaO2x4kf+-_VygYn7@>yWu~= z0`y%fY0LUmTvol>NHW|M0(!kKCP)O-S&I5`wZ?De!xn?12$n%*|znN4Teo z%Si786mAz?Iei0h3TcjwQ(fr;4ynuUf>C8eip9HodXuTX{(->}GP-*_fC5$^wTMVu zsUp_nn0ih8M(n2O$S5B}BZ}{1Fw7>TB1KV+5*_430_jM?EN3Xa45bY0&H9VDmL|)v zzXZLPwQDW*Mw;#IN$d^x;&P1)u+3CpC^rH0K%$2R{nhc;jmdRF4bQ^RW!ympy>Pl% z(%E+FhvbrGu)h&1DYnej!bBW*2b4G0*^Vn16LA>}u2_{O67W3<#xK{jn!d7Gy62Fe z5`(8;jE=v!yr{40H%@PsO3!Px<*O6|?C?WmKm;=P?F1m1No;7fnqDp2y9h1-TxL}4 zIV7#PBE760ViBI~T-}80D3>T5H1rLvsv(-%Z8197W9ZdudZtG?HgjxR8NWtC z<>=}OWoc<2#9iKHw>!gjOem% z=(z15mhZ|8@ia7+O=R}ENpe2{^Y~TfPR?dV+~hTuO_`%9-Tu+c&WT2I#^QsB$qxBW zyU%8Va0(e>vq-7=L^4IX77<6zLV__=YOd|3{W^qq13D))7yTyEAJZ z^_(C3$X-Z5eBsiI!xH`D4F1uj#GdOqDScY%%B1Oev~O`Q>`}3m6}_sLvAvI#wDqNe z_Jle;JDY|7BEg7YL0l-R%Q))|l>uDd<&~OPXPjN8uoOxKvrrKNU3G%wz*u z0Z3B*fbj@ywgI#uqS*%Y(EMzJ8VjE7=lg7$p;$B{Jup?!uOWYt$4!;b`tDVZ8YMD+ zA7wC8V*7oV|0X08q-*}`06&$WFDa2kNxP|6k7n^9WhhF<(79SG!!*P5-8P^=n3D|>#b5~Vg-7F!O z{h^0Rl#G6{V6* zBwx^l|7Bq`IY@>K%6;Irup5Hp5a~4!KWucEZ-v$v5!E1EjZwbFBF<&Ze=fWH=Ca#8 zmvO$$Asp6(_r)Im1?^S#?{%+VuaoXx_wj8~h;_l&DdIdt66-p0EXxLrB=?g66Ibk4 zdy@yqco{`M(JJjvK1}wS<;Cj@$u#L@1&k=rfR6bjISxIH2FgdxuH9%oWiIyAu~S#B zES{Zz@yauor&w_|<85l^f@?c4JIKTAkw%z38U(XL+rjK39L(BAH*=+(ksszq{#fJ4 zGeIMNe7lh!;YQvz%Gx7G1r|~_DT_ZbBIWQWP9z0?x`{N6KRra6!Jl3t&EijjNQfXv zf20#^#qCz?-%TKtqegXsYHMr^jqp-$<|7R!7mA$u5z?3Znf*VB#F7#{K}-z;JufG1 zSy)sQMBSo7fD^WNq}wmX{RJ{VNlg|2Nbs%{%uL9Z)^XAZI5uUQlSYP}X~vL6PRhw!hM0UzItJAr;EY)OMx}_9V!xb%y z%0xHEO(3YO}BBV{#CX{C6CwN(J|yr9TUnr z398GbGG$nrt@QJq$*Y|C$T1|zqxJ6E=?YIE+cuTnSBEMYvESAkPs4zzG_mz&BQTU7 zaIp#bbIC4p1`alG6pcmU0F0?<*Qlz-hQ3@}DcYuwB(We0--!v>ZlL2R3FvN$yMpRP>kVVG>GL(kWWpW8961MPT5rqd z)0pTxkzxb$QjK;7)^b~(fw^}fs*|^*Iq$5eE(jsWYj92@1Yc5UUC+Qm*Ok8hxC!Dw<=3{)JJKhs}-X*Rn<4pZlmS1DE1j^_~X*E*veF7*ugrU zo1#P?Z3D3KdvsY^T~zV1cv?fFXbT(EA=ziVqYK}jYcJe^MezV4{9gHC9QNT6i>?@G z?n2$s^pr^_W&D;2NXj&eY8&*dn1RNI|`Tpjr8= zXzOTrPpGqyd?;|74sRh(Hq@WG>3nUPgNiAmf)UAoTY@8YKg>9X9A%wr=ziMnv#ukZ^oIO@^)6w9bs&`QO( zsd{qe4QKU_PiX)Nz>Gq=W3EvnTDn5%;r-HnSB*yj!6Q(1G_9 zJj(MPCXwKoj+HV3d~VZu|a z0V|qN8#O^OZ@Vt%`$_5~o8z%CSeHXFQbDy}FT{)zs%Pl0hQH`quvuO!SFV>Q#;6CN z^`K@^Q^Gyw++P(`6S|a!OkLq?E_k;UI}(ml#xctF_~iQ7>VM|ye|nYZ_mmarMatw{ zTD%aYdhW!Wvd7ncCx21Q@f`56200r^o_N!zh)sBRW{X9{lqkLs@wG`vK2w&TqQN?Ci^c=4=H$P|{Q6NTMVvSxM@v^% zqKnt&A9tn zdH5g&e<8tMo&m5eh4iSOotdT4H`c3yY6T$RvCg7$8s6vdxW9-Vq#of1Ekcg>w`G*E zdokU^nbD@mVvQQiHZBShXB@&4;XB}M8g?*~{G%AO*#zMoTb*{JoVYB2y5LI_#>eRM{KeciTJnQOF{;r-*!T#zsjeib|FpL zMCkl^R69Hi6AgHPdVS_AeZ+?y86yHj#nz})o?2Uu zT?KU=0})H;lQBEpAe_AcpAblOUX5dTD;uNMb+h`L@M@&ZyNq~vp!f{md&ab6bVPa7 zXZS5Q?-{-eIaBQTI^PJcDer28HXUwB!sa}~58a&ngE$Bq^aON;Zct<~-Z^w+qW%t| zV>G92Sg}n(K|V@c@m<_ju_g54dOB#0*5KlnWZ~kr+`MbE-W`kUnM>%6#tf+fY%zr; z+pU0;)!LR8(U96qfAUXb@R5)eIDzv;L6w%NrQ^0B3!7~f_Vs3G1*7#{0ha{ZsQAE*OQFlr;mgwYPv3D1w}>hbwDdm32F#t>{PHh|Iqr! zm~trG4CFtM;K_i|gz*t*d()n3dkFI9(T3O5G%O#9!vU2X++%JN&jyqX10}&?hVrDC zYI0)o-;zKyE6d*hOtk*bnectd|4Y298U{wV>^k!Croiq1SwA2c6ParsH$#30FX;zhof^)vgn2uQ#Xz{pLrOJD0<#|UF z45r-I5GBW#ww0n|JLBN1d@Oi|ByhFDhx^PP=be9DTO%cR!hI;s2$`GaPRN-}DcWy) zxX(6}_=KAX-M5tZGkf8mVZH@kKh%(TiVQ&?&5q{WHb7p2MVgih{gZ=emOC#E8kq2w z2rEG@pPALG7ba7e`PFsj!Oi#`AQ*_^WLih<=bOfKeroDiM%gXy7^4@= zfK$t6AGQ})n!Vk2?90GBO>WyY-*apOFLbB}zm)8xzIGz;mb%$%cSGifSvP>b{_7i4 ztgn>Qh5QA)s=~_iZ~e;qXdIU}SAE@(>UGQ>USKPmb4qx}YbEAN_^kpB4>5-z*Ia4t zxs>;=)s=_L;z2!{&Nn_G%?f7zy6mLWmtW;?A86Yj1KY0OWq=VkyF9lbOktMAYC(uX zp%ONWVxhZ~ai4zVuVFHH&j^V3g*;k^h67ajAx|}P;PvmL^=UdU-D~eVA9|8|T9^Mv z3`Q@r`z@yU>oa@+B{%rBBNpW`lrGuhPN4d^n7tGJpEUn}B-{(E(7e)An%|@J=AZeQ z=IWSPd4xUrrv&!zDWHzk)1@y4JE?qo(;mWD`^UtUk=bw9K;fi`@@zkYfX<8Wp ze5nuFxb+?}b3^AYDQOq>`)tsU%@YUkUGvs|Ovhj@xVK~Z%iYdpC)TAWvg|>ayq)V5 z1&K+)#2lQI5BN;4USbCr0eMH50NEL7bO&{Pq|vph=46%&wtnz?DT6RTDKqL^fP-h0 z(u4Ue?aUML$d*7BCopeON#xfj%stw2kOHsSgY29`#`XU*W&* zg73s&f|tP!B-$#AI?iPYN`f4jMs^A%<=5er2t{vCc379DtwHK^Nuge;blM}O&Ul_@ zC(O97NYk7DW$ZTG)U#voZZgycMPlDOfhHSRdps=|!KIP4&qoR+`3m%)r+QQI5%g5E zF3fRcc*eC;8J;yJFG3vuFasPsXfF@4BK5;&Tk#n4@?jN}QWN;2=9(Pqggw@w`2nT~ zNE|h#1qC@Qt3mk{HgG(BSvE=Uh=;9WO^n%u1ag7DA~qAyo>^eP`9KpufGUfR@qWg- zyVa(U;rNb=Zf%aAdcrMEj@aN#&i68KcGWdEA4oo&j8|oHn#a#cAU8?G=i@FIL*y9#ZdZAU?4xhX-uDE-&N9t-YJ3C zLRbW-j)j~~4->`KXFbEsxg%C?zMHEhQQ5?eTDl*L#nHCqB;fl-dW6FGrCLJE$CG6WcjWe|}-??)uaH$LKk}DeIwC^nHCv+S6`cT~| z>fkotcYpz%e)k=C?qx6s>rB0K%>mv3Y~1VD9Bja@hH&mkZXH8Yxv*iXa2{#+;)~jbnUrPtUhC@Y~@!Pcw9cWtQlHv!4WB2u=Ax0-c8CP z?O7zEgHp#LK}{TY)*=zzYTiQ@iQ748UJ0HHvPeWXD9$1gU8D8o4W|VTZIRfCA%;=P z-xU^#J2m-0i$rvTs^=C&i$rwI+PK)UNHA|wgLo_w(LtI7r-k?DSwvh`WEP1Z?S@{m8>EAa^8k%xG zllq_T7?=G{5f|>%l{p9N1-#R93VHo<4>!10QO{Q1r`7sf?1aK&&BPw}c<~0U;-|RfJ-y|gH(61<^rSs3Y4_F?@65J!R^%P~f57|qY~$H~$7gD@ z&;GuHE}_-sX_Mo8Wa@E{nSEX#8El~*Mqu6%)7J9h9WiayzZK2 zN1=6ZM7iU-rscC!HO*MVpMri3-CgnvmtM?g{ltc*mc?`QcU;#T5Bhnpa2Fajb)9#| zb*)my3wYfwhK|Cf2 z-QHTI>9$rN^xaWS<%^5x?)r}F8qHf{5KM5BZT;7x9xM2N*Q zqJ~x~zRi;Jo^a>kxNlQJ{`w?*CCd0bMLE%M4Ra?Z@1Bjx`(nHe`CAkb0DglS3{m!TIHv{I?@j-eaOI?|>^fH#G7w6j)^m22@74(A>1C-`iaWX1$8CLd-i`O?=CseQ#DyWr(1WD@+iYP z+R0Tb>>;%8OGi5?GRra#IFE`pG)U!+c2c#XT!q?-SoGS8OH^BNZt-k1j_*0VHNK;r ztOYNkf_a~6E63Kg+SOw;U-9(RDa-?oIgO9lZ>xG?6LnCiZlYf=;m;I%MSTHi-WsYc zY$CC+*y-q4bN$`xdB-~1$!N8c*^YKnFQeh?=622^5lLAOnm5?}DIrFnokYVxM?3kz zX(zuk2_Ho2_=Jhc0W6O_R4ep@9TSL|Habd%wR={{z=DW8PD`N$FtJSQX{|+%8wbB0 zNJpsmsApQMaJgv33EVk3`gP4@OqizV6R6FVh8ncti5c^)9)O;eDB3+A;M6cbo z#dSxB#{`!hXDFW-r;fUz?B`(FafV`-QwK!7bey3aZ?BHJ!8an1w9aAWFX9DK=OC=t zARO?yQ7TC+ehM3uqHS<887Jrv*x=XL4Kd6<6nuy82Pyaq3HI`lXXxozt5ofaTu+kS zK&)2<)yko!W>FLrz0Xag$i;0g>@*FGb}y#C1$*bD6yr^hnC+2DpaoN-}((JjhCwfK0MS(IrpO-N@W*(bD3*g!dz37P#M=7;F&w`7DfgL29u4d8BjC z`rkDUM|=p|7!e>UwkAc^S}xs_{GIOMe)&8M(<`Q{824SmdV8mCcgJ zvs}$#qQ^acAC;4_^nMtntrgVuNULy!G1gAcv-!z?}|?1$s{ zkhC98z#f#uc2)J&&5~9vzO6Cas2L><<4GW+wJI((%O7wJ3(!9d7)bV!6#v*yQfqou zGgNJ(2!qK10!DGQtYP>ymd+%Pl0(ZF6UDQq^bx}WJb~HJpFBpAs>({Mus1nH#%*k1 z$fY{%rvLdd7*T60s@ZW_uU)TH*GP9#8kXd59s6XO2G)GdJL9u7 z^x`~4IyxM#LD9&(l~uL2xq%#_O6wMD`a0{v9DnYOckrb%XA_3BoUx%*4GIT)7FBI^ zRWEQa9Jp&~+6MPXyD2YKHZg3!j1-1-NBi*(?5>_w$vrEEYN2Lznc?C2N(B|>p4FM~Ws7AzNNU_du4~9YR@U?~T__8!%bV+JrE<;otiTUB93y?hO5$_+ zwHGg(dxMUA9kWx_U(+{wdAU-+okfdn)>e+5bi>la!qOTIkbUj7w4vdbEc^c_-+AK= z5e9{UwMQpRdwYOjEttI!))c>kVZBh#=0su05g{9`@l`$FlMkF7Qp1r~?M9KJ-q9RS zM;aOv4!aSRoJkxxfGE^BedoInp{|_1xN%-vUn)@C6V1`3;A~=!E7qAsw}kEHiRHYcA=V*eczh_CiqYYo)BGan}gQZs<9>e6HG=E7y>4&fIr(nsy|ZyB9w8 z+%4qcr_VL#1A@RMHo3>v=IcG;#Ak1M=K3?Iod3N*|DylQKTZEyzCJ_$&Yn6`nO`|I zzjo^77Z%=n;khfP7S^7=vbi=-%ba`RS#AEz=^5>n=Qmz2pQ=?B3TKwI1?8ogWo?CnMQJw=4IP8$BL(^OTvPT}wR<}_+mM6PXf9hD=LbbRt1o^;d~un~`5OWTOV zVF^Pzo-n+@a1ePeGYt{rU??@gLD6FPnLD1|prx&$)>*jMD4&UGhwLg0W-NqOHIVr&qEe{9V;r>JU#TSN2}R&lmgvvv>S|f zL*OEGy}dE!?d}{bU5};BVyWO^*&6_U1$mxI^)k(=ML{m&jlI*}whHaMXsD}IZR4tH zNiqm?OZG4Y9Av-@lwa_5S~)gzEK56I(w310qYQSsE`-t#rjX8|Ly$1U6|)Zmv0PL~ zRPPM$YOcH^YtS1x97#9~|J%qSJxKhoVD_Rzp(ATVAlEfA+X6<8J4W$P@saBm>Wr@E zEVQW_5;_+vEW&B~+)Q_5je?NxW?mZX`#`+$*paZbBtib0yx|m z`-6o~BO-hv96pUkj!y!vSYUs+xA7_59PclD`p4a9Zy)owx9{?^x9@h_+s8SKL;Lzn z6ZZAlx_v$U-MAm^+lqT>-$Gq@q9>nC7LZbGD^rEVQ%>X(ExWDCv|L!i5j30T2`E~a zp*ysQ8=K4Ybmw%2@4&x_;bAtxp+fqPVtxzy+YM}++oWT!r>*IB#Jo1>4hOTy*3-B^j<$HuS0~xJOdrk7MXO_z+2~>*%nwY`{ozKN(;$lWM=(n>;|q@sdH)RnFr$xIg(Y z*=v>;yBtC?O?p`YBMQ&N@a}|7c_|)3kH)muqY>E&k2hm_61{*|^|FqzQa8~zRIMz% zg>Hn_f*e6v^(4v`Zl~AuE*Nl}9MKYvmuOR$M(^@b_fpRSUxcyb92r)ZF%L6Jzvy7d zaxlYCGEe%@tZ83#0=tqQB_rroiF%AmJy-K$*Q*22!<_0#_i%P12M=vC0T0gk4C!OS zA${CCq$jo=k}@&f?2vY78+@H^gM<7wc%*R~JQ}nO4sEv$KEiE-wgra%B**nPB`+L8 z*`j?+-wu(s`|d%som|z!{Hi|IxT=|;RegNBRXxJ3YTHasf0E-=n&NQ|AwSgqsy4kU z0s_%=8-lfe1$xlGUdrN6j7T~Bi4(kNt_MBirD^=>A<_)~^b%!g{YYL{Z{k{ah zFa2)(8xr)LMyI+{)m7?FMRL%Mu5AW+?EUy5=zXi(&`T?je9H==1b8dS#CUYCreCS^ zH5q^e>c?3vyy$`S(b)Xz+I4zW0~v>;6@p`szW3uvh;4KM{V)QscAehv!0N@{k9Xm< zf`jHOr$+okd`+NFE0-&%t3jtFgV1-wQU=?A_u|Tz;*blE04rs68WV&!U1cuYY?Yau z(sUK&`2DabKarIGEx~?%Q7rD#j`*6+%}$elDZv*I{XXYGzdj577($HH5espRA&LvR z207&N!6}Td^xzaNC_t&q(_8UV@Z|gPkHO@v*d7-6)ff>rk+8pNMELdJO7s6(f(d^1 zG?GoFm^X{NZgBakzPz?kDVA#(Zgg`(N`EN9R}c_?(*uZo79e)LA15&UmZx6eo{JPL z{af)AByPpJNf%E3)3N!VZ42FE!%Q*92^32?Y0{iMHK;9a*BAuLg zXg95ibl!Wm;=3VrE7nUeq!4HB{F#ElHJ(0kzppXg#c}_y@qd?;|0YF~f_G+mnrwyG zFjYrqtIg*%<*Ljei4tqzaDh_d06Wao3%!UeC;}8h$21oC{z9SB<1T8 z?BOS3>H#ZN-15hdlP^i|IUM(Ani%)aac%Lfco%$AQhrC;<~HWLaM-^qDZlMurfgNJ zd4=@d20kA?LI~QQlSuc_ zljJG+!dCnoL?h+0MVk6Begra3x|0L9V*8v2foz;2|1?5|`ZJRJ>Q?+^xU?029?rcV ze-_fWV*Bls$x~U8be?jT$%)IbxjniSPs83@u?dg_O~ocsz!Vf;YnENUwxgq-xE0&g zI4K-0`3)%o=Jj=XdMloV$F@*2$aEx@QjiRgzNCGN=g%Yr4&@_a8xn+oB`wZi+5Qk` zbEMfJb~QD`8)5m`i>*OQynYBF60YatInr#;765{J4op(_2_g;2=eBVFo!W{&0sC&n zWKr5WONrw1ImbRD9&ofjj&t_!8AU1Vne0%{zZ475;QdcD5Dm$Pw&Dk2ax1>(=F^=vFeI~6)}yJP67-M@?jg5b4>`Vv6!C`&>bxja$b$$KRAPn^ z`o3tN!Y^*c7fnv*+=ukkVoAv?jkAwVnU*P$K&^DT>BIO>rHJiGKX1?QrLFj7IKLG? z+rBoPo+0!sY33Y>c4Ep2vNIE7R5S_L?cDukhlpzM3Ktd7{ezDDe3sNE>Bo@T1A~b5y$c%7IG~Iy6A#xK| zIYVv$2Z)q&AZk918W#>1naB=?oFI+O{;S6&vaq9ZBH6kfZ>M+F9I1v8hoOo*AIdQ$_zLKFVtz~m0%OuL` zlr5uW^F(Y-7KmguStOFvWQhpZWEqBe9tVB5dDIfV_e&`6u{PduY7Z9Y^YzuYdliH%E1|{ByMX7``69F$`)=@!L zO8Uag5=8VTMW7K;G9$LAmVwT-40fvJ<;_~kX3iqn8oGTjY#s!+jY(&^S$H*bAy z<_tP!!CA04qnP{QD&O0a(9g%P5bcUSs_-{hZqvk)rZmRiVkOci&lmVOWL7q;bA+EvtU=iKy z)XHL@5iCMX9IP6GroSfUPSRr&EKAHOO)ci&gorB$YG|1Hd+v-9mKQW6&L{;-oQ8Qp z4PyWTMi8;CZc+=JbKuj$4GfITxePDCa$v%<$)k5 z6hR-TTf5)?&tw0Y{bxA4JJfT1hL$t)KfnL`zyEvxNs0X6(>?I+2>Fl0kcarGLV@&x zbR#WYp}#lrSH|DB@i)B}@MU-H=7ydhhC!u&a8!zQ_x8k7Ns@r8kX)#gm$lkU3K@jn zYg%bjH?USlxZ*wD&_ zr;F+br|7#J6$XnM13I)%)a7Rie<731`7I$wppsFtRn;sE}C3lS@xVpxY5!KDR3>> z%X-I4u4$fK_nW4>>*o5K8f?m>Vx>_j})IzapknW{Qr6k7=OB9B`DPcdGY;~U1jH_(j zUQ*1G18UZ1AWF-t*RH)vbL_fK0OS(HVs`ZR|@zK zepG2etmh1&D?I`!!~m7yxdrY~=wnO6zobLOeR834aRZ)*?n~P8ns+z4k&@qk62Je3 z%stEYrW~8-f^SGLbV;isjw9go!h%*R(RI_RtA;|6fZ~R^uN4zFMorv+Qqia#ru*}; z!%}+8!V&jd++-`Hdu6k{j0JiTek;XQ7ldT!H{arlRcaO2x4kf+-_VygYn7@>yWu~= z0`y%fY0LUmTvol>NHW|M0(!kKCP)O-S&I5`wZ?De!xn?12$n%*|znN4Teo z%Si786mAz?Iei0h3TcjwQ(fr;4ynuUf>C8eip9HodXuTX{(->}GP-*_fC5$^wTMVu zsUp_nn0ih8M(n2O$S5B}BZ}{1Fw7>TB1KV+5*_430_jM?EN3Xa45bY0&H9VDmL|)v zzXZLPwQDW*Mw;#IN$d^x;&P1)u+3CpC^rH0K%$2R{nhc;jmdRF4bQ^RW!ympy>Pl% z(%E+FhvbrGu)h&1DYnej!bBW*2b4G0*^Vn16LA>}u2_{O67W3<#xK{jn!d7Gy62Fe z5`(8;jE=v!yr{40H%@PsO3!Px<*O6|?C?WmKm;=P?F1m1No;7fnqDp2y9h1-TxL}4 zIV7#PBE760ViBI~T-}80D3>T5H1rLvsv(-%Z8197W9ZdudZtG?HgjxR8NWtC z<>=}OWoc<2#9iKHw>!gjOem% z=(z15mhZ|8@ia7+O=R}ENpe2{^Y~TfPR?dV+~hTuO_`%9-Tu+c&WT2I#^QsB$qxBW zyU%8Va0(e>vq-7=L^4IX77<6zLV__=YOd|3{W^qq13D))7yTyEAJZ z^_(C3$X-Z5eBsiI!xH`D4F1uj#GdOqDScY%%B1Oev~O`Q>`}3m6}_sLvAvI#wDqNe z_Jle;JDY|7BEg7YL0l-R%Q))|l>uDd<&~OPXPjN8uoOxKvrrKNU3G%wz*u z0Z3B*fbj@ywgI#uqS*%Y(EMzJ8VjE7=lg7$p;$B{Jup?!uOWYt$4!;b`tDVZ8YMD+ zA7wC8V*7oV|0X08q-*}`06&$WFDa2kNxP|6k7n^9WhhF<(79SG!!*P5-8P^=n3D|>#b5~Vg-7F!O z{h^0Rl#G6{V6* zBwx^l|7Bq`IY@>K%6;Irup5Hp5a~4!KWucEZ-v$v5!E1EjZwbFBF<&Ze=fWH=Ca#8 zmvO$$Asp6(_r)Im1?^S#?{%+VuaoXx_wj8~h;_l&DdIdt66-p0EXxLrB=?g66Ibk4 zdy@yqco{`M(JJjvK1}wS<;Cj@$u#L@1&k=rfR6bjISxIH2FgdxuH9%oWiIyAu~S#B zES{Zz@yauor&w_|<85l^f@?c4JIKTAkw%z38U(XL+rjK39L(BAH*=+(ksszq{#fJ4 zGeIMNe7lh!;YQvz%Gx7G1r|~_DT_ZbBIWQWP9z0?x`{N6KRra6!Jl3t&EijjNQfXv zf20#^#qCz?-%TKtqegXsYHMr^jqp-$<|7R!7mA$u5z?3Znf*VB#F7#{K}-z;JufG1 zSy)sQMBSo7fD^WNq}wmX{RJ{VNlg|2Nbs%{%uL9Z)^XAZI5uUQlSYP}X~vL6PRhw!hM0UzItJAr;EY)OMx}_9V!xb%y z%0xHEO(3YO}BBV{#CX{C6CwN(J|yr9TUnr z398GbGG$nrt@QJq$*Y|C$T1|zqxJ6E=?YIE+cuTnSBEMYvESAkPs4zzG_mz&BQTU7 zaIp#bbIC4p1`alG6pcmU0F0?<*Qlz-hQ3@}DcYuwB(We0--!v>ZlL2R3FvN$yMpRP>kVVG>GL(kWWpW8961MPT5rqd z)0pTxkzxb$QjK;7)^b~(fw^}fs*|^*Iq$5eE(jsWYj92@1Yc5UUC+Qm*Ok8hxC!Dw<=3{)JJKhs}-X*Rn<4pZlmS1DE1j^_~X*E*veF7*ugrU zo1#P?Z3D3KdvsY^T~zV1cv?fFXbT(EA=ziVqYK}jYcJe^MezV4{9gHC9QNT6i>?@G z?n2$s^pr^_W&D;2NXj&eY8&*dn1RNI|`Tpjr8= zXzOTrPpGqyd?;|74sRh(Hq@WG>3nUPgNiAmf)UAoTY@8YKg>9X9A%wr=ziMnv#ukZ^oIO@^)6w9bs&`QO( zsd{qe4QKU_PiX)Nz>Gq=W3EvnTDn5%;r-HnSB*yj!6Q(1G_9 zJj(MPCXwKoj+HV3d~VZu|a z0V|qN8#O^OZ@Vt%`$_5~o8z%CSeHXFQbDy}FT{)zs%Pl0hQH`quvuO!SFV>Q#;6CN z^`K@^Q^Gyw++P(`6S|a!OkLq?E_k;UI}(ml#xctF_~iQ7>VM|ye|nYZ_mmarMatw{ zTD%aYdhW!Wvd7ncCx21Q@f`56200r^o_N!zh)sBRW{X9{lqkLs@wG`vK2w&TqQN?Ci^c=4=H$P|{Q6NTMVvSxM@v^% zqKnt&A9tn zdH5g&e<8tMo&m5eh4iSOotdT4H`c3yY6T$RvCg7$8s6vdxW9-Vq#of1Ekcg>w`G*E zdokU^nbD@mVvQQiHZBShXB@&4;XB}M8g?*~{G%AO*#zMoTb*{JoVYB2y5LI_#>eRM{KeciTJnQOF{;r-*!T#zsjeib|FpL zMCkl^R69Hi6AgHPdVS_AeZ+?y86yHj#nz})o?2Uu zT?KU=0})H;lQBEpAe_AcpAblOUX5dTD;uNMb+h`L@M@&ZyNq~vp!f{md&ab6bVPa7 zXZS5Q?-{-eIaBQTI^PJcDer28HXUwB!sa}~58a&ngE$Bq^aON;Zct<~-Z^w+qW%t| zV>G92Sg}n(K|V@c@m<_ju_g54dOB#0*5KlnWZ~kr+`MbE-W`kUnM>%6#tf+fY%zr; z+pU0;)!LR8(U96qfAUXb@R5)eIDzv;L6w%NrQ^0B3!7~f_Vs3G1*7#{0ha{ZsQAE*OQFlr;mgwYPv3D1w}>hbwDdm32F#t>{PHh|Iqr! zm~trG4CFtM;K_i|gz*t*d()n3dkFI9(T3O5G%O#9!vU2X++%JN&jyqX10}&?hVrDC zYI0)o-;zKyE6d*hOtk*bnectd|4Y298U{wV>^k!Croiq1SwA2c6ParsH$#30FX;zhof^)vgn2uQ#Xz{pLrOJD0<#|UF z45r-I5GBW#ww0n|JLBN1d@Oi|ByhFDhx^PP=be9DTO%cR!hI;s2$`GaPRN-}DcWy) zxX(6}_=KAX-M5tZGkf8mVZH@kKh%(TiVQ&?&5q{WHb7p2MVgih{gZ=emOC#E8kq2w z2rEG@pPALG7ba7e`PFsj!Oi#`AQ*_^WLih<=bOfKeroDiM%gXy7^4@= zfK$t6AGQ})n!Vk2?90GBO>WyY-*apOFLbB}zm)8xzIGz;mb%$%cSGifSvP>b{_7i4 ztgn>Qh5QA)s=~_iZ~e;qXdIU}SAE@(>UGQ>USKPmb4qx}YbEAN_^kpB4>5-z*Ia4t zxs>;=)s=_L;z2!{&Nn_G%?f7zy6mLWmtW;?A86Yj1KY0OWq=VkyF9lbOktMAYC(uX zp%ONWVxhZ~ai4zVuVFHH&j^V3g*;k^h67ajAx|}P;PvmL^=UdU-D~eVA9|8|T9^Mv z3`Q@r`z@yU>oa@+B{%rBBNpW`lrGuhPN4d^n7tGJpEUn}B-{(E(7e)An%|@J=AZeQ z=IWSPd4xUrrv&!zDWHzk)1@y4JE?qo(;mWD`^UtUk=bw9K;fi`@@zkYfX<8Wp ze5nuFxb+?}b3^AYDQOq>`)tsU%@YUkUGvs|Ovhj@xVK~Z%iYdpC)TAWvg|>ayq)V5 z1&K+)#2lQI5BN;4USbCr0eMH50NEL7bO&{Pq|vph=46%&wtnz?DT6RTDKqL^fP-h0 z(u4Ue?aUML$d*7BCopeON#xfj%stw2kOHsSgY29`#`XU*W&* zg73s&f|tP!B-$#AI?iPYN`f4jMs^A%<=5er2t{vCc379DtwHK^Nuge;blM}O&Ul_@ zC(O97NYk7DW$ZTG)U#voZZgycMPlDOfhHSRdps=|!KIP4&qoR+`3m%)r+QQI5%g5E zF3fRcc*eC;8J;yJFG3vuFasPsXfF@4BK5;&Tk#n4@?jN}QWN;2=9(Pqggw@w`2nT~ zNE|h#1qC@Qt3mk{HgG(BSvE=Uh=;9WO^n%u1ag7DA~qAyo>^eP`9KpufGUfR@qWg- zyVa(U;rNb=Zf%aAdcrMEj@aN#&i68KcGWdEA4oo&j8|oHn#a#cAU8?G=i@FIL*y9#ZdZAU?4xhX-uDE-&N9t-YJ3C zLRbW-j)j~~4->`KXFbEsxg%C?zMHEhQQ5?eTDl*L#nHCqB;fl-dW6FGrCLJE$CG6WcjWe|}-??)uaH$LKk}DeIwC^nHCv+S6`cT~| z>fkotcYpz%e)k=C?qx6s>rB0K%>mv3Y~1VD9Bja@hH&mkZXH8Yxv*iXa2{#+;)~jbnUrPtUhC@Y~@!Pcw9cWtQlHv!4WB2u=Ax0-c8CP z?O7zEgHp#LK}{TY)*=zzYTiQ@iQ748UJ0HHvPeWXD9$1gU8D8o4W|VTZIRfCA%;=P z-xU^#J2m-0i$rvTs^=C&i$rwI+PK)UNHA|wgLo_w(LtI7r-k?DSwvh`WEP1Z?S@{m8>EAa^8k%xG zllq_T7?=G{5f|>%l{p9N1-#R93VHo<4>!10QO{Q1r`7sf?1aK&&BPw}c<~0U;-|RfJ-y|gH(61<^rSs3Y4_F?@65J!R^%P~f57|qY~$H~$7gD@ z&;GuHE}_-sX_Mo8Wa@E{nSEX#8El~*Mqu6%)7J9h9WiayzZK2 zN1=6ZM7iU-rscC!HO*MVpMri3-CgnvmtM?g{ltc*mc?`QcU;#T5Bhnpa2Fajb)9#| zb*)my3wYfwhK|Cf2 z-QHTI>9$rN^xaWS<%^5x?)r}F8qHf{5KM5BZT;7x9xM2N*Q zqJ~x~zRi;Jo^a>kxNlQJ{`w?*CCd0bMLE%M4Ra?Z@1Bjx`(nHe`CAkb0DglS3{m!TIHv{I?@j-eaOI?|>^fH#G7w6j)^m22@74(A>1C-`iaWX1$8CLd-i`O?=CseQ#DyWr(1WD@+iYP z+R0Tb>>;%8OGi5?GRra#IFE`pG)U!+c2c#XT!q?-SoGS8OH^BNZt-k1j_*0VHNK;r ztOYNkf_a~6E63Kg+SOw;U-9(RDa-?oIgO9lZ>xG?6LnCiZlYf=;m;I%MSTHi-WsYc zY$CC+*y-q4bN$`xdB-~1$!N8c*^YKnFQeh?=622^5lLAOnm5?}DIrFnokYVxM?3kz zX(zuk2_Ho2_=Jhc0W6O_R4ep@9TSL|Habd%wR={{z=DW8PD`N$FtJSQX{|+%8wbB0 zNJpsmsApQMaJgv33EVk3`gP4@OqizV6R6FVh8ncti5c^)9)O;eDB3+A;M6cbo z#dSxB#{`!hXDFW-r;fUz?B`(FafV`-QwK!7bey3aZ?BHJ!8an1w9aAWFX9DK=OC=t zARO?yQ7TC+ehM3uqHS<887Jrv*x=XL4Kd6<6nuy82Pyaq3HI`lXXxozt5ofaTu+kS zK&)2<)yko!W>FLrz0Xag$i;0g>@*FGb}y#C1$*bD6yr^hnC+2DpaoN-}((JjhCwfK0MS(IrpO-N@W*(bD3*g!dz37P#M=7;F&w`7DfgL29u4d8BjC z`rkDUM|=p|7!e>UwkAc^S}xs_{GIOMe)&8M(<`Q{824SmdV8mCcgJ zvs}$#qQ^acAC;4_^nMtntrgVuNULy!G1gAcv-!z?}|?1$s{ zkhC98z#f#uc2)J&&5~9vzO6Cas2L><<4GW+wJI((%O7wJ3(!9d7)bV!6#v*yQfqou zGgNJ(2!qK10!DGQtYP>ymd+%Pl0(ZF6UDQq^bx}WJb~HJpFBpAs>({Mus1nH#%*k1 z$fY{%rvLdd7*T60s@ZW_uU)TH*GP9#8kXd59s6XO2G)GdJL9u7 z^x`~4IyxM#LD9&(l~uL2xq%#_O6wMD`a0{v9DnYOckrb%XA_3BoUx%*4GIT)7FBI^ zRWEQa9Jp&~+6MPXyD2YKHZg3!j1-1-NBi*(?5>_w$vrEEYN2Lznc?C2N(B|>p4FM~Ws7AzNNU_du4~9YR@U?~T__8!%bV+JrE<;otiTUB93y?hO5$_+ zwHGg(dxMUA9kWx_U(+{wdAU-+okfdn)>e+5bi>la!qOTIkbUj7w4vdbEc^c_-+AK= z5e9{UwMQpRdwYOjEttI!))c>kVZBh#=0su05g{9`@l`$FlMkF7Qp1r~?M9KJ-q9RS zM;aOv4!aSRoJkxxfGE^BedoInp{|_1xN%-vUn)@C6V1`3;A~=!E7qAsw}kEHiRHYcA=V*eczh_CiqYYo)BGan}gQZs<9>e6HG=E7y>4&fIr(nsy|ZyB9w8 z+%4qcr_VL#1A@RMHo3>v=IcG;#Ak1M=K3?Iod3N*|DylQKTZEyzCJ_$&Yn6`nO`|I zzjo^77Z%=n;khfP7S^7=vbi=-%ba`RS#AEz=^5>n=Qmz2pQ=?B3TKwI1?8ogWo?CnMQJw=4IP8$BL(^OTvPT}wR<}_+mM6PXf9hD=LbbRt1o^;d~un~`5OWTOV zVF^Pzo-n+@a1ePeGYt{rU??@gLD6FPnLD1|prx&$)>*jMD4&UGhwLg0W-NqOHIVr&qEe{9V;r>JU#TSN2}R&lmgvvv>S|f zL*OEGy}dE!?d}{bU5};BVyWO^*&6_U1$mxI^)k(=ML{m&jlI*}whHaMXsD}IZR4tH zNiqm?OZG4Y9Av-@lwa_5S~)gzEK56I(w310qYQSsE`-t#rjX8|Ly$1U6|)Zmv0PL~ zRPPM$YOcH^YtS1x97#9~|J%qSJxKhoVD_Rzp(ATVAlEfA+X6<8J4W$P@saBm>Wr@E zEVQW_5;_+vEW&B~+)Q_5je?NxW?mZX`#`+$*paZbBtib0yx|m z`-6o~BO-hv96pUkj!y!vSYUs+xA7_59PclD`p4a9Zy)owx9{?^x9@h_+s8SKL;Lzn z6ZZAlx_v$U-MAm^+lqT>-$Gq@q9>nC7LZbGD^rEVQ%>X(ExWDCv|L!i5j30T2`E~a zp*ysQ8=K4Ybmw%2@4&x_;bAtxp+fqPVtxzy+YM}++oWT!r>*IB#Jo1>4hOTy*3-B^j<$HuS0~xJOdrk7MXO_z+2~>*%nwY`{ozKN(;$lWM=(n>;|q@sdH)RnFr$xIg(Y z*=v>;yBtC?O?p`YBMQ&N@a}|7c_|)3kH)muqY>E&k2hm_61{*|^|FqzQa8~zRIMz% zg>Hn_f*e6v^(4v`Zl~AuE*Nl}9MKYvmuOR$M(^@b_fpRSUxcyb92r)ZF%L6Jzvy7d zaxlYCGEe%@tZ83#0=tqQB_rroiF%AmJy-K$*Q*22!<_0#_i%P12M=vC0T0gk4C!OS zA${CCq$jo=k}@&f?2vY78+@H^gM<7wc%*R~JQ}nO4sEv$KEiE-wgra%B**nPB`+L8 z*`j?+-wu(s`|d%som|z!{Hi|IxT=|;RegNBRXxJ3YTHasf0E-=n&NQ|AwSgqsy4kU z0s_%=8-lfe1$xlGUdrN6j7T~Bi4(kNt_MBirD^=>A<_)~^b%!g{YYL{Z{k{ah zFa2)(8xr)LMyI+{)m7?FMRL%Mu5AW+?EUy5=zXi(&`T?je9H==1b8dS#CUYCreCS^ zH5q^e>c?3vyy$`S(b)Xz+I4zW0~v>;6@p`szW3uvh;4KM{V)QscAehv!0N@{k9Xm< zf`jHOr$+okd`+NFE0-&%t3jtFgV1-wQU=?A_u|Tz;*blE04rs68WV&!U1cuYY?Yau z(sUK&`2DabKarIGEx~?%Q7rD#j`*6+%}$elDZv*I{XXYGzdj577($HH5espRA&LvR z207&N!6}Td^xzaNC_t&q(_8UV@Z|gPkHO@v*d7-6)ff>rk+8pNMELdJO7s6(f(d^1 zG?GoFm^X{NZgBakzPz?kDVA#(Zgg`(N`EN9R}c_?(*uZo79e)LA15&UmZx6eo{JPL z{af)AByPpJNf%E3)3N!VZ42FE!%Q*92^32?Y0{iMHK;9a*BAuLg zXg95ibl!Wm;=3VrE7nUeq!4HB{F#ElHJ(0kzppXg#c}_y@qd?;|0YF~f_G+mnrwyG zFjYrqtIg*%<*Ljei4tqzaDh_d06Wao3%!UeC;}8h$21oC{z9SB<1T8 z?BOS3>H#ZN-15hdlP^i|IUM(Ani%)aac%Lfco%$AQhrC;<~HWLaM-^qDZlMurfgNJ zd4=@d20kA?LI~QQlSuc_ zljJG+!dCnoL?h+0MVk6Begra3x|0L9V*8v2foz;2|1?5|`ZJRJ>Q?+^xU?029?rcV ze-_fWV*Bls$x~U8be?jT$%)IbxjniSPs83@u?dg_O~ocsz!Vf;YnENUwxgq-xE0&g zI4K-0`3)%o=Jj=XdMloV$F@*2$aEx@QjiRgzNCGN=g%Yr4&@_a8xn+oB`wZi+5Qk` zbEMfJb~QD`8)5m`i>*OQynYBF60YatInr#;765{J4op(_2_g;2=eBVFo!W{&0sC&n zWKr5WONrw1ImbRD9&ofjj&t_!8AU1Vne0%{zZ475;QdcD5Dm$Pw&Dk2ax1>(=F^=vFeI~6)}yJP67-M@?jg5b4>`Vv6!C`&>bxja$b$$KRAPn^ z`o3tN!Y^*c7fnv*+=ukkVoAv?jkAwVnU*P$K&^DT>BIO>rHJiGKX1?QrLFj7IKLG? z+rBoPo+0!sY33Y>c4Ep2vNIE7R5S_L?cDukhlpzM3Ktd7{ezDDe3sNE>Bo@T1A~b5y$c%7IG~Iy6A#xK| zIYVv$2Z)q&AZk918W#>1naB=?oFI+O{;S6&vaq9ZBH6kfZ>M+F9I1v8hoOo*AIdQ$_zLKFVtz~m0%OuL` zlr5uW^F(Y-7KmguStOFvWQhpZWEqBe9tVB5dDIfV_e&`6u{PduY7Z9Y^YzuYdliH%E1|{ByMX7``69F$`)=@!L zO8Uag5=8VTMW7K;G9$LAmVwT-40fvJ<;_~kX3iqn8oGTjY#s!+jY(&^S$H*bAy z<_tP!!CA04qnP{QD&O0a(9g%P5bcUSs_-{hZqvk)rZmRiVkOci&lmVOWL7q;bA+EvtU=iKy z)XHL@5iCMX9IP6GroSfUPSRr&EKAHOO)ci&gorB$YG|1Hd+v-9mKQW6&L{;-oQ8Qp z4PyWTMi8;CZc+=JbKuj$4GfITxePDCa$v%< Date: Tue, 25 Nov 2025 16:58:35 -0800 Subject: [PATCH 02/50] Update comments while considering Field and IO design. --- service/domain/input.go | 1 + service/domain/output.go | 11 ++++++++--- service/domain/signature.go | 1 + service/platform/evaluator.go | 4 ++-- shared/config/datastores.go | 11 ++++++++--- shared/field.go | 1 - 6 files changed, 20 insertions(+), 9 deletions(-) diff --git a/service/domain/input.go b/service/domain/input.go index 1dfe28e..f1101b9 100644 --- a/service/domain/input.go +++ b/service/domain/input.go @@ -21,6 +21,7 @@ type Input struct { Vocab bool // Auxiliary is true if this input isn't part of the model + // TODO redesign model IO vs server IO concerns Auxiliary bool Type reflect.Type diff --git a/service/domain/output.go b/service/domain/output.go index 7156336..2b7b38c 100644 --- a/service/domain/output.go +++ b/service/domain/output.go @@ -8,14 +8,19 @@ import ( // Output represents model output type Output struct { + // Used in Stream Name string - // Primarily shown in config + // Shown in config DataType string - // DataTypeKind is used only for GBQ tool + // Only for GBQ tool DataTypeKind reflect.Kind - Index int + + // Used to extract output from *tf.Operation. + // Eventually becomes part of tf.Session.Run() parameter fetches ([]tf.Output). + Index int + *tf.Operation goType reflect.Type diff --git a/service/domain/signature.go b/service/domain/signature.go index 12745b7..22a9f21 100644 --- a/service/domain/signature.go +++ b/service/domain/signature.go @@ -4,6 +4,7 @@ package domain // Contains information required to extract vocabularies, unmarshal requests, and validate request inputs. // TODO document and address issues if reloaded model IO changes. type Signature struct { + // Method is unused. Method string Inputs []Input diff --git a/service/platform/evaluator.go b/service/platform/evaluator.go index 0ebf621..470e7ba 100644 --- a/service/platform/evaluator.go +++ b/service/platform/evaluator.go @@ -34,9 +34,9 @@ type PlatformEvaluator interface { // Close releases resources Close() error - // ReloadIfNeeded will update models as needed, and check their health. + // ReloadIfNeeded will update models as needed, check their health, and consolidate signatures, if implemented. // For in-process models (TensorFlow), this will check if the underlying models need to be updated. - // For external models (Triton), this will check Triton models' health. + // For external models (Triton), this will use the Model Control API to load, unload, and check the health of Triton models. ReloadIfNeeded(ctx context.Context) error } diff --git a/shared/config/datastores.go b/shared/config/datastores.go index c25d5d3..65be57e 100644 --- a/shared/config/datastores.go +++ b/shared/config/datastores.go @@ -2,16 +2,17 @@ package config import ( "fmt" + "github.com/viant/mly/shared/config/datastore" ) -//DatastoreList represents datastore list +// DatastoreList represents datastore list type DatastoreList struct { Connections []*datastore.Connection Datastores []*Datastore } -//Init initialises list +// Init initialises list func (d *DatastoreList) Init() { if len(d.Connections) > 0 { for i := range d.Connections { @@ -25,14 +26,16 @@ func (d *DatastoreList) Init() { } } -//Validate checks if datastore list is valid +// Validate checks if datastore list is valid func (d *DatastoreList) Validate() error { if len(d.Connections) == 0 && len(d.Datastores) == 0 { return nil } + if len(d.Connections) > 0 && len(d.Datastores) == 0 { return fmt.Errorf("item were empty, but item defined") } + if len(d.Connections) > 0 { for _, item := range d.Connections { if err := item.Validate(); err != nil { @@ -40,10 +43,12 @@ func (d *DatastoreList) Validate() error { } } } + for _, item := range d.Datastores { if err := item.Validate(); err != nil { return err } } + return nil } diff --git a/shared/field.go b/shared/field.go index f498040..da935f2 100644 --- a/shared/field.go +++ b/shared/field.go @@ -135,7 +135,6 @@ func (m *MetaInput) FieldByName() map[string]*Field { // On the server, it is called after reading the configuration file. // On the client, it is called after fetching the configuration from the server, which will have already processed it via reconcileIOFromSignature(). func (m *MetaInput) Init() { - // TODO assess why this approach was taken - this condition could be improved by having a map to see if the field by name already exists if len(m.Inputs) == 0 { // Add KeyFields to Inputs if len(m.KeyFields) > 0 { From de616a623a3de9b1b51028ab028af8c246a6f998 Mon Sep 17 00:00:00 2001 From: David Choi Date: Tue, 25 Nov 2025 17:11:19 -0800 Subject: [PATCH 03/50] Refactored out intermediate step in gRPC type handling. --- service/triton/client.go | 1 + service/triton/grpc.go | 332 +++++++++++++-------------------------- 2 files changed, 107 insertions(+), 226 deletions(-) diff --git a/service/triton/client.go b/service/triton/client.go index 3d01add..133601b 100644 --- a/service/triton/client.go +++ b/service/triton/client.go @@ -16,6 +16,7 @@ type TritonClient interface { ServerReady(ctx context.Context) error // inputs is expected to be [numInputs]([batchSize][1]T) (see service/request.Request.Feeds) + // inputs will never be empty ModelInfer(ctx context.Context, modelName string, inputs []interface{}, indexToName map[int]string) ([]interface{}, error) ModelReady(ctx context.Context, modelName string) (bool, error) diff --git a/service/triton/grpc.go b/service/triton/grpc.go index 820238b..e601274 100644 --- a/service/triton/grpc.go +++ b/service/triton/grpc.go @@ -24,26 +24,13 @@ func NewGRPCClient(grpcConn *grpc.ClientConn) *GRPCClient { } } -// preparedInput represents processed input data ready for gRPC transport -type preparedInput struct { - name string - datatype string // Triton datatype: "BYTES", "INT32", "INT64", "FP32", "FP64" - shape []int64 // Shape in int64 for gRPC compatibility - data interface{} // Flattened data: []string, []int32, []int64, []float32, []float64 -} - func (c *GRPCClient) ServerReady(ctx context.Context) error { _, err := c.grpcClient.ServerReady(ctx, &triton.ServerReadyRequest{}) return err } func (c *GRPCClient) ModelInfer(ctx context.Context, modelName string, inputs []interface{}, indexToName map[int]string) ([]interface{}, error) { - preparedInputs, err := prepareInputs(indexToName, inputs) - if err != nil { - return nil, err - } - - grpcRequest, err := buildGRPCRequest(modelName, preparedInputs) + grpcRequest, err := toGRPCRequest(modelName, inputs, indexToName) if err != nil { return nil, err } @@ -98,44 +85,105 @@ func (c *GRPCClient) Close() error { return c.grpcConn.Close() } -func buildGRPCRequest(modelName string, preparedInputs []preparedInput) (*triton.ModelInferRequest, error) { +func toGRPCRequest(modelName string, params []interface{}, indexToName map[int]string) (*triton.ModelInferRequest, error) { req := &triton.ModelInferRequest{ ModelName: modelName, - Inputs: make([]*triton.ModelInferRequest_InferInputTensor, len(preparedInputs)), + Inputs: make([]*triton.ModelInferRequest_InferInputTensor, len(params)), } - for i, input := range preparedInputs { - tensor := &triton.ModelInferRequest_InferInputTensor{ - Name: input.name, - Datatype: input.datatype, - Shape: input.shape, - Contents: &triton.InferTensorContents{}, + for i, param := range params { + inputName, exists := indexToName[i] + if !exists { + return nil, fmt.Errorf("no input name found for index %d", i) } - switch data := input.data.(type) { - case []string: - tensor.Contents.BytesContents = make([][]byte, len(data)) - for j, s := range data { - tensor.Contents.BytesContents[j] = []byte(s) + inputContents := &triton.InferTensorContents{} + + inputTensor := &triton.ModelInferRequest_InferInputTensor{ + Name: inputName, + Contents: inputContents, + } + + var batchSize int + var datatype string + + switch v := param.(type) { + case [][]string: + if len(v) > 0 { + batchSize = len(v) + datatype = "BYTES" + + inputContents.BytesContents = make([][]byte, batchSize) + for j := range batchSize { + inputContents.BytesContents[j] = []byte(v[j][0]) + } + } + case [][]int: + if len(v) > 0 { + batchSize := len(v) + datatype = "INT32" + + inputContents.IntContents = make([]int32, batchSize) + for j := range batchSize { + inputContents.IntContents[j] = int32(v[j][0]) + } + } + case [][]int32: + if len(v) > 0 { + batchSize := len(v) + datatype = "INT32" + + inputContents.IntContents = make([]int32, batchSize) + for j := range batchSize { + inputContents.IntContents[j] = v[j][0] + } + + } + case [][]int64: + if len(v) > 0 { + batchSize := len(v) + datatype = "INT64" + + inputContents.Int64Contents = make([]int64, batchSize) + for j := range batchSize { + inputContents.Int64Contents[j] = v[j][0] + } + + } + case [][]float32: + if len(v) > 0 { + batchSize := len(v) + datatype = "FP32" + + inputContents.Fp32Contents = make([]float32, batchSize) + for j := range batchSize { + inputContents.Fp32Contents[j] = v[j][0] + } + } + case [][]float64: + if len(v) > 0 { + batchSize := len(v) + datatype = "FP64" + + inputContents.Fp64Contents = make([]float64, batchSize) + for j := range batchSize { + inputContents.Fp64Contents[j] = v[j][0] + } } - case []int32: - tensor.Contents.IntContents = data - case []int64: - tensor.Contents.Int64Contents = data - case []float32: - tensor.Contents.Fp32Contents = data - case []float64: - tensor.Contents.Fp64Contents = data default: - return nil, fmt.Errorf("unsupported input data type %T for %s", data, input.name) + return nil, fmt.Errorf("unsupported input type for %s at index %d: %T", inputName, i, param) } - req.Inputs[i] = tensor + inputTensor.Datatype = datatype + inputTensor.Shape = []int64{int64(batchSize), 1} + + req.Inputs[i] = inputTensor } return req, nil } +// parseRawOutput if output is provided in raw format. func parseRawOutput(rawData []byte, datatype string, batchSize int) (interface{}, error) { switch datatype { case "INT64": @@ -207,124 +255,6 @@ func parseRawOutput(rawData []byte, datatype string, batchSize int) (interface{} } } -func prepareInputs(indexToName map[int]string, params []interface{}) ([]preparedInput, error) { - if len(params) == 0 { - return nil, fmt.Errorf("no input parameters provided") - } - - var inputs []preparedInput - - for i, param := range params { - inputName, exists := indexToName[i] - if !exists { - return nil, fmt.Errorf("no input name found for index %d", i) - } - - switch v := param.(type) { - case [][]string: - if len(v) > 0 { - batchSize := len(v) - data := make([]string, batchSize) - for j := 0; j < batchSize; j++ { - if len(v[j]) > 0 { - data[j] = v[j][0] - } - } - inputs = append(inputs, preparedInput{ - name: inputName, - shape: []int64{int64(batchSize), 1}, - datatype: "BYTES", - data: data, - }) - } - case [][]int: - if len(v) > 0 { - batchSize := len(v) - data := make([]int32, batchSize) - for j := 0; j < batchSize; j++ { - if len(v[j]) > 0 { - data[j] = int32(v[j][0]) - } - } - inputs = append(inputs, preparedInput{ - name: inputName, - shape: []int64{int64(batchSize), 1}, - datatype: "INT32", - data: data, - }) - } - case [][]int32: - if len(v) > 0 { - batchSize := len(v) - data := make([]int32, batchSize) - for j := 0; j < batchSize; j++ { - if len(v[j]) > 0 { - data[j] = v[j][0] - } - } - inputs = append(inputs, preparedInput{ - name: inputName, - shape: []int64{int64(batchSize), 1}, - datatype: "INT32", - data: data, - }) - } - case [][]int64: - if len(v) > 0 { - batchSize := len(v) - data := make([]int64, batchSize) - for j := 0; j < batchSize; j++ { - if len(v[j]) > 0 { - data[j] = v[j][0] - } - } - inputs = append(inputs, preparedInput{ - name: inputName, - shape: []int64{int64(batchSize), 1}, - datatype: "INT64", - data: data, - }) - } - case [][]float32: - if len(v) > 0 { - batchSize := len(v) - data := make([]float32, batchSize) - for j := 0; j < batchSize; j++ { - if len(v[j]) > 0 { - data[j] = v[j][0] - } - } - inputs = append(inputs, preparedInput{ - name: inputName, - shape: []int64{int64(batchSize), 1}, - datatype: "FP32", - data: data, - }) - } - case [][]float64: - if len(v) > 0 { - batchSize := len(v) - data := make([]float64, batchSize) - for j := 0; j < batchSize; j++ { - if len(v[j]) > 0 { - data[j] = v[j][0] - } - } - inputs = append(inputs, preparedInput{ - name: inputName, - shape: []int64{int64(batchSize), 1}, - datatype: "FP64", - data: data, - }) - } - default: - return nil, fmt.Errorf("unsupported input type for %s at index %d: %T", inputName, i, param) - } - } - - return inputs, nil -} - func convertGRPCResponse(response *triton.ModelInferResponse) ([]interface{}, error) { if len(response.Outputs) == 0 { return nil, fmt.Errorf("no outputs in response") @@ -358,90 +288,40 @@ func convertGRPCResponse(response *triton.ModelInferResponse) ([]interface{}, er switch output.Datatype { case "FP32": - if len(output.Contents.Fp32Contents) == 0 { - converted := make([][]float32, batchSize) - for j := 0; j < batchSize; j++ { - converted[j] = []float32{0.0} - } - result[i] = converted - } else { - converted := make([][]float32, batchSize) - for j := 0; j < batchSize && j < len(output.Contents.Fp32Contents); j++ { - converted[j] = []float32{output.Contents.Fp32Contents[j]} - } - result[i] = converted + converted := make([][]float32, batchSize) + for j := 0; j < batchSize && j < len(output.Contents.Fp32Contents); j++ { + converted[j] = []float32{output.Contents.Fp32Contents[j]} } - + result[i] = converted case "FP64": - if len(output.Contents.Fp64Contents) == 0 { - converted := make([][]float64, batchSize) - for j := 0; j < batchSize; j++ { - converted[j] = []float64{0.0} - } - result[i] = converted - } else { - converted := make([][]float64, batchSize) - for j := 0; j < batchSize && j < len(output.Contents.Fp64Contents); j++ { - converted[j] = []float64{output.Contents.Fp64Contents[j]} - } - result[i] = converted + converted := make([][]float64, batchSize) + for j := 0; j < batchSize && j < len(output.Contents.Fp64Contents); j++ { + converted[j] = []float64{output.Contents.Fp64Contents[j]} } + result[i] = converted case "INT32": - if len(output.Contents.IntContents) == 0 { - converted := make([][]int32, batchSize) - for j := 0; j < batchSize; j++ { - converted[j] = []int32{0} - } - result[i] = converted - } else { - converted := make([][]int32, batchSize) - for j := 0; j < batchSize && j < len(output.Contents.IntContents); j++ { - converted[j] = []int32{output.Contents.IntContents[j]} - } - result[i] = converted + converted := make([][]int32, batchSize) + for j := 0; j < batchSize && j < len(output.Contents.IntContents); j++ { + converted[j] = []int32{output.Contents.IntContents[j]} } - + result[i] = converted case "INT64": - if len(output.Contents.Int64Contents) == 0 { - converted := make([][]int64, batchSize) - for j := 0; j < batchSize; j++ { - converted[j] = []int64{0} - } - result[i] = converted - } else { - converted := make([][]int64, batchSize) - for j := 0; j < batchSize && j < len(output.Contents.Int64Contents); j++ { - converted[j] = []int64{output.Contents.Int64Contents[j]} - } - result[i] = converted + converted := make([][]int64, batchSize) + for j := 0; j < batchSize && j < len(output.Contents.Int64Contents); j++ { + converted[j] = []int64{output.Contents.Int64Contents[j]} } + result[i] = converted case "BYTES": - if len(output.Contents.BytesContents) == 0 { - converted := make([][]string, batchSize) - for j := 0; j < batchSize; j++ { - converted[j] = []string{""} - } - result[i] = converted - } else { - converted := make([][]string, batchSize) - for j := 0; j < batchSize && j < len(output.Contents.BytesContents); j++ { - converted[j] = []string{string(output.Contents.BytesContents[j])} - } - result[i] = converted + converted := make([][]string, batchSize) + for j := 0; j < batchSize && j < len(output.Contents.BytesContents); j++ { + converted[j] = []string{string(output.Contents.BytesContents[j])} } + result[i] = converted default: - if len(output.Contents.Fp32Contents) > 0 { - converted := make([][]float32, batchSize) - for j := 0; j < batchSize && j < len(output.Contents.Fp32Contents); j++ { - converted[j] = []float32{output.Contents.Fp32Contents[j]} - } - result[i] = converted - } else { - return nil, fmt.Errorf("unsupported output datatype %s for %s", output.Datatype, output.Name) - } + return nil, fmt.Errorf("unsupported output datatype %s for %s", output.Datatype, output.Name) } } From e692430c9fcf20fe31f15e9ebe88718b27bfcedf Mon Sep 17 00:00:00 2001 From: David Choi Date: Tue, 25 Nov 2025 17:18:55 -0800 Subject: [PATCH 04/50] Add model IO detection to Triton models. --- service/triton/client.go | 16 +- service/triton/{triton.go => evaluator.go} | 205 ++++++----- service/triton/evaluator_test.go | 206 +++++++++++ service/triton/grpc.go | 35 ++ service/triton/grpc_test.go | 400 +++++---------------- service/triton/http.go | 225 ++++++------ service/triton/metrics.go | 24 ++ service/triton/triton_test.go | 136 ------- service/triton/types.go | 23 ++ 9 files changed, 635 insertions(+), 635 deletions(-) rename service/triton/{triton.go => evaluator.go} (55%) create mode 100644 service/triton/evaluator_test.go delete mode 100644 service/triton/triton_test.go create mode 100644 service/triton/types.go diff --git a/service/triton/client.go b/service/triton/client.go index 133601b..ec93f3c 100644 --- a/service/triton/client.go +++ b/service/triton/client.go @@ -25,9 +25,24 @@ type TritonClient interface { ModelUnload(ctx context.Context, modelName string) error + ModelMetadata(ctx context.Context, modelName string) (*ModelMetadata, error) + Close() error } +// https://github.com/kserve/kserve/blob/master/docs/predict-api/v2/required_api.md#model-metadata-response-json-object `$metadata_tensor` +type MetadataTensor struct { + Name string `json:"name"` + Datatype string `json:"datatype"` + Shape []int64 `json:"shape"` +} + +// stripped down version of https://github.com/kserve/kserve/blob/master/docs/predict-api/v2/required_api.md#model-metadata-response-json-object +type ModelMetadata struct { + Inputs []MetadataTensor `json:"inputs"` + Outputs []MetadataTensor `json:"outputs"` +} + // NewClient creates either an HTTP or gRPC client. func NewClient(server config.TritonServer) (TritonClient, error) { if server.GRPCBaseURL != "" { @@ -54,7 +69,6 @@ func NewClient(server config.TritonServer) (TritonClient, error) { } // HTTP options seem a bit bare - // TODO see if DRY with return NewMeteredTritonClient(&HTTPClient{ httpClient: &http.Client{ Timeout: time.Duration(server.HTTPClientTimeoutMs) * time.Millisecond, diff --git a/service/triton/triton.go b/service/triton/evaluator.go similarity index 55% rename from service/triton/triton.go rename to service/triton/evaluator.go index a4c8347..010fbd0 100644 --- a/service/triton/triton.go +++ b/service/triton/evaluator.go @@ -3,8 +3,8 @@ package triton import ( "context" "fmt" + "log" "net/http" - "reflect" "time" "github.com/viant/mly/service/config" @@ -25,9 +25,16 @@ type TritonEvaluator struct { timeout time.Duration - signature *domain.Signature + modelID string + debug bool + + signature *domain.Signature + + // maps feeds index to input name indexToName map[int]string + configuredInputs []*shared.Field + inputs map[string]*domain.Input } @@ -65,9 +72,10 @@ func NewTritonEvaluator(config *config.Model, tritonClients map[string]TritonCli repositoryExplicit: !isPrivateClient || config.Triton.RepositoryExplicit, } - if err := evaluator.handleIO(&config.MetaInput); err != nil { - return nil, err - } + evaluator.configuredInputs = config.MetaInput.Inputs + + evaluator.modelID = config.ID + evaluator.debug = config.Debug return evaluator, nil } @@ -84,6 +92,10 @@ func NewRoutedTritonEvaluator(modelName string, client TritonClient, timeoutMs i // Predict performs inference via Triton Inference Server func (t *TritonEvaluator) Predict(ctx context.Context, params []interface{}) ([]interface{}, error) { + if len(params) == 0 { + return nil, fmt.Errorf("no input parameters") + } + requestCtx := ctx if _, hasDeadline := ctx.Deadline(); !hasDeadline { var cancel context.CancelFunc @@ -94,87 +106,12 @@ func (t *TritonEvaluator) Predict(ctx context.Context, params []interface{}) ([] return t.client.ModelInfer(requestCtx, t.modelName, params, t.indexToName) } -func (t *TritonEvaluator) handleIO(io *shared.MetaInput) error { - var inputs []domain.Input - var outputs []domain.Output - - indexToName := make(map[int]string) - - mappedInputs := make(map[string]*domain.Input) - - if len(io.Inputs) > 0 { - for _, input := range io.Inputs { - if !input.Auxiliary { - inputs = append(inputs, domain.Input{ - Name: input.Name, - Index: input.Index, - }) - - indexToName[input.Index] = input.Name - } - - inputType := reflect.TypeOf("") - if input.DataType != "" { - switch input.DataType { - case "string": - inputType = reflect.TypeOf("") - case "int": - inputType = reflect.TypeOf(0) - case "int32": - inputType = reflect.TypeOf(int32(0)) - case "int64": - inputType = reflect.TypeOf(int64(0)) - case "float32": - inputType = reflect.TypeOf(float32(0)) - case "float64": - inputType = reflect.TypeOf(float64(0)) - } - } - - mappedInputs[input.Name] = &domain.Input{ - Name: input.Name, - Index: input.Index, - Type: inputType, - Vocab: false, - Auxiliary: input.Auxiliary, - } - } - } else { - return fmt.Errorf("missing input configuration for Triton evaluator. " + - "Add 'inputs' section to your model configuration YAML with field definitions") - } - - if len(io.Outputs) > 0 { - for i, output := range io.Outputs { - outputs = append(outputs, domain.Output{ - Name: output.Name, - Index: i, - DataType: output.DataType, - }) - } - } else { - return fmt.Errorf("missing output configuration for Triton evaluator. " + - "Add 'outputs' section to your model configuration YAML with field definitions") - } - - t.indexToName = indexToName - - t.signature = &domain.Signature{ - Inputs: inputs, - Outputs: outputs, - Output: outputs[0], - } - - t.inputs = mappedInputs - - return nil -} - func (t *TritonEvaluator) Signature() *domain.Signature { return t.signature } func (t *TritonEvaluator) Dictionary() *common.Dictionary { + // no dictionary return nil } @@ -195,34 +132,116 @@ func (t *TritonEvaluator) Close() error { return nil } -// ReloadIfNeeded for independent Triton models, reloading is not supported. +// For independent Triton server models, reloading is not supported. func (t *TritonEvaluator) ReloadIfNeeded(ctx context.Context) error { ready, err := t.client.ModelReady(ctx, t.modelName) if err != nil { return fmt.Errorf("failed to check Triton model %s health: %w", t.modelName, err) } - if ready { + if ready && t.signature != nil { + // only a health check return nil } - if !t.repositoryExplicit { - return fmt.Errorf("model %s not ready and Triton is not in EXPLICIT Model Control Mode: %w", t.modelName, err) - } + if !ready { + if !t.repositoryExplicit { + return fmt.Errorf("model %s not ready and Triton is not in EXPLICIT Model Control Mode: %w", t.modelName, err) + } - err = t.client.ModelLoad(ctx, t.modelName) - if err != nil { - return fmt.Errorf("failed to load Triton model %s: %w", t.modelName, err) - } + err = t.client.ModelLoad(ctx, t.modelName) + if err != nil { + return fmt.Errorf("failed to load Triton model %s: %w", t.modelName, err) + } - ready, err = t.client.ModelReady(ctx, t.modelName) - if err != nil { - return fmt.Errorf("failed to check Triton model %s health after loading: %w", t.modelName, err) + ready, err = t.client.ModelReady(ctx, t.modelName) + if err != nil { + return fmt.Errorf("failed to check Triton model %s health after loading: %w", t.modelName, err) + } } if !ready { return fmt.Errorf("model %s is not ready after loading", t.modelName) } + // we need to get the model metadata and consolidate the signature + metadata, err := t.client.ModelMetadata(ctx, t.modelName) + if err != nil || metadata == nil { + return fmt.Errorf("failed to get Triton model %s metadata: %w", t.modelName, err) + } + + mappedInputs := make(map[string]*domain.Input) + indexedInputNames := make(map[int]string) + + signatureInputs := make([]domain.Input, len(metadata.Inputs)) + for i, input := range metadata.Inputs { + goType := TritonToGoType(input.Datatype) + di := domain.Input{ + Name: input.Name, + // for now, since the request provides a []interface{}, we populate the Index + Index: i, + Type: goType, + Vocab: false, + Auxiliary: false, + } + + if t.debug { + log.Printf("[%s] Triton[%s] input:%s index:%d datatype:%s goType:%s", + t.modelID, t.modelName, input.Name, di.Index, input.Datatype, goType.Name()) + } + + signatureInputs[i] = di + mappedInputs[input.Name] = &di + indexedInputNames[i] = input.Name + } + + t.indexToName = indexedInputNames + + for _, input := range t.configuredInputs { + iName := input.Name + if _, ok := mappedInputs[iName]; !ok { + goType, err := common.DataType(input.DataType) + if err != nil { + return fmt.Errorf("failed to get data type for %s: %w", iName, err) + } + + mappedInputs[iName] = &domain.Input{ + Name: iName, + Index: len(mappedInputs), + Type: goType, + Vocab: false, + Auxiliary: true, + } + + if t.debug { + log.Printf("[%s] Triton[%s] auxiliary input:%s goType:%s", + t.modelID, t.modelName, iName, goType.Name()) + } + } + } + + t.inputs = mappedInputs + + outputs := make([]domain.Output, len(metadata.Outputs)) + for i, output := range metadata.Outputs { + o := domain.Output{ + Name: output.Name, + Index: len(outputs), + } + + goType := TritonToGoType(output.Datatype) + o.SetType(goType) + o.DataType = goType.Name() + o.DataTypeKind = goType.Kind() + + outputs[i] = o + } + + t.signature = &domain.Signature{ + Inputs: signatureInputs, + Outputs: outputs, + Output: outputs[0], + } + return nil } diff --git a/service/triton/evaluator_test.go b/service/triton/evaluator_test.go new file mode 100644 index 0000000..eecc394 --- /dev/null +++ b/service/triton/evaluator_test.go @@ -0,0 +1,206 @@ +package triton + +import ( + "context" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/mly/service/config" + "github.com/viant/mly/shared" +) + +type mockTritonClient struct { + mu sync.Mutex + + readyState map[string]bool + + unloadCh chan string + modelLoadErr map[string]error + + metadata *ModelMetadata +} + +func (m *mockTritonClient) ServerReady(ctx context.Context) error { return nil } + +func (m *mockTritonClient) ModelInfer(ctx context.Context, modelName string, inputs []interface{}, indexToName map[int]string) ([]interface{}, error) { + return nil, nil +} + +func (m *mockTritonClient) ModelReady(ctx context.Context, modelName string) (bool, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if ready, ok := m.readyState[modelName]; ok { + return ready, nil + } + + return true, nil +} + +func (m *mockTritonClient) ModelLoad(ctx context.Context, modelName string) error { + m.mu.Lock() + defer m.mu.Unlock() + + if err := m.modelLoadErr[modelName]; err != nil { + return err + } + + if m.readyState == nil { + m.readyState = make(map[string]bool) + } + + m.readyState[modelName] = true + + return nil +} + +func (m *mockTritonClient) ModelUnload(ctx context.Context, modelName string) error { + ch := m.unloadCh + if ch != nil { + ch <- modelName + } + return nil +} + +func (m *mockTritonClient) ModelMetadata(ctx context.Context, modelName string) (*ModelMetadata, error) { + if m.metadata == nil { + m.metadata = &ModelMetadata{ + Inputs: []MetadataTensor{ + {Name: "input1", Datatype: "BYTES"}, + {Name: "input2", Datatype: "BYTES"}, + }, + Outputs: []MetadataTensor{ + {Name: "output1", Datatype: "FP32"}, + }, + } + } + + return m.metadata, nil +} + +func (m *mockTritonClient) Close() error { return nil } + +func newTritonEvaluator(cfg *config.Model, mockClient *mockTritonClient) *TritonEvaluator { + cfg.Triton.Init() + + evaluator := &TritonEvaluator{ + modelName: cfg.Triton.ModelName, + isPrivateClient: true, + repositoryExplicit: false, + client: mockClient, + configuredInputs: cfg.MetaInput.Inputs, + } + + evaluator.ReloadIfNeeded(context.Background()) + + return evaluator +} + +func TestTritonEvaluator_Signature(t *testing.T) { + cfg := &config.Model{ + ID: "test_model", + Triton: &config.TritonConfig{ + ModelName: "test_model", + }, + } + + evaluator := newTritonEvaluator(cfg, &mockTritonClient{}) + + defer evaluator.Close() + + sig := evaluator.Signature() + require.NotNil(t, sig) + assert.Equal(t, 2, len(sig.Inputs)) + assert.Equal(t, 1, len(sig.Outputs)) + assert.Equal(t, "input1", sig.Inputs[0].Name) + assert.Equal(t, "output1", sig.Outputs[0].Name) +} + +func TestTritonEvaluator_Inputs(t *testing.T) { + cfg := &config.Model{ + ID: "test_model", + Triton: &config.TritonConfig{ + ModelName: "test_model", + }, + MetaInput: shared.MetaInput{ + Inputs: []*shared.Field{ + {Name: "input1", Index: 0, DataType: "string"}, + {Name: "input2", Index: 0, DataType: "string"}, + {Name: "input_aux", Index: 0, DataType: "string", Auxiliary: true}, + }, + }, + } + + evaluator := newTritonEvaluator(cfg, &mockTritonClient{}) + + defer evaluator.Close() + + inputMap := evaluator.Inputs() + require.NotNil(t, inputMap) + assert.Equal(t, 3, len(inputMap)) + + i1 := inputMap["input1"] + assert.Equal(t, "input1", i1.Name) + assert.Equal(t, false, i1.Auxiliary) + assert.Equal(t, 0, i1.Index) + + i2 := inputMap["input2"] + assert.Equal(t, "input2", i2.Name) + assert.Equal(t, false, i2.Auxiliary) + assert.Equal(t, 1, i2.Index) + + iAux := inputMap["input_aux"] + assert.Equal(t, "input_aux", iAux.Name) + assert.Equal(t, true, iAux.Auxiliary) + assert.Equal(t, 2, iAux.Index) +} + +func TestTritonEvaluator_SignatureWithAuxiliaryInputs(t *testing.T) { + cfg := &config.Model{ + ID: "test_model", + Triton: &config.TritonConfig{ + ModelName: "test_model", + }, + MetaInput: shared.MetaInput{ + Inputs: []*shared.Field{ + {Name: "input_aux", Index: 0, DataType: "string", Auxiliary: true}, + }, + }, + } + + evaluator := newTritonEvaluator(cfg, &mockTritonClient{ + metadata: &ModelMetadata{ + Inputs: []MetadataTensor{ + {Name: "input1", Datatype: "BYTES"}, + }, + Outputs: []MetadataTensor{ + {Name: "output1", Datatype: "FP32"}, + }, + }, + }) + + defer evaluator.Close() + + sig := evaluator.Signature() + require.NotNil(t, sig) + assert.Equal(t, 1, len(sig.Inputs)) + assert.Equal(t, "input1", sig.Inputs[0].Name) +} + +func TestTritonEvaluator_PredictEmptyBatch(t *testing.T) { + cfg := &config.Model{ + ID: "test_model", + Triton: &config.TritonConfig{ + ModelName: "test_model", + }, + } + + evaluator := newTritonEvaluator(cfg, &mockTritonClient{}) + defer evaluator.Close() + + _, err := evaluator.Predict(context.Background(), []interface{}{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "no input parameters") +} diff --git a/service/triton/grpc.go b/service/triton/grpc.go index e601274..0b92cdf 100644 --- a/service/triton/grpc.go +++ b/service/triton/grpc.go @@ -81,6 +81,41 @@ func (c *GRPCClient) ModelUnload(ctx context.Context, modelName string) error { return nil } +func (c *GRPCClient) ModelMetadata(ctx context.Context, modelName string) (*ModelMetadata, error) { + grpcResponse, err := c.grpcClient.ModelMetadata(ctx, &triton.ModelMetadataRequest{ + Name: modelName, + }) + if err != nil { + return nil, err + } + return convertGRPCModelMetadataResponse(grpcResponse), nil +} + +func convertGRPCModelMetadataResponse(response *triton.ModelMetadataResponse) *ModelMetadata { + inputs := make([]MetadataTensor, len(response.Inputs)) + for i, input := range response.Inputs { + inputs[i] = MetadataTensor{ + Name: input.Name, + Datatype: input.Datatype, + Shape: input.Shape, + } + } + + outputs := make([]MetadataTensor, len(response.Outputs)) + for i, output := range response.Outputs { + outputs[i] = MetadataTensor{ + Name: output.Name, + Datatype: output.Datatype, + Shape: output.Shape, + } + } + + return &ModelMetadata{ + Inputs: inputs, + Outputs: outputs, + } +} + func (c *GRPCClient) Close() error { return c.grpcConn.Close() } diff --git a/service/triton/grpc_test.go b/service/triton/grpc_test.go index 57e9c0c..836defc 100644 --- a/service/triton/grpc_test.go +++ b/service/triton/grpc_test.go @@ -9,8 +9,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" triton "github.com/viant/mly/proto/triton" - "github.com/viant/mly/service/config" - "github.com/viant/mly/shared" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/test/bufconn" @@ -28,73 +26,90 @@ func createMockTritonConn(ctx context.Context, t *testing.T, listener *bufconn.L return conn } -func createMockGRPCTritonEvaluator(t *testing.T, cfg *config.Model, grpcConn *grpc.ClientConn) *TritonEvaluator { - grpcClient := triton.NewGRPCInferenceServiceClient(grpcConn) - cfg.Triton.Init() - evaluator, err := NewTritonEvaluator(cfg, map[string]TritonClient{ - cfg.Triton.ServerID: &GRPCClient{ - grpcConn: grpcConn, - grpcClient: grpcClient, - }, - }) +// mockTritonServer implements triton.GRPCInferenceServiceServer for testing +type mockTritonServer struct { + triton.UnimplementedGRPCInferenceServiceServer - require.NoError(t, err) - return evaluator + modelReady bool + responses map[string]*triton.ModelInferResponse } -func TestTritonEvaluator_ReloadAndSupportsReload(t *testing.T) { - ctx := context.Background() +func (m *mockTritonServer) RepositoryModelLoad(ctx context.Context, req *triton.RepositoryModelLoadRequest) (*triton.RepositoryModelLoadResponse, error) { + return &triton.RepositoryModelLoadResponse{}, nil +} - // Set up mock server - mock := &mockTritonServer{ - modelReady: true, - responses: map[string]*triton.ModelInferResponse{ - "test_model": { - ModelName: "test_model", - Outputs: []*triton.ModelInferResponse_InferOutputTensor{ - { - Name: "output", - Datatype: "INT64", - Shape: []int64{2, 1}, - Contents: &triton.InferTensorContents{ - Int64Contents: []int64{42, 100}, - }, - }, +func (m *mockTritonServer) ModelReady(ctx context.Context, req *triton.ModelReadyRequest) (*triton.ModelReadyResponse, error) { + return &triton.ModelReadyResponse{Ready: m.modelReady}, nil +} + +func (m *mockTritonServer) ModelInfer(ctx context.Context, req *triton.ModelInferRequest) (*triton.ModelInferResponse, error) { + if resp, ok := m.responses[req.ModelName]; ok { + return resp, nil + } + + // Return a default response + return &triton.ModelInferResponse{ + ModelName: req.ModelName, + Outputs: []*triton.ModelInferResponse_InferOutputTensor{ + { + Name: "output", + Datatype: "FP32", + Shape: []int64{1, 1}, + Contents: &triton.InferTensorContents{ + Fp32Contents: []float32{0.5}, }, }, }, - } + }, nil +} +// startMockGRPCServer starts an in-memory gRPC server for testing +func startMockGRPCServer(t *testing.T, mock *mockTritonServer) (*grpc.Server, *bufconn.Listener) { + buffer := 1024 * 1024 + listener := bufconn.Listen(buffer) + + server := grpc.NewServer() + triton.RegisterGRPCInferenceServiceServer(server, mock) + + go func() { + if err := server.Serve(listener); err != nil { + t.Fatalf("Server exited with error: %v", err) + } + }() + + return server, listener +} + +func createClient(t *testing.T, ctx context.Context, mock *mockTritonServer) (func(), *GRPCClient) { server, listener := startMockGRPCServer(t, mock) - defer server.Stop() - - cfg := &config.Model{ - ID: "test_model", - Platform: "triton", - MetaInput: shared.MetaInput{ - Inputs: []*shared.Field{ - {Name: "input1", Index: 0, DataType: "string"}, - }, - Outputs: []*shared.Field{ - {Name: "output1", Index: 0, DataType: "float32"}, - }, - }, - Triton: &config.TritonConfig{ - ModelName: "test_model", - ServerID: "test_server", - }, - } grpcConn := createMockTritonConn(ctx, t, listener) - evaluator := createMockGRPCTritonEvaluator(t, cfg, grpcConn) - defer evaluator.Close() + client := &GRPCClient{ + grpcConn: grpcConn, + grpcClient: triton.NewGRPCInferenceServiceClient(grpcConn), + } - // ReloadIfNeeded should be a no-op - err := evaluator.ReloadIfNeeded(context.Background()) - assert.NoError(t, err) + return func() { + server.Stop() + }, client } -func TestTritonEvaluator_PredictWithMockServer(t *testing.T) { +func TestGRPCClient_ModelLoad(t *testing.T) { + ctx := context.Background() + + // Set up mock server + mock := &mockTritonServer{ + modelReady: true, + } + + stopper, client := createClient(t, ctx, mock) + defer stopper() + + err := client.ModelLoad(ctx, "test_model") + require.NoError(t, err) +} + +func TestGRPCClient_ModelInfer(t *testing.T) { ctx := context.Background() // Set up mock server @@ -117,36 +132,15 @@ func TestTritonEvaluator_PredictWithMockServer(t *testing.T) { }, } - server, listener := startMockGRPCServer(t, mock) - defer server.Stop() - - // Create evaluator with mock connection - cfg := &config.Model{ - ID: "test_model", - Platform: "triton", - MetaInput: shared.MetaInput{ - Inputs: []*shared.Field{ - {Name: "input1", Index: 0, DataType: "string"}, - }, - Outputs: []*shared.Field{ - {Name: "output", Index: 0, DataType: "int64"}, - }, - }, - Triton: &config.TritonConfig{ - ModelName: "test_model", - ServerID: "test_server", - }, - } - grpcConn := createMockTritonConn(ctx, t, listener) - evaluator := createMockGRPCTritonEvaluator(t, cfg, grpcConn) - defer evaluator.Close() + stopper, client := createClient(t, ctx, mock) + defer stopper() // Test prediction params := []interface{}{ [][]string{{"value1"}, {"value2"}}, // 2 batch items } - results, err := evaluator.Predict(ctx, params) + results, err := client.ModelInfer(ctx, "test_model", params, map[int]string{0: "input1"}) require.NoError(t, err) require.Len(t, results, 1) @@ -158,7 +152,7 @@ func TestTritonEvaluator_PredictWithMockServer(t *testing.T) { assert.Equal(t, []int64{100}, output[1]) } -func TestTritonEvaluator_PredictWithRawOutputContents(t *testing.T) { +func TestGRPCClient_ModelInferWithRawOutputContents(t *testing.T) { ctx := context.Background() // Set up mock server with raw output contents @@ -181,35 +175,14 @@ func TestTritonEvaluator_PredictWithRawOutputContents(t *testing.T) { }, } - server, listener := startMockGRPCServer(t, mock) - defer server.Stop() - - cfg := &config.Model{ - ID: "test_model", - Platform: "triton", - MetaInput: shared.MetaInput{ - Inputs: []*shared.Field{ - {Name: "input1", Index: 0, DataType: "string"}, - }, - Outputs: []*shared.Field{ - {Name: "output", Index: 0, DataType: "float32"}, - }, - }, - Triton: &config.TritonConfig{ - ModelName: "test_model", - ServerID: "test_server", - }, - } - - grpcConn := createMockTritonConn(ctx, t, listener) - evaluator := createMockGRPCTritonEvaluator(t, cfg, grpcConn) - defer evaluator.Close() + stopper, client := createClient(t, ctx, mock) + defer stopper() params := []interface{}{ [][]string{{"value1"}, {"value2"}}, } - results, err := evaluator.Predict(ctx, params) + results, err := client.ModelInfer(ctx, "test_model", params, map[int]string{0: "input1"}) require.NoError(t, err) require.Len(t, results, 1) @@ -220,7 +193,7 @@ func TestTritonEvaluator_PredictWithRawOutputContents(t *testing.T) { assert.InDelta(t, 50.0, output[1][0], 0.01) } -func TestTritonEvaluator_PredictAllInputTypes(t *testing.T) { +func TestGRPCClient_ModelInferAllInputTypes(t *testing.T) { ctx := context.Background() testCases := []struct { @@ -315,31 +288,10 @@ func TestTritonEvaluator_PredictAllInputTypes(t *testing.T) { responses: map[string]*triton.ModelInferResponse{"test_model": tc.expectedResp}, } - server, listener := startMockGRPCServer(t, mock) - grpcConn := createMockTritonConn(ctx, t, listener) - defer server.Stop() + stopper, client := createClient(t, ctx, mock) + defer stopper() - cfg := &config.Model{ - ID: "test_model", - Platform: "triton", - MetaInput: shared.MetaInput{ - Inputs: []*shared.Field{ - {Name: "input1", Index: 0, DataType: tc.inputType}, - }, - Outputs: []*shared.Field{ - {Name: "output", Index: 0, DataType: tc.inputType}, - }, - }, - Triton: &config.TritonConfig{ - ModelName: "test_model", - ServerID: "test_server", - }, - } - - evaluator := createMockGRPCTritonEvaluator(t, cfg, grpcConn) - defer evaluator.Close() - - results, err := evaluator.Predict(ctx, []interface{}{tc.inputData}) + results, err := client.ModelInfer(ctx, "test_model", []interface{}{tc.inputData}, map[int]string{0: "input1"}) require.NoError(t, err) require.Len(t, results, 1) assert.NotNil(t, results[0]) @@ -347,55 +299,7 @@ func TestTritonEvaluator_PredictAllInputTypes(t *testing.T) { } } -// mockTritonServer implements triton.GRPCInferenceServiceServer for testing -type mockTritonServer struct { - triton.UnimplementedGRPCInferenceServiceServer - modelReady bool - responses map[string]*triton.ModelInferResponse -} - -func (m *mockTritonServer) ModelReady(ctx context.Context, req *triton.ModelReadyRequest) (*triton.ModelReadyResponse, error) { - return &triton.ModelReadyResponse{Ready: m.modelReady}, nil -} - -func (m *mockTritonServer) ModelInfer(ctx context.Context, req *triton.ModelInferRequest) (*triton.ModelInferResponse, error) { - if resp, ok := m.responses[req.ModelName]; ok { - return resp, nil - } - // Return a default response - return &triton.ModelInferResponse{ - ModelName: req.ModelName, - Outputs: []*triton.ModelInferResponse_InferOutputTensor{ - { - Name: "output", - Datatype: "FP32", - Shape: []int64{1, 1}, - Contents: &triton.InferTensorContents{ - Fp32Contents: []float32{0.5}, - }, - }, - }, - }, nil -} - -// startMockGRPCServer starts an in-memory gRPC server for testing -func startMockGRPCServer(t *testing.T, mock *mockTritonServer) (*grpc.Server, *bufconn.Listener) { - buffer := 1024 * 1024 - listener := bufconn.Listen(buffer) - - server := grpc.NewServer() - triton.RegisterGRPCInferenceServiceServer(server, mock) - - go func() { - if err := server.Serve(listener); err != nil { - t.Logf("Server exited with error: %v", err) - } - }() - - return server, listener -} - -func TestTritonEvaluator_PredictBytesOutput(t *testing.T) { +func TestGRPCClient_ModelInferBytesOutput(t *testing.T) { ctx := context.Background() mock := &mockTritonServer{ @@ -417,35 +321,14 @@ func TestTritonEvaluator_PredictBytesOutput(t *testing.T) { }, } - server, listener := startMockGRPCServer(t, mock) - defer server.Stop() - - cfg := &config.Model{ - ID: "test_model", - Platform: "triton", - MetaInput: shared.MetaInput{ - Inputs: []*shared.Field{ - {Name: "input1", Index: 0, DataType: "string"}, - }, - Outputs: []*shared.Field{ - {Name: "output", Index: 0, DataType: "string"}, - }, - }, - Triton: &config.TritonConfig{ - ModelName: "test_model", - ServerID: "test_server", - }, - } - - grpcConn := createMockTritonConn(ctx, t, listener) - evaluator := createMockGRPCTritonEvaluator(t, cfg, grpcConn) - defer evaluator.Close() + stopper, client := createClient(t, ctx, mock) + defer stopper() params := []interface{}{ [][]string{{"input1"}, {"input2"}}, } - results, err := evaluator.Predict(ctx, params) + results, err := client.ModelInfer(ctx, "test_model", params, map[int]string{0: "input1"}) require.NoError(t, err) require.Len(t, results, 1) @@ -456,7 +339,7 @@ func TestTritonEvaluator_PredictBytesOutput(t *testing.T) { assert.Equal(t, []string{"result2"}, output[1]) } -func TestTritonEvaluator_PredictDifferentBatchSizes(t *testing.T) { +func TestGRPCClient_ModelInferDifferentBatchSizes(t *testing.T) { ctx := context.Background() testCases := []struct { @@ -496,29 +379,8 @@ func TestTritonEvaluator_PredictDifferentBatchSizes(t *testing.T) { }, } - server, listener := startMockGRPCServer(t, mock) - defer server.Stop() - - cfg := &config.Model{ - ID: "test_model", - Platform: "triton", - MetaInput: shared.MetaInput{ - Inputs: []*shared.Field{ - {Name: "input1", Index: 0, DataType: "string"}, - }, - Outputs: []*shared.Field{ - {Name: "output", Index: 0, DataType: "int64"}, - }, - }, - Triton: &config.TritonConfig{ - ModelName: "test_model", - ServerID: "test_server", - }, - } - - grpcConn := createMockTritonConn(ctx, t, listener) - evaluator := createMockGRPCTritonEvaluator(t, cfg, grpcConn) - defer evaluator.Close() + stopper, client := createClient(t, ctx, mock) + defer stopper() // Generate input batch inputBatch := make([][]string, tc.batchSize) @@ -526,7 +388,7 @@ func TestTritonEvaluator_PredictDifferentBatchSizes(t *testing.T) { inputBatch[i] = []string{fmt.Sprintf("input_%d", i)} } - results, err := evaluator.Predict(ctx, []interface{}{inputBatch}) + results, err := client.ModelInfer(ctx, "test_model", []interface{}{inputBatch}, map[int]string{0: "input1"}) require.NoError(t, err) require.Len(t, results, 1) @@ -537,7 +399,7 @@ func TestTritonEvaluator_PredictDifferentBatchSizes(t *testing.T) { } } -func TestTritonEvaluator_PredictUnsupportedType(t *testing.T) { +func TestGRPCClient_ModelInferUnsupportedType(t *testing.T) { ctx := context.Background() mock := &mockTritonServer{ @@ -560,66 +422,19 @@ func TestTritonEvaluator_PredictUnsupportedType(t *testing.T) { }, } - server, listener := startMockGRPCServer(t, mock) - defer server.Stop() - - cfg := &config.Model{ - ID: "test_model", - Platform: "triton", - MetaInput: shared.MetaInput{ - Inputs: []*shared.Field{ - {Name: "input1", Index: 0, DataType: "string"}, - }, - Outputs: []*shared.Field{ - {Name: "output", Index: 0, DataType: "string"}, - }, - }, - Triton: &config.TritonConfig{ - ModelName: "test_model", - ServerID: "test_server", - }, - } - - grpcConn := createMockTritonConn(ctx, t, listener) - evaluator := createMockGRPCTritonEvaluator(t, cfg, grpcConn) - defer evaluator.Close() + stopper, client := createClient(t, ctx, mock) + defer stopper() params := []interface{}{ [][]string{{"test"}}, } - _, err := evaluator.Predict(ctx, params) + _, err := client.ModelInfer(ctx, "test_model", params, map[int]string{0: "input1"}) require.Error(t, err) assert.Contains(t, err.Error(), "unsupported") } -func TestTritonEvaluator_PredictEmptyBatch(t *testing.T) { - cfg := &config.Model{ - ID: "test_model", - Platform: "triton", - MetaInput: shared.MetaInput{ - Inputs: []*shared.Field{ - {Name: "input1", Index: 0, DataType: "string"}, - }, - Outputs: []*shared.Field{ - {Name: "output", Index: 0, DataType: "int64"}, - }, - }, - Triton: &config.TritonConfig{ - ModelName: "test_model", - ServerID: "test_server", - }, - } - - evaluator := createMockGRPCTritonEvaluator(t, cfg, nil) - defer evaluator.Close() - - _, err := evaluator.Predict(context.Background(), []interface{}{}) - require.Error(t, err) - assert.Contains(t, err.Error(), "no input parameters") -} - -func TestTritonEvaluator_PredictMissingOutput(t *testing.T) { +func TestGRPCClient_ModelInferMissingOutput(t *testing.T) { ctx := context.Background() mock := &mockTritonServer{ @@ -639,35 +454,14 @@ func TestTritonEvaluator_PredictMissingOutput(t *testing.T) { }, } - server, listener := startMockGRPCServer(t, mock) - defer server.Stop() - - cfg := &config.Model{ - ID: "test_model", - Platform: "triton", - MetaInput: shared.MetaInput{ - Inputs: []*shared.Field{ - {Name: "input1", Index: 0, DataType: "string"}, - }, - Outputs: []*shared.Field{ - {Name: "output", Index: 0, DataType: "int64"}, - }, - }, - Triton: &config.TritonConfig{ - ModelName: "test_model", - ServerID: "test_server", - }, - } - - grpcConn := createMockTritonConn(ctx, t, listener) - evaluator := createMockGRPCTritonEvaluator(t, cfg, grpcConn) - defer evaluator.Close() + stopper, client := createClient(t, ctx, mock) + defer stopper() params := []interface{}{ [][]string{{"test"}}, } - _, err := evaluator.Predict(ctx, params) + _, err := client.ModelInfer(ctx, "test_model", params, map[int]string{0: "input1"}) require.Error(t, err) assert.Contains(t, err.Error(), "missing contents") } diff --git a/service/triton/http.go b/service/triton/http.go index 4195a81..6de7227 100644 --- a/service/triton/http.go +++ b/service/triton/http.go @@ -22,30 +22,6 @@ type HTTPClient struct { debug bool } -func (c *HTTPClient) sendRequestCheckStatus(ctx context.Context, method, path string) (*http.Response, error) { - if c.debug { - log.Printf("Sending request %s %s\n", method, path) - } - - httpReq, err := http.NewRequestWithContext(ctx, method, c.serverURL+path, nil) - if err != nil { - return nil, fmt.Errorf("failed to create HTTP request: %w", err) - } - - resp, err := c.handleRequestWithRetry(ctx, httpReq, nil) - - if err != nil { - return nil, err - } - - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return resp, fmt.Errorf("triton server http status code: %d for %s %s", resp.StatusCode, method, path) - } - - return resp, nil -} - func (c *HTTPClient) ServerReady(ctx context.Context) error { path := "/v2/health/ready" _, err := c.sendRequestCheckStatus(ctx, "GET", path) @@ -58,7 +34,7 @@ func (c *HTTPClient) ModelInfer(ctx context.Context, modelName string, inputs [] return nil, err } - tritonResponse, err := c.sendRequest(ctx, modelName, tritonRequest) + tritonResponse, err := c.sendInferRequest(ctx, modelName, tritonRequest) if err != nil { return nil, err } @@ -96,94 +72,61 @@ func (c *HTTPClient) ModelUnload(ctx context.Context, modelName string) error { return err } -func (c *HTTPClient) Close() error { - return nil -} +func (c *HTTPClient) ModelMetadata(ctx context.Context, modelName string) (*ModelMetadata, error) { + url := c.serverURL + "/v2/models/" + modelName -// TritonInput represents a single input tensor for Triton -type TritonInput struct { - Name string `json:"name"` - Shape []int `json:"shape"` - DataType string `json:"datatype"` - Data interface{} `json:"data"` -} - -type TritonOutput struct { - Name string `json:"name"` - Data interface{} `json:"data"` -} + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %w", err) + } -// TritonRequest represents the HTTP request format to Triton -type TritonRequest struct { - Inputs []TritonInput `json:"inputs"` -} + resp, err := c.handleRequestWithRetry(ctx, req, nil) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } -// TritonResponse represents the HTTP response format from Triton -type TritonResponse struct { - Outputs []TritonOutput `json:"outputs"` -} + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("triton server http status code: %d for %s", resp.StatusCode, url) + } -func (t *TritonInput) MarshalJSONObject(enc *gojay.Encoder) { - enc.StringKey("name", t.Name) - enc.ArrayKey("shape", gojay.EncodeArrayFunc(func(enc *gojay.Encoder) { - for _, v := range t.Shape { - enc.AddInt(v) - } - })) - enc.StringKey("datatype", t.DataType) + modelMetadata := new(ModelMetadata) + if err := json.NewDecoder(resp.Body).Decode(modelMetadata); err != nil { + return nil, fmt.Errorf("failed to parse Triton response: %w", err) + } - enc.ArrayKey("data", gojay.EncodeArrayFunc(func(enc *gojay.Encoder) { - switch data := t.Data.(type) { - case []string: - for _, v := range data { - enc.AddString(v) - } - case []int: - for _, v := range data { - enc.AddInt(v) - } - case []float32: - for _, v := range data { - enc.AddFloat32(v) - } - case []float64: - for _, v := range data { - enc.AddFloat64(v) - } - default: - for i := 0; i < reflect.ValueOf(data).Len(); i++ { - val := reflect.ValueOf(data).Index(i).Interface() - enc.AddInterface(val) - } - } - })) + return modelMetadata, nil } -func (t *TritonInput) IsNil() bool { - return t == nil +func (c *HTTPClient) Close() error { + return nil } -func (t *TritonRequest) MarshalJSONObject(enc *gojay.Encoder) { - enc.ArrayKey("inputs", (*TritonInputs)(&t.Inputs)) -} +func (c *HTTPClient) sendRequestCheckStatus(ctx context.Context, method, path string) (*http.Response, error) { + if c.debug { + log.Printf("Sending request %s %s\n", method, path) + } -func (t *TritonRequest) IsNil() bool { - return t == nil -} + httpReq, err := http.NewRequestWithContext(ctx, method, c.serverURL+path, nil) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %w", err) + } -type TritonInputs []TritonInput + resp, err := c.handleRequestWithRetry(ctx, httpReq, nil) -func (t *TritonInputs) MarshalJSONArray(enc *gojay.Encoder) { - for i := range *t { - enc.AddObject(&(*t)[i]) + if err != nil { + return nil, err } -} -func (t *TritonInputs) IsNil() bool { - return t == nil || len(*t) == 0 + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return resp, fmt.Errorf("triton server http status code: %d for %s %s", resp.StatusCode, method, path) + } + + return resp, nil } -func (c *HTTPClient) sendRequest(ctx context.Context, modelName string, request *TritonRequest) (*TritonResponse, error) { +func (c *HTTPClient) sendInferRequest(ctx context.Context, modelName string, request *TritonRequest) (*TritonResponse, error) { url := c.serverURL + "/v2/models/" + modelName + "/infer" buf := bytes.NewBuffer(make([]byte, 0, 1024)) @@ -193,7 +136,6 @@ func (c *HTTPClient) sendRequest(ctx context.Context, modelName string, request } jsonData := buf.Bytes() - // Create HTTP request httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) if err != nil { return nil, fmt.Errorf("failed to create HTTP request: %w", err) @@ -256,12 +198,91 @@ func (c *HTTPClient) handleRequestWithRetry(ctx context.Context, httpReq *http.R return resp, nil } -// convertToTritonRequest converts Feeds ([numInputs]([batchSize][1]T)) to Triton request format -func convertToTritonRequest(params []interface{}, indexToName map[int]string) (*TritonRequest, error) { - if len(params) == 0 { - return nil, fmt.Errorf("no input parameters provided") +// TritonInput represents a single input tensor for Triton +type TritonInput struct { + Name string `json:"name"` + Shape []int `json:"shape"` + DataType string `json:"datatype"` + Data interface{} `json:"data"` +} + +type TritonOutput struct { + Name string `json:"name"` + Data interface{} `json:"data"` +} + +// TritonRequest represents the HTTP request format to Triton +type TritonRequest struct { + Inputs []TritonInput `json:"inputs"` +} + +// TritonResponse represents the HTTP response format from Triton +type TritonResponse struct { + Outputs []TritonOutput `json:"outputs"` +} + +func (t *TritonInput) MarshalJSONObject(enc *gojay.Encoder) { + enc.StringKey("name", t.Name) + enc.ArrayKey("shape", gojay.EncodeArrayFunc(func(enc *gojay.Encoder) { + for _, v := range t.Shape { + enc.AddInt(v) + } + })) + enc.StringKey("datatype", t.DataType) + + enc.ArrayKey("data", gojay.EncodeArrayFunc(func(enc *gojay.Encoder) { + switch data := t.Data.(type) { + case []string: + for _, v := range data { + enc.AddString(v) + } + case []int: + for _, v := range data { + enc.AddInt(v) + } + case []float32: + for _, v := range data { + enc.AddFloat32(v) + } + case []float64: + for _, v := range data { + enc.AddFloat64(v) + } + default: + for i := 0; i < reflect.ValueOf(data).Len(); i++ { + val := reflect.ValueOf(data).Index(i).Interface() + enc.AddInterface(val) + } + } + })) +} + +func (t *TritonInput) IsNil() bool { + return t == nil +} + +func (t *TritonRequest) MarshalJSONObject(enc *gojay.Encoder) { + enc.ArrayKey("inputs", (*TritonInputs)(&t.Inputs)) +} + +func (t *TritonRequest) IsNil() bool { + return t == nil +} + +type TritonInputs []TritonInput + +func (t *TritonInputs) MarshalJSONArray(enc *gojay.Encoder) { + for i := range *t { + enc.AddObject(&(*t)[i]) } +} +func (t *TritonInputs) IsNil() bool { + return t == nil || len(*t) == 0 +} + +// convertToTritonRequest converts Feeds ([numInputs]([batchSize][1]T)) to Triton request format +func convertToTritonRequest(params []interface{}, indexToName map[int]string) (*TritonRequest, error) { var inputs []TritonInput // Convert each parameter to Triton input format diff --git a/service/triton/metrics.go b/service/triton/metrics.go index 73cfe49..22043a0 100644 --- a/service/triton/metrics.go +++ b/service/triton/metrics.go @@ -76,6 +76,16 @@ var ( []string{"model"}, ) + modelMetadataDurationMicrosSummary = prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Namespace: "mly", + Subsystem: "triton", + Name: "model_metadata_duration_summary_us", + Help: "Duration of Triton ModelMetadata RPCs, labeled by model name, successful only.", + Objectives: buckets.CommonSummaryObjectives, + }, + []string{"model"}, + ) ) func init() { @@ -88,6 +98,7 @@ func init() { prometheus.MustRegister(modelReadyDurationMicrosSummary) prometheus.MustRegister(modelLoadDurationMicrosSummary) prometheus.MustRegister(modelUnloadDurationMicrosSummary) + prometheus.MustRegister(modelMetadataDurationMicrosSummary) } type MeteredTritonClient struct { @@ -169,6 +180,19 @@ func (c *MeteredTritonClient) ModelUnload(ctx context.Context, modelName string) return err } +func (c *MeteredTritonClient) ModelMetadata(ctx context.Context, modelName string) (*ModelMetadata, error) { + var metadata *ModelMetadata + var err error + err = withGatherers(func() error { + metadata, err = c.client.ModelMetadata(ctx, modelName) + return err + }, func(duration float64) { + modelMetadataDurationMicrosSummary.WithLabelValues(modelName).Observe(duration) + }) + + return metadata, err +} + func (c *MeteredTritonClient) Close() error { return c.client.Close() } diff --git a/service/triton/triton_test.go b/service/triton/triton_test.go deleted file mode 100644 index cd16312..0000000 --- a/service/triton/triton_test.go +++ /dev/null @@ -1,136 +0,0 @@ -package triton - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/viant/mly/service/config" - "github.com/viant/mly/shared" -) - -func newTritonEvaluator(t *testing.T, cfg *config.Model) *TritonEvaluator { - cfg.Triton.Init() - evaluator, err := NewTritonEvaluator(cfg, nil) - require.NoError(t, err) - return evaluator -} - -func TestTritonEvaluator_Signature(t *testing.T) { - cfg := &config.Model{ - ID: "test_model", - Platform: "triton", - URL: "http://localhost:8000", - - MetaInput: shared.MetaInput{ - Inputs: []*shared.Field{ - {Name: "input1", Index: 0, DataType: "string"}, - {Name: "input2", Index: 1, DataType: "int64"}, - }, - Outputs: []*shared.Field{ - {Name: "output1", Index: 0, DataType: "float32"}, - }, - }, - Triton: &config.TritonConfig{ - ModelName: "test_model", - }, - } - - evaluator := newTritonEvaluator(t, cfg) - defer evaluator.Close() - - sig := evaluator.Signature() - require.NotNil(t, sig) - assert.Equal(t, 2, len(sig.Inputs)) - assert.Equal(t, 1, len(sig.Outputs)) - assert.Equal(t, "input1", sig.Inputs[0].Name) - assert.Equal(t, "output1", sig.Outputs[0].Name) -} - -func TestTritonEvaluator_SignatureWithAuxiliaryInputs(t *testing.T) { - cfg := &config.Model{ - ID: "test_model", - Platform: "triton", - URL: "http://localhost:8000", - MetaInput: shared.MetaInput{ - Inputs: []*shared.Field{ - {Name: "input1", Index: 0, DataType: "string"}, - {Name: "auxiliary_input", Index: 1, DataType: "int64", Auxiliary: true}, - }, - Outputs: []*shared.Field{ - {Name: "output1", Index: 0, DataType: "float32"}, - }, - }, - Triton: &config.TritonConfig{ - ModelName: "test_model", - }, - } - - evaluator := newTritonEvaluator(t, cfg) - defer evaluator.Close() - - sig := evaluator.Signature() - require.NotNil(t, sig) - // Auxiliary inputs should be excluded from signature - assert.Equal(t, 1, len(sig.Inputs)) - assert.Equal(t, "input1", sig.Inputs[0].Name) -} - -func TestTritonEvaluator_Dictionary(t *testing.T) { - cfg := &config.Model{ - ID: "test_model", - Platform: "triton", - URL: "http://localhost:8000", - MetaInput: shared.MetaInput{ - Inputs: []*shared.Field{ - {Name: "input1", Index: 0, DataType: "string"}, - }, - Outputs: []*shared.Field{ - {Name: "output1", Index: 0, DataType: "float32"}, - }, - }, - Triton: &config.TritonConfig{ - ModelName: "test_model", - }, - } - - evaluator := newTritonEvaluator(t, cfg) - defer evaluator.Close() - - dict := evaluator.Dictionary() - assert.Nil(t, dict) -} - -func TestTritonEvaluator_InputsMapping(t *testing.T) { - cfg := &config.Model{ - ID: "test_model", - Platform: "triton", - URL: "http://localhost:8000", - MetaInput: shared.MetaInput{ - Inputs: []*shared.Field{ - {Name: "string_input", Index: 0, DataType: "string"}, - {Name: "int64_input", Index: 1, DataType: "int64"}, - {Name: "int32_input", Index: 2, DataType: "int32"}, - {Name: "float32_input", Index: 3, DataType: "float32"}, - {Name: "float64_input", Index: 4, DataType: "float64"}, - }, - Outputs: []*shared.Field{ - {Name: "output", Index: 0, DataType: "float32"}, - }, - }, - Triton: &config.TritonConfig{ - ModelName: "test_model", - }, - } - - evaluator := newTritonEvaluator(t, cfg) - defer evaluator.Close() - - inputs := evaluator.Inputs() - assert.Len(t, inputs, 5) - assert.Contains(t, inputs, "string_input") - assert.Contains(t, inputs, "int64_input") - assert.Contains(t, inputs, "int32_input") - assert.Contains(t, inputs, "float32_input") - assert.Contains(t, inputs, "float64_input") -} diff --git a/service/triton/types.go b/service/triton/types.go new file mode 100644 index 0000000..d94901a --- /dev/null +++ b/service/triton/types.go @@ -0,0 +1,23 @@ +package triton + +import ( + "fmt" + "reflect" +) + +func TritonToGoType(datatype string) reflect.Type { + switch datatype { + case "INT64": + return reflect.TypeOf(int64(0)) + case "INT32": + return reflect.TypeOf(int32(0)) + case "FP32": + return reflect.TypeOf(float32(0)) + case "FP64": + return reflect.TypeOf(float64(0)) + case "BYTES": + return reflect.TypeOf("") + default: + panic(fmt.Sprintf("unsupported Triton datatype: %s", datatype)) + } +} From 29169d7aec13fff2057861281487c1528639bbcb Mon Sep 17 00:00:00 2001 From: David Choi Date: Tue, 25 Nov 2025 17:27:05 -0800 Subject: [PATCH 05/50] Undo set_sdk change. --- example/e2e/deps.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/example/e2e/deps.yaml b/example/e2e/deps.yaml index d17e368..d29b893 100644 --- a/example/e2e/deps.yaml +++ b/example/e2e/deps.yaml @@ -3,10 +3,10 @@ init: pipeline: deploy: - # set_sdk: - # action: sdk.set - # target: $target - # sdk: go:${goVersion} + set_sdk: + action: sdk.set + target: $target + sdk: go:${goVersion} install_dependencies: action: exec:run From e0335fd3e08ce5b69dbf742f23ee14c1967aaa6d Mon Sep 17 00:00:00 2001 From: David Choi Date: Tue, 25 Nov 2025 17:33:56 -0800 Subject: [PATCH 06/50] Fix test to conform with metadata interface. --- service/platform/router/router_test.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/service/platform/router/router_test.go b/service/platform/router/router_test.go index a5baa16..4495bca 100644 --- a/service/platform/router/router_test.go +++ b/service/platform/router/router_test.go @@ -92,6 +92,16 @@ func (m *mockTritonClient) ModelUnload(ctx context.Context, modelName string) er } return nil } +func (m *mockTritonClient) ModelMetadata(ctx context.Context, modelName string) (*tricli.ModelMetadata, error) { + return &tricli.ModelMetadata{ + Inputs: []tricli.MetadataTensor{ + {Name: "input1", Datatype: "INT32"}, + }, + Outputs: []tricli.MetadataTensor{ + {Name: "output1", Datatype: "FP32"}, + }, + }, nil +} func (m *mockTritonClient) Close() error { return nil } func (m *mockTritonClient) snapshotLoadCalls() []string { From f27c1a7ebb35646441768efaa268cf1cebe3e9dc Mon Sep 17 00:00:00 2001 From: David Choi Date: Tue, 25 Nov 2025 17:34:42 -0800 Subject: [PATCH 07/50] Fix bug if there is a load error in Router. --- service/platform/router/router.go | 27 +++++++++++++++++++++++++- service/platform/router/router_test.go | 1 + service/platform/router/worker.go | 1 + 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/service/platform/router/router.go b/service/platform/router/router.go index f4d1fc6..a4fc2c5 100644 --- a/service/platform/router/router.go +++ b/service/platform/router/router.go @@ -50,6 +50,8 @@ type Router struct { indexToName map[int]string inputs map[string]*domain.Input + debug bool + // router input offset is the index of the router input in the inputs array routerInputOffset int } @@ -75,6 +77,7 @@ func NewRouter(cfg *config.Model, fs afs.Service, tritonClients map[string]tricl routerName: cfg.ID, modelConfig: cfg, tritonClient: tritonClient, + debug: cfg.Debug, } if err := r.handleIO(cfg); err != nil { @@ -557,12 +560,20 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Router newModelMapping := make(map[int]string) for _, entity := range newConfig.EntityMapping { + if r.debug { + log.Printf("router: add mapping: %d -> %s", entity.EntityID, entity.ModelName) + } + newModelMapping[entity.EntityID] = entity.ModelName delete(modelsToUnload, entity.ModelName) } globalModelName := newConfig.GlobalModelName if globalModelName != "" { + if r.debug { + log.Printf("router: global model: %s", globalModelName) + } + delete(modelsToUnload, globalModelName) } @@ -616,7 +627,7 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Router } wg := sync.WaitGroup{} - errCh := make(chan error, 1) + errCh := make(chan error, len(newRoutingTable)+1) if globalEvaluator != nil { wg.Add(1) go func() { @@ -631,11 +642,25 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Router wg.Add(1) go func(model string) { defer wg.Done() + + if r.debug { + log.Printf("router: reload model: %s", model) + } + if err := newRoutingTable[model].ReloadIfNeeded(ctx); err != nil { + if r.debug { + log.Printf("router: failed to reload model: %s: %v", model, err) + } + errCh <- fmt.Errorf("failed to reload model %s: %w", model, err) } }(model) } + + if r.debug { + log.Printf("router: wait for reloads") + } + wg.Wait() close(errCh) diff --git a/service/platform/router/router_test.go b/service/platform/router/router_test.go index 4495bca..d24da28 100644 --- a/service/platform/router/router_test.go +++ b/service/platform/router/router_test.go @@ -338,6 +338,7 @@ func TestRouter_applyRouterConfig_LoadsAndSwaps(t *testing.T) { "modelB": &mockPredictOnly{}, }, globalModel: &mockPredictOnly{}, + debug: true, } reusedModelB := router.routingTable["modelB"] diff --git a/service/platform/router/worker.go b/service/platform/router/worker.go index 4acc056..09bc073 100644 --- a/service/platform/router/worker.go +++ b/service/platform/router/worker.go @@ -50,6 +50,7 @@ func handleWorkRequests(workCh chan *workRequest, observer prometheus.Observer) } if request.modelOutputEnabled { + // TODO fix ordering results = append(results, [][]string{{request.routingValueString}}) } From 96ffff98761085731fabc132f21d97871f1f94f1 Mon Sep 17 00:00:00 2001 From: David Choi Date: Tue, 25 Nov 2025 17:49:50 -0800 Subject: [PATCH 08/50] Fix test to correctly detect auxiliary handling bug. --- service/request/request.go | 43 +++++++++++++++++++++------------ service/request/request_test.go | 29 +++++++++++++--------- 2 files changed, 44 insertions(+), 28 deletions(-) diff --git a/service/request/request.go b/service/request/request.go index 92182f4..02122c5 100644 --- a/service/request/request.go +++ b/service/request/request.go @@ -21,6 +21,8 @@ type Request struct { // Passed through to Evaluator. // This is expected to be [numInputs]([batchSize][1]T). + // TODO: This shape is fixed, and should be addressed. + // Also, the fact that it is a []interface{} is a TensorFlow concern; ideally it should be map[string]interface{}. Feeds []interface{} supplied map[string]struct{} // used to check if the required inputs were provided @@ -57,45 +59,47 @@ func (r *Request) Put(key string, value string) error { return nil } + inputIndex := input.Index + switch input.Type.Kind() { case reflect.String: - r.Feeds[input.Index] = [][]string{{value}} + r.Feeds[inputIndex] = [][]string{{value}} case reflect.Bool: val, err := strconv.ParseBool(value) if err != nil { return fmt.Errorf("failed to parse bool: '%v' for %v, %w", val, key, err) } - r.Feeds[input.Index] = [][]bool{{val}} + r.Feeds[inputIndex] = [][]bool{{val}} case reflect.Int: val, err := strconv.ParseInt(value, 10, 64) if err != nil { return fmt.Errorf("failed to parse int: '%v' for %v, %w", val, key, err) } - r.Feeds[input.Index] = [][]int{{int(val)}} + r.Feeds[inputIndex] = [][]int{{int(val)}} case reflect.Int32: val, err := strconv.ParseInt(value, 10, 32) if err != nil { return fmt.Errorf("failed to parse int32: '%v' for %v, %w", val, key, err) } - r.Feeds[input.Index] = [][]int32{{int32(val)}} + r.Feeds[inputIndex] = [][]int32{{int32(val)}} case reflect.Int64: val, err := strconv.ParseInt(value, 10, 64) if err != nil { return fmt.Errorf("failed to parse int64: '%v' for %v, %w", val, key, err) } - r.Feeds[input.Index] = [][]int64{{val}} + r.Feeds[inputIndex] = [][]int64{{val}} case reflect.Float64: val, err := strconv.ParseFloat(value, 64) if err != nil { return fmt.Errorf("failed to parse float64: '%v' for %v, %w", val, key, err) } - r.Feeds[input.Index] = [][]float64{{val}} + r.Feeds[inputIndex] = [][]float64{{val}} case reflect.Float32: val, err := strconv.ParseFloat(value, 32) if err != nil { return fmt.Errorf("failed to parse float32: '%v' for %v, %w", val, key, err) } - r.Feeds[input.Index] = [][]float32{{float32(val)}} + r.Feeds[inputIndex] = [][]float32{{float32(val)}} default: // TODO add more type support return fmt.Errorf("unsupported input type: %T", reflect.New(input.Type).Interface()) @@ -143,7 +147,14 @@ func (r *Request) UnmarshalJSONObject(dec *gojay.Decoder, key string) error { } r.supplied[key] = exists - inputValue, err := r.Input.SetAt(input.Index, input.Name, input.Type.Kind()) + + inputIndex := input.Index + + if inputIndex >= len(r.Feeds) && !input.Auxiliary { + return fmt.Errorf("non-aux input %s index %d is out of range for %d feeds", input.Name, inputIndex, len(r.Feeds)) + } + + inputValue, err := r.Input.SetAt(inputIndex, input.Name, input.Type.Kind()) if err != nil { return err } @@ -153,7 +164,7 @@ func (r *Request) UnmarshalJSONObject(dec *gojay.Decoder, key string) error { return err } if !input.Auxiliary { - r.Feeds[input.Index] = inputValue.Feed(r.Input.BatchSize) + r.Feeds[inputIndex] = inputValue.Feed(r.Input.BatchSize) } return nil } @@ -168,7 +179,7 @@ func (r *Request) UnmarshalJSONObject(dec *gojay.Decoder, key string) error { } _ = inputValue.Set(value) if !input.Auxiliary { - r.Feeds[input.Index] = [][]string{{value}} + r.Feeds[inputIndex] = [][]string{{value}} } case reflect.Bool: var value bool @@ -176,7 +187,7 @@ func (r *Request) UnmarshalJSONObject(dec *gojay.Decoder, key string) error { return err } if !input.Auxiliary { - r.Feeds[input.Index] = [][]bool{{value}} + r.Feeds[inputIndex] = [][]bool{{value}} } _ = inputValue.Set(value) case reflect.Int: @@ -185,7 +196,7 @@ func (r *Request) UnmarshalJSONObject(dec *gojay.Decoder, key string) error { return err } if !input.Auxiliary { - r.Feeds[input.Index] = [][]int{{value}} + r.Feeds[inputIndex] = [][]int{{value}} } _ = inputValue.Set(value) case reflect.Int32: @@ -194,7 +205,7 @@ func (r *Request) UnmarshalJSONObject(dec *gojay.Decoder, key string) error { return err } if !input.Auxiliary { - r.Feeds[input.Index] = [][]int32{{value}} + r.Feeds[inputIndex] = [][]int32{{value}} } _ = inputValue.Set(value) case reflect.Int64: @@ -203,7 +214,7 @@ func (r *Request) UnmarshalJSONObject(dec *gojay.Decoder, key string) error { return err } if !input.Auxiliary { - r.Feeds[input.Index] = [][]int64{{value}} + r.Feeds[inputIndex] = [][]int64{{value}} } _ = inputValue.Set(value) case reflect.Float64: @@ -212,7 +223,7 @@ func (r *Request) UnmarshalJSONObject(dec *gojay.Decoder, key string) error { return err } if !input.Auxiliary { - r.Feeds[input.Index] = [][]float64{{value}} + r.Feeds[inputIndex] = [][]float64{{value}} } _ = inputValue.Set(value) case reflect.Float32: @@ -221,7 +232,7 @@ func (r *Request) UnmarshalJSONObject(dec *gojay.Decoder, key string) error { return err } if !input.Auxiliary { - r.Feeds[input.Index] = [][]float32{{float32(value)}} + r.Feeds[inputIndex] = [][]float32{{float32(value)}} } _ = inputValue.Set(value) default: diff --git a/service/request/request_test.go b/service/request/request_test.go index 76f9eee..b8b481d 100644 --- a/service/request/request_test.go +++ b/service/request/request_test.go @@ -28,7 +28,12 @@ func TestDecode(t *testing.T) { inputs[modelInput.Name] = modelInput } - numInputs := len(modelInputs) + nonAuxCount := 0 + for _, input := range modelInputs { + if !input.Auxiliary { + nonAuxCount++ + } + } testCases := []struct { desc string @@ -41,7 +46,7 @@ func TestDecode(t *testing.T) { requestEnc: `{ "batch_size": 1, "a2": ["a2_0"], - "a1": ["a1_0"], + "a1": ["a1_0"], "a3": ["a3_0"], "cache_key": ["ck1"], }`, @@ -61,7 +66,7 @@ func TestDecode(t *testing.T) { desc: "invalid", requestEnc: `{ "batch_size": 1, - "a1": ["a1_0"], + "a1": ["a1_0"], "a3": ["a3_0"], "cache_key": ["ck1"], }`, @@ -71,9 +76,9 @@ func TestDecode(t *testing.T) { desc: "duplicate_aux", requestEnc: `{ "batch_size": 1, - "a1": ["a1_0"], - "a2": ["a1_0"], - "a3": ["a3_0"], + "a1": ["a1_0"], + "a2": ["a1_0"], + "a3": ["a3_0"], "a3": ["a3_1"], "cache_key": ["ck1"], }`, @@ -83,9 +88,9 @@ func TestDecode(t *testing.T) { desc: "duplicate_input", requestEnc: `{ "batch_size": 1, - "a1": ["a1_0"], - "a2": ["a2_0"], - "a2": ["a2_1"], + "a1": ["a1_0"], + "a2": ["a2_0"], + "a2": ["a2_1"], "a3": ["a3_0"], "cache_key": ["ck1"], }`, @@ -101,8 +106,8 @@ func TestDecode(t *testing.T) { desc: "bad_batch_expansion", requestEnc: `{ "batch_size": 2, - "a1": ["a1_0"], - "a2": ["a2_0", "a2_1"], + "a1": ["a1_0"], + "a2": ["a2_0", "a2_1"], "a3": ["a3_0", "a3_1"], "cache_key": ["ck1", "ck2"], }`, @@ -132,7 +137,7 @@ func TestDecode(t *testing.T) { for _, tc := range testCases { r := &Request{ inputs: inputs, - Feeds: make([]interface{}, numInputs, numInputs), + Feeds: make([]interface{}, nonAuxCount), } err := gojay.Unmarshal([]byte(tc.requestEnc), r) From 8e7233010753c6b1d45d756a08335a0e2f364734 Mon Sep 17 00:00:00 2001 From: David Choi Date: Tue, 25 Nov 2025 17:50:30 -0800 Subject: [PATCH 09/50] Remove unused code, minor refactor. --- service/service.go | 16 +++++++++------- service/storable.go | 18 ------------------ shared/common/type.go | 5 +---- 3 files changed, 10 insertions(+), 29 deletions(-) delete mode 100644 service/storable.go diff --git a/service/service.go b/service/service.go index a21ff1c..0080882 100644 --- a/service/service.go +++ b/service/service.go @@ -64,7 +64,6 @@ type Service struct { // outputs transformer domain.Transformer - newStorable func() common.Storable // serviceMetric measures validate + model + transformer serviceMetric *gmetric.Operation @@ -93,6 +92,7 @@ func (s *Service) Config() *config.Model { return s.config } +// Signature is invoked after at least 1 successful ReloadIfNeeded(). func (s *Service) Signature() *domain.Signature { return s.evaluator.Signature() } @@ -307,13 +307,18 @@ func (s *Service) initializeService(ctx context.Context, cfg *config.Model, fs a atomic.StoreInt32(&s.ReloadOK, 1) + signature := s.Signature() + if signature == nil { + return fmt.Errorf("signature could not be determined") + } + s.transformer, err = transform.Get(cfg.Transformer) if err != nil { return err } if err = s.initDatastore(cfg, datastores); err != nil { - return err + return fmt.Errorf("failed to initialize datastore: %w", err) } if cfg.Stream != nil { @@ -407,10 +412,11 @@ func (s *Service) initDatastore(cfg *config.Model, datastores map[string]*datast signature := s.Signature() if signature == nil { - return fmt.Errorf("signature was emtpy") + return fmt.Errorf("signature was not provided") } if len(cfg.KeyFields) == 0 { + // add all inputs from model signature as a key field for _, input := range signature.Inputs { cfg.KeyFields = append(cfg.KeyFields, input.Name) } @@ -433,10 +439,6 @@ func (s *Service) initDatastore(cfg *config.Model, datastores map[string]*datast _ = datastoreConfig.FieldsDescriptor(fields) } - if s.newStorable == nil { - s.newStorable = getStorable(datastoreConfig) - } - return nil } diff --git a/service/storable.go b/service/storable.go deleted file mode 100644 index 0d8f3c9..0000000 --- a/service/storable.go +++ /dev/null @@ -1,18 +0,0 @@ -package service - -import ( - "github.com/viant/mly/shared/common" - "github.com/viant/mly/shared/common/storable" - "github.com/viant/mly/shared/config" -) - -func getStorable(cfg *config.Datastore) func() common.Storable { - result, err := storable.Singleton().Lookup(cfg.Storable) - if err == nil && result != nil { - return result - } //otherwise return default storable - - return func() common.Storable { - return storable.New(cfg.Fields) - } -} diff --git a/shared/common/type.go b/shared/common/type.go index f0cf627..d3ff745 100644 --- a/shared/common/type.go +++ b/shared/common/type.go @@ -10,7 +10,7 @@ import ( // dataType == "" is treated as a string type. func DataType(dataType string) (reflect.Type, error) { switch strings.ToLower(dataType) { - case "string": + case "string", "": return reflect.TypeOf(""), nil case "float64": return reflect.TypeOf(float64(0)), nil @@ -35,9 +35,6 @@ func DataType(dataType string) (reflect.Type, error) { case "[]float64": return reflect.TypeOf([]float64{}), nil default: - if dataType == "" { - return reflect.TypeOf(""), nil - } return nil, fmt.Errorf("unsupported data type: %v", dataType) } } From 75850e3e2ae4db80a3aeab782c21542bf3692a93 Mon Sep 17 00:00:00 2001 From: David Choi Date: Wed, 26 Nov 2025 10:04:28 -0800 Subject: [PATCH 10/50] Reduce copied code between request and requets tests. --- service/request/request.go | 2 ++ service/request/request_test.go | 15 +++------------ shared/field.go | 2 ++ 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/service/request/request.go b/service/request/request.go index 02122c5..ed1f3cf 100644 --- a/service/request/request.go +++ b/service/request/request.go @@ -31,6 +31,8 @@ type Request struct { // type metadata from service/tfservice.Service.inputs // see service/tfmodel.(*Service).reconcileIOFromSignature + // The key is expected to be the name of the input. + // Uses fields Index, Auxiliary, Type, and Name (for debugging) inputs map[string]*domain.Input } diff --git a/service/request/request_test.go b/service/request/request_test.go index b8b481d..055097a 100644 --- a/service/request/request_test.go +++ b/service/request/request_test.go @@ -21,20 +21,14 @@ func TestDecode(t *testing.T) { }, } - inputs := make(map[string]*domain.Input, len(modelInputs)) + numInputs := len(modelInputs) + inputs := make(map[string]*domain.Input, numInputs) for i, modelInput := range modelInputs { modelInput.Index = i inputs[modelInput.Name] = modelInput } - nonAuxCount := 0 - for _, input := range modelInputs { - if !input.Auxiliary { - nonAuxCount++ - } - } - testCases := []struct { desc string requestEnc string @@ -135,10 +129,7 @@ func TestDecode(t *testing.T) { } for _, tc := range testCases { - r := &Request{ - inputs: inputs, - Feeds: make([]interface{}, nonAuxCount), - } + r := NewRequest(numInputs, inputs) err := gojay.Unmarshal([]byte(tc.requestEnc), r) diff --git a/shared/field.go b/shared/field.go index da935f2..bbcc4c2 100644 --- a/shared/field.go +++ b/shared/field.go @@ -119,6 +119,8 @@ func (m *MetaInput) OutputByName() map[string]*Field { return outputByName } +// TODO look into history of this method then document its purpose. +// Is "key" key as in "important" or as in "cache key"? func (d *MetaInput) KeysLen() int { return len(d.Inputs) } From 3a1f9554b10541aba82f9f4bef95ee3949f88042 Mon Sep 17 00:00:00 2001 From: David Choi Date: Tue, 2 Dec 2025 13:58:45 -0800 Subject: [PATCH 11/50] Router now can detect and check signatures. --- service/config/model.go | 9 + service/platform/evaluator.go | 10 +- service/platform/factory/factory.go | 6 +- service/platform/router/router.go | 698 ++++++++++++++----------- service/platform/router/router_test.go | 374 ++++++------- service/triton/evaluator.go | 27 +- 6 files changed, 612 insertions(+), 512 deletions(-) diff --git a/service/config/model.go b/service/config/model.go index 445aaa8..04fbf32 100644 --- a/service/config/model.go +++ b/service/config/model.go @@ -203,6 +203,11 @@ func (m *Model) Validate() error { return fmt.Errorf("router model %s requires Router configuration", m.ID) } + if m.Triton == nil { + // TODO support TensorFlow + return fmt.Errorf("router model %s requires Triton configuration", m.ID) + } + if err := m.Router.Validate(); err != nil { return fmt.Errorf("router model %s config invalid: %w", m.ID, err) } @@ -270,6 +275,10 @@ func (t *TritonConfig) Validate(isRouter bool, urlPresent bool) error { return fmt.Errorf("triton ModelName is required") } + if isRouter && t.ServerID == "" { + return fmt.Errorf("triton ServerID is required for router mode") + } + if t.ServerID == "" && !urlPresent { return fmt.Errorf("triton ServerID or Model.URL is required") } diff --git a/service/platform/evaluator.go b/service/platform/evaluator.go index 470e7ba..d00d5b6 100644 --- a/service/platform/evaluator.go +++ b/service/platform/evaluator.go @@ -19,22 +19,24 @@ const ( type PlatformEvaluator interface { Predictor - // Signature returns underlying model's signature + // Signature returns underlying model's signature. + // This is expected to return non-nil after ReloadIfNeeded() succeeds. Signature() *domain.Signature - // Dictionary returns vocabulary if available + // Dictionary returns vocabulary if available. Dictionary() *common.Dictionary - // Inputs returns the model input definitions for request validation + // Inputs returns the model input definitions for request validation. + // This will be invoked after at least 1 ReloadIfNeeded() succeeds. Inputs() map[string]*domain.Input // Stats returns platform-specific live metrics, for debugging Stats(stats map[string]interface{}) - // Close releases resources Close() error // ReloadIfNeeded will update models as needed, check their health, and consolidate signatures, if implemented. + // This can also be named EnsurePredictionPossible() or EnsureReady() or the like. // For in-process models (TensorFlow), this will check if the underlying models need to be updated. // For external models (Triton), this will use the Model Control API to load, unload, and check the health of Triton models. ReloadIfNeeded(ctx context.Context) error diff --git a/service/platform/factory/factory.go b/service/platform/factory/factory.go index b9e8c37..f899127 100644 --- a/service/platform/factory/factory.go +++ b/service/platform/factory/factory.go @@ -35,7 +35,11 @@ func CreateEvaluator( case "triton": if isRouter { - return router.NewRouter(cfg, fs, tritonClients) + makeEvaluator := func(modelName string) (platform.PlatformEvaluator, error) { + return triton.NewRoutedTritonEvaluator(modelName, cfg, tritonClients) + } + + return router.NewRouter(cfg, fs, tritonClients, makeEvaluator) } return triton.NewTritonEvaluator(cfg, tritonClients) diff --git a/service/platform/router/router.go b/service/platform/router/router.go index a4fc2c5..7fad23c 100644 --- a/service/platform/router/router.go +++ b/service/platform/router/router.go @@ -24,207 +24,139 @@ import ( "gopkg.in/yaml.v2" ) -// Router implements the PlatformEvaluator interface for router mode. -type Router struct { - configURL string - fs afs.Service - configLock sync.RWMutex - configModified *config.Modified - routerConfig *router.RouterConfig - - routingTableLock sync.RWMutex - routingMap map[int]string - routingTable map[string]platform.PlatformEvaluator - - workCh chan *workRequest - - globalModel platform.PlatformEvaluator - fixedEvaluator platform.Predictor - modelOutputName string - - modelConfig *config.Model - routerName string - tritonClient tricli.TritonClient - - signature *domain.Signature - indexToName map[int]string - inputs map[string]*domain.Input - - debug bool +type IOState struct { + inputs map[string]*domain.Input + signature *domain.Signature // router input offset is the index of the router input in the inputs array routerInputOffset int } -func NewRouter(cfg *config.Model, fs afs.Service, tritonClients map[string]tricli.TritonClient) (*Router, error) { - if cfg.Router == nil { - return nil, fmt.Errorf("router configuration is required") - } - - if err := cfg.Router.Validate(); err != nil { - return nil, fmt.Errorf("router configuration is invalid: %w", err) - - } - - tritonClient, ok := tritonClients[cfg.Triton.ServerID] - if !ok { - return nil, fmt.Errorf("triton client not found for server ID: %s", cfg.Triton.ServerID) - } +type ModelUnloader interface { + ModelUnload(ctx context.Context, modelName string) error +} - r := &Router{ - configURL: cfg.Router.ConfigURL, - fs: fs, - routerName: cfg.ID, - modelConfig: cfg, - tritonClient: tritonClient, - debug: cfg.Debug, - } +// Router implements the PlatformEvaluator interface for router mode. +type Router struct { + configURL string + fs afs.Service - if err := r.handleIO(cfg); err != nil { - return nil, fmt.Errorf("failed to handle IO: %w", err) - } + // config lock only protects the configModified field + configLock sync.RWMutex + configModified *config.Modified - r.workCh = make(chan *workRequest, cfg.Router.MaxQueueSize) - for i := 0; i < cfg.Router.Workers; i++ { - go handleWorkRequests(r.workCh, routerWorkerChannelQueuedSummary.WithLabelValues(r.routerName)) - } + // routingTableLock protects: + // - routerConfig + // - routingMap + // - routingTable + // - globalModel + // - ioState + routingTableLock sync.RWMutex - return r, nil -} + // routerConfig contains the last loaded router configuration + routerConfig *router.RouterConfig -type preparedReplacement struct { - typ string - value interface{} -} + hasGlobalModel bool + makeRoutedEvaluator func(modelName string) (platform.PlatformEvaluator, error) -func (t *Router) handleIO(cfg *config.Model) error { - io := &cfg.MetaInput + routerInputFieldName string - if len(io.Inputs) == 0 { - return fmt.Errorf("input configuration is required for a router") - } + routingMap map[int]string + routingTable map[string]platform.PlatformEvaluator - if len(io.Outputs) == 0 { - return fmt.Errorf("output configuration is required for a router") - } + // TODO see if this can be removed, may just need to map via model name + globalModel platform.PlatformEvaluator - var inputs []domain.Input + // fixedEvaluator is non-nil IFF there is no global model configured + fixedEvaluator platform.Predictor - // for declaring the router's inputs - mappedInputs := make(map[string]*domain.Input) + // fixedEvaluatorFields is for checking all outputs in the signature are replaced + fixedEvaluatorFields map[string]struct{} + outputConfig config.OutputConfig - // generate backend input - indexToName := make(map[int]string) + workCh chan *workRequest - i := 0 - for _, input := range io.Inputs { - if !input.Auxiliary && input.Name != cfg.Router.InputName { - inputs = append(inputs, domain.Input{ - Name: input.Name, - Index: input.Index, - }) + modelOutputName string - indexToName[i] = input.Name - i++ - } + routerName string + debug bool + unloader ModelUnloader - inputType := reflect.TypeOf("") - if input.DataType != "" { - switch input.DataType { - case "string": - inputType = reflect.TypeOf("") - case "int": - inputType = reflect.TypeOf(0) - case "int32": - inputType = reflect.TypeOf(int32(0)) - case "int64": - inputType = reflect.TypeOf(int64(0)) - case "float32", "float": - inputType = reflect.TypeOf(float32(0)) - case "float64": - inputType = reflect.TypeOf(float64(0)) - } - } + ioState *IOState +} - mappedInputs[input.Name] = &domain.Input{ - Name: input.Name, - Index: len(inputs), - Type: inputType, - Vocab: false, - Auxiliary: input.Auxiliary, - } +// NewRouter creates a new Router instance. +// cfg is expected to be Init()'d and Validate()'d before calling this function. +func NewRouter(cfg *config.Model, fs afs.Service, tritonClients map[string]tricli.TritonClient, makeEvaluator func(modelName string) (platform.PlatformEvaluator, error)) (*Router, error) { + unloaders := make(map[string]ModelUnloader) + for serverID, tritonClient := range tritonClients { + unloaders[serverID] = tritonClient } - var outputs []domain.Output - outputByName := make(map[string]domain.Output) - - for i, output := range io.Outputs { - outputs = append(outputs, domain.Output{ - Name: output.Name, - Index: i, - DataType: output.DataType, - }) + return newRouter(cfg, fs, unloaders, makeEvaluator) +} - outputByName[output.Name] = outputs[i] +// newRouter uses a map[string]ModelUnloader, where ModelUnloader is-a triton.TritonClient, for testing. +func newRouter(cfg *config.Model, fs afs.Service, unloaders map[string]ModelUnloader, makeEvaluator func(modelName string) (platform.PlatformEvaluator, error)) (*Router, error) { + if cfg.Router == nil { + return nil, fmt.Errorf("router configuration is required") } - modelOutputName := cfg.Router.Output.FieldName - hasModelOutputName := modelOutputName != "" + unloader, ok := unloaders[cfg.Triton.ServerID] + if !ok { + return nil, fmt.Errorf("triton client not found for server ID: %s", cfg.Triton.ServerID) + } + var fixedEvaluator *fixedEvaluator + var fixedEvaluatorFields map[string]struct{} if !cfg.Router.Global.Exists { replacementsByName := make(map[string]config.PredictionReplacement) for _, repl := range cfg.Router.Global.PredictionReplacements { replacementsByName[repl.Name] = repl } - replacementOutputs := make([]config.PredictionReplacement, 0, len(outputs)) - for _, output := range outputs { - if hasModelOutputName && output.Name == modelOutputName { - // model-used output field name is handled in a different way - continue - } - - if _, ok := replacementsByName[output.Name]; !ok { - return fmt.Errorf("replacement for output %s not found", output.Name) - } - - replacementOutputs = append(replacementOutputs, replacementsByName[output.Name]) - } - - fixedEvaluator, err := newFixedEvaluator(replacementOutputs) + var err error + fixedEvaluator, err = newFixedEvaluator(cfg.Router.Global.PredictionReplacements) if err != nil { - return fmt.Errorf("failed to create fixed evaluator: %w", err) + return nil, fmt.Errorf("failed to create fixed evaluator: %w", err) } - t.fixedEvaluator = fixedEvaluator + fixedEvaluatorFields = make(map[string]struct{}, len(replacementsByName)) + for name := range replacementsByName { + fixedEvaluatorFields[name] = struct{}{} + } } - var modelOutputInOutputs bool = !hasModelOutputName - if hasModelOutputName { - _, modelOutputInOutputs = outputByName[modelOutputName] - } + r := &Router{ + debug: cfg.Debug, + routerName: cfg.ID, - if !modelOutputInOutputs { - outputs = append(outputs, domain.Output{ - Name: modelOutputName, - DataType: "string", - Index: len(outputs), - }) - } + configURL: cfg.Router.ConfigURL, + fs: fs, + makeRoutedEvaluator: makeEvaluator, - t.modelOutputName = modelOutputName + unloader: unloader, + outputConfig: cfg.Router.Output, + hasGlobalModel: cfg.Router.Global.Exists, - t.indexToName = indexToName + modelOutputName: cfg.Router.Output.FieldName, - t.signature = &domain.Signature{ - Inputs: inputs, - Outputs: outputs, - Output: outputs[0], + fixedEvaluator: fixedEvaluator, + fixedEvaluatorFields: fixedEvaluatorFields, } - t.inputs = mappedInputs + // spawn worker routines + r.workCh = make(chan *workRequest, cfg.Router.MaxQueueSize) + for i := 0; i < cfg.Router.Workers; i++ { + go handleWorkRequests(r.workCh, routerWorkerChannelQueuedSummary.WithLabelValues(r.routerName)) + } - return nil + return r, nil +} + +type preparedReplacement struct { + typ string + value interface{} } // Predict performs model inference with the given parameters @@ -252,120 +184,138 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface return nil, err } - numInputs := len(params) - - r.routingTableLock.RLock() - defer r.routingTableLock.RUnlock() - - globalExists := r.modelConfig.Router.Global.Exists - reportedGlobalModelName := r.modelConfig.Router.Output.GlobalModelOverride - noModelName := r.modelConfig.Router.Output.NoModelID + errCh := make(chan error, expectedBatchSize) + resultsCh := make(chan offsetResults, expectedBatchSize) predictWaitGroup := sync.WaitGroup{} predictWaitGroup.Add(expectedBatchSize) - errCh := make(chan error, expectedBatchSize) - resultsCh := make(chan offsetResults, expectedBatchSize) + var signature *domain.Signature - for batchOffset := range expectedBatchSize { - // 1 input is reserved for the router input - request := make([]interface{}, numInputs-1) + err = func() error { + r.routingTableLock.RLock() + defer r.routingTableLock.RUnlock() - var routingValueBatched interface{} + if r.ioState == nil { + return fmt.Errorf("ioState was not initialized") + } - for inputOffset := range numInputs { - debatched, err := shape.Debatch(params[inputOffset], batchOffset) - if err != nil { - return nil, fmt.Errorf("failed to debatch for row %d and input %d: %w", batchOffset, inputOffset, err) - } + // this assignment isn't strictly required to be atomic as it should never change after initialization + signature = r.ioState.signature + routerInputOffset := r.ioState.routerInputOffset - if inputOffset < r.routerInputOffset { - request[inputOffset] = debatched - } else if inputOffset == r.routerInputOffset { - routingValueBatched = debatched - } else { - request[inputOffset-1] = debatched - } - } + globalExists := r.fixedEvaluator != nil - routingValue, err := shape.SqueezeBatch(routingValueBatched) - if err != nil { - return nil, fmt.Errorf("failed to extract from batch for row %d: %w", batchOffset, err) - } + reportedGlobalModelName := r.outputConfig.GlobalModelOverride + noModelName := r.outputConfig.NoModelID - var ok bool = true - var routingValueInt int - switch routingValue := routingValue.(type) { - case int: - routingValueInt = routingValue - case int32: - routingValueInt = int(routingValue) - case int64: - routingValueInt = int(routingValue) - default: - ok = false - } + numInputs := len(params) - if !ok { - return nil, fmt.Errorf("routing value is not an int: %v, is %T, for row %d", routingValue, routingValue, batchOffset) - } + for batchOffset := range expectedBatchSize { + // 1 input is reserved for the router input + request := make([]interface{}, numInputs-1) - routingValueString, ok := r.routingMap[routingValueInt] + var routingValueBatched interface{} - var evaluator platform.Predictor - if !ok { - if globalExists { - metricFixedOnly = false - // fallback to global model - evaluator = r.globalModel + // TODO support different input ordering per evaluator - see applyRouterConfig() regarding signatures + for inputOffset := range numInputs { + debatched, err := shape.Debatch(params[inputOffset], batchOffset) + if err != nil { + return fmt.Errorf("failed to debatch for row %d and input %d: %w", batchOffset, inputOffset, err) + } - // override model name - if reportedGlobalModelName != "" { - routingValueString = reportedGlobalModelName + if inputOffset < routerInputOffset { + request[inputOffset] = debatched + } else if inputOffset == routerInputOffset { + routingValueBatched = debatched + } else { + request[inputOffset-1] = debatched } - } else { - routingValueString = noModelName - evaluator = r.fixedEvaluator } - } else { - metricFixedOnly = false - var ok bool - evaluator, ok = r.routingTable[routingValueString] - if !ok { - return nil, fmt.Errorf("no evaluator found for routing value: %v", routingValue) + routingValue, err := shape.SqueezeBatch(routingValueBatched) + if err != nil { + return fmt.Errorf("failed to extract from batch for row %d: %w", batchOffset, err) } - } - select { - case r.workCh <- &workRequest{ - wg: &predictWaitGroup, + var ok bool = true + var routingValueInt int + switch routingValue := routingValue.(type) { + case int: + routingValueInt = routingValue + case int32: + routingValueInt = int(routingValue) + case int64: + routingValueInt = int(routingValue) + default: + ok = false + } + + if !ok { + return fmt.Errorf("routing value is not an int: %v, is %T, for row %d", routingValue, routingValue, batchOffset) + } - predictor: evaluator, - ctx: ctx, - request: request, + routingValueString, ok := r.routingMap[routingValueInt] - queuedTime: time.Now(), - offset: batchOffset, - modelOutputEnabled: r.modelOutputName != "", - routingValueString: routingValueString, + var evaluator platform.Predictor + if !ok { + if globalExists { + metricFixedOnly = false + // fallback to global model + evaluator = r.globalModel + + // override model name + if reportedGlobalModelName != "" { + routingValueString = reportedGlobalModelName + } + } else { + routingValueString = noModelName + evaluator = r.fixedEvaluator + } + } else { + metricFixedOnly = false - responseCh: resultsCh, - errCh: errCh, - }: + var ok bool + evaluator, ok = r.routingTable[routingValueString] + if !ok { + return fmt.Errorf("no evaluator found for routing value: %v", routingValue) + } + } - // continue - default: - routerPredictDroppedCounter.WithLabelValues(r.routerName).Inc() - return nil, fmt.Errorf("work channel is full") + select { + case r.workCh <- &workRequest{ + wg: &predictWaitGroup, + + predictor: evaluator, + ctx: ctx, + request: request, + + queuedTime: time.Now(), + offset: batchOffset, + modelOutputEnabled: r.modelOutputName != "", + routingValueString: routingValueString, + + responseCh: resultsCh, + errCh: errCh, + }: + // continue + default: + routerPredictDroppedCounter.WithLabelValues(r.routerName).Inc() + return fmt.Errorf("work channel is full") + } } - } - predictWaitGroup.Wait() + return nil + }() + predictWaitGroup.Wait() close(errCh) close(resultsCh) + if err != nil { + return nil, err + } + for err := range errCh { return nil, err } @@ -375,7 +325,7 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface allResults[results.offset] = results.results } - endResults := make([]interface{}, len(r.signature.Outputs)) + endResults := make([]interface{}, len(signature.Outputs)) for i, results := range allResults { endResults, err = shape.ConcatAxis0(endResults, results) if err != nil { @@ -387,7 +337,9 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface } func (r *Router) Signature() *domain.Signature { - return r.signature + r.routingTableLock.RLock() + defer r.routingTableLock.RUnlock() + return r.ioState.signature } func (r *Router) Dictionary() *common.Dictionary { @@ -395,17 +347,31 @@ func (r *Router) Dictionary() *common.Dictionary { } func (r *Router) Inputs() map[string]*domain.Input { - return r.inputs + r.routingTableLock.RLock() + defer r.routingTableLock.RUnlock() + return r.ioState.inputs } func (r *Router) Stats(stats map[string]interface{}) { - + // do nothing } func (r *Router) Close() error { return nil } +func (r *Router) debugLogf(format string, args ...interface{}) { + if r.debug { + prefix := "[%s Router] " + log.Printf(prefix+format, append([]interface{}{r.routerName}, args...)...) + } +} + +type modelSignature struct { + name string + signature *domain.Signature +} + // TODO refactor with service/tfmodel/service.isModified()? func (r *Router) isModified(snapshot *config.Modified) bool { if r.routerConfig == nil || r.configModified == nil { @@ -485,7 +451,7 @@ func (r *Router) ReloadIfNeeded(ctx context.Context) error { errStrings = append(errStrings, err.Error()) } - err = fmt.Errorf("one or more model reloading errors: %s", strings.Join(errStrings, "; ")) + err = fmt.Errorf("reloading errors: %s", strings.Join(errStrings, "; ")) } r.configLock.RUnlock() @@ -538,12 +504,25 @@ func (r *Router) ReloadIfNeeded(ctx context.Context) error { return nil } +// applyRouterConfig will both update evaluators to new configuration state and verify and build the signature func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.RouterConfig) error { modelsToUnload := make(map[string]struct{}) reuseEvaluators := make(map[string]platform.PlatformEvaluator) var reuseGlobal platform.PlatformEvaluator - oldConfig := r.routerConfig + var finalSignature *domain.Signature + var oldConfig *router.RouterConfig + func() { + r.routingTableLock.RLock() + defer r.routingTableLock.RUnlock() + if r.ioState != nil { + finalSignature = r.ioState.signature + } + + reuseGlobal = r.globalModel + oldConfig = r.routerConfig + }() + if oldConfig != nil { for _, entity := range oldConfig.EntityMapping { modelsToUnload[entity.ModelName] = struct{}{} @@ -554,26 +533,24 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Router if oldConfig.GlobalModelName != "" { modelsToUnload[oldConfig.GlobalModelName] = struct{}{} - reuseGlobal = r.globalModel } } newModelMapping := make(map[int]string) for _, entity := range newConfig.EntityMapping { - if r.debug { - log.Printf("router: add mapping: %d -> %s", entity.EntityID, entity.ModelName) - } + r.debugLogf("add mapping: %d -> %s", entity.EntityID, entity.ModelName) newModelMapping[entity.EntityID] = entity.ModelName delete(modelsToUnload, entity.ModelName) } globalModelName := newConfig.GlobalModelName - if globalModelName != "" { - if r.debug { - log.Printf("router: global model: %s", globalModelName) - } + if globalModelName == "" && r.hasGlobalModel { + return fmt.Errorf("global model name is missing") + } + if globalModelName != "" { + r.debugLogf("global model: %s", globalModelName) delete(modelsToUnload, globalModelName) } @@ -589,15 +566,10 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Router continue } - evaluator, err := tricli.NewRoutedTritonEvaluator( - model, - r.tritonClient, - r.modelConfig.Triton.Timeout, - r.indexToName, - ) + evaluator, err := r.makeRoutedEvaluator(model) if err != nil { - return fmt.Errorf("failed to create Triton evaluator for model %s: %w", model, err) + return fmt.Errorf("failed to create Routed Evaluator for model %s: %w", model, err) } newRoutingTable[model] = evaluator @@ -613,21 +585,23 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Router globalEvaluator = evaluator } else { var err error - globalEvaluator, err = tricli.NewRoutedTritonEvaluator( - globalModelName, - r.tritonClient, - r.modelConfig.Triton.Timeout, - r.indexToName, - ) - + globalEvaluator, err = r.makeRoutedEvaluator(globalModelName) if err != nil { - return fmt.Errorf("failed to create Triton evaluator for global model %s: %w", globalModelName, err) + return fmt.Errorf("failed to create Routed Evaluator for global model %s: %w", globalModelName, err) } } } wg := sync.WaitGroup{} - errCh := make(chan error, len(newRoutingTable)+1) + + numWorkers := len(newRoutingTable) + if globalEvaluator != nil { + numWorkers++ + } + + errCh := make(chan error, numWorkers) + signatureCh := make(chan modelSignature, numWorkers) + if globalEvaluator != nil { wg.Add(1) go func() { @@ -643,26 +617,33 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Router go func(model string) { defer wg.Done() - if r.debug { - log.Printf("router: reload model: %s", model) + r.debugLogf("reload model: %s", model) + + modelEvaluator := newRoutingTable[model] + if err := modelEvaluator.ReloadIfNeeded(ctx); err != nil { + r.debugLogf("failed to reload model: %s: %v", model, err) + errCh <- fmt.Errorf("failed to reload model %s: %w", model, err) } - if err := newRoutingTable[model].ReloadIfNeeded(ctx); err != nil { - if r.debug { - log.Printf("router: failed to reload model: %s: %v", model, err) - } + evalSig := modelEvaluator.Signature() - errCh <- fmt.Errorf("failed to reload model %s: %w", model, err) + if evalSig == nil { + errCh <- fmt.Errorf("model %s signature is nil", model) + return + } + + signatureCh <- modelSignature{ + name: model, + signature: evalSig, } }(model) } - if r.debug { - log.Printf("router: wait for reloads") - } + r.debugLogf("wait for reloads") wg.Wait() close(errCh) + close(signatureCh) if len(errCh) > 0 { var errStrings []string @@ -672,18 +653,153 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Router return fmt.Errorf("one or more model reloading errors: %s", strings.Join(errStrings, "; ")) } + sigInputMap := make(map[string]*domain.Input) + sigOutputMap := make(map[string]*domain.Output) + + // we only create ioState on the first reload + var ioState *IOState = new(IOState) + if finalSignature != nil { + for _, input := range finalSignature.Inputs { + sigInputMap[input.Name] = &input + } + + for _, output := range finalSignature.Outputs { + sigOutputMap[output.Name] = &output + } + } + + for signature := range signatureCh { + // accept first available signature as the final signature + if finalSignature == nil { + // in the creation of the signature, include the routing input + + finalSignature = signature.signature + + inputOffset := len(finalSignature.Inputs) + + routerInput := domain.Input{ + Name: r.routerInputFieldName, + Index: inputOffset, + Type: reflect.TypeOf(int64(0)), + } + + ioState.routerInputOffset = inputOffset + + finalSignature.Inputs = append(finalSignature.Inputs, routerInput) + + for _, input := range finalSignature.Inputs { + sigInputMap[input.Name] = &input + } + + if r.modelOutputName != "" { + // also, add the selected model output + modelOutput := domain.Output{ + Name: r.modelOutputName, + Index: len(finalSignature.Outputs), + DataType: "string", + } + + finalSignature.Outputs = append(finalSignature.Outputs, modelOutput) + } + + for _, output := range finalSignature.Outputs { + sigOutputMap[output.Name] = &output + } + + continue + } + + thisSignature := signature.signature + // validate signature consistency + thisSignatureOutputMap := make(map[string]*domain.Output) + for _, output := range thisSignature.Outputs { + oldOutput, ok := sigOutputMap[output.Name] + if !ok { + return fmt.Errorf("signature output %s for model %s not found in the previous signature", output.Name, signature.name) + } + + thisSignatureOutputMap[output.Name] = &output + + // TODO permit this + if oldOutput.Index != output.Index { + return fmt.Errorf("signature output %s for model %s has index %d, and the previous signature has index %d", output.Name, signature.name, output.Index, oldOutput.Index) + } + + if oldOutput.DataType != output.DataType { + return fmt.Errorf("signature output %s for model %s has data type %s, and the previous signature has data type %s", output.Name, signature.name, output.DataType, oldOutput.DataType) + } + } + + for expectedOutput := range sigOutputMap { + if _, ok := thisSignatureOutputMap[expectedOutput]; !ok && expectedOutput != r.modelOutputName { + return fmt.Errorf("signature output %s for was not found in model %s signature", expectedOutput, signature.name) + } + } + + thisSignatureInputMap := make(map[string]*domain.Input) + for _, input := range thisSignature.Inputs { + oldInput, ok := sigInputMap[input.Name] + if !ok { + return fmt.Errorf("signature input %s for model %s not found in the previous signature", input.Name, signature.name) + } + + thisSignatureInputMap[input.Name] = &input + + // TODO permit this + if oldInput.Index != input.Index { + return fmt.Errorf("signature input %s for model %s has index %d, and the previous signature has index %d", input.Name, signature.name, input.Index, oldInput.Index) + } + + if !oldInput.Type.ConvertibleTo(input.Type) { + return fmt.Errorf("signature input %s for model %s has data type %s, and the previous signature has data type %s", input.Name, signature.name, input.Type.String(), oldInput.Type.String()) + } + } + + for expectedInput := range sigInputMap { + if _, ok := thisSignatureInputMap[expectedInput]; !ok && expectedInput != r.routerInputFieldName { + return fmt.Errorf("signature input %s for was not found in model %s signature", expectedInput, signature.name) + } + } + } + + if r.fixedEvaluatorFields != nil { + // TODO this is actually an acceptable case, but needs to be addressed elsewhere first before it is permitted + for field := range r.fixedEvaluatorFields { + if _, ok := sigOutputMap[field]; !ok { + return fmt.Errorf("fixed evaluator field: %s was not found in the signature outputs", field) + } + } + + for _, field := range sigOutputMap { + if _, ok := r.fixedEvaluatorFields[field.Name]; !ok && field.Name != r.modelOutputName { + return fmt.Errorf("signature output %s is not replaced", field.Name) + } + } + } + + ioState.signature = finalSignature + ioState.inputs = sigInputMap + + if globalEvaluator != nil { + if _, exists := newRoutingTable[globalModelName]; !exists { + newRoutingTable[globalModelName] = globalEvaluator + } + } + func() { r.routingTableLock.Lock() defer r.routingTableLock.Unlock() - if globalEvaluator != nil { - if _, exists := newRoutingTable[globalModelName]; !exists { - newRoutingTable[globalModelName] = globalEvaluator - } - } - r.globalModel = globalEvaluator - r.routingMap = newModelMapping + r.routerConfig = newConfig + + r.routingMap = newModelMapping r.routingTable = newRoutingTable + + r.globalModel = globalEvaluator + + if r.ioState == nil { + r.ioState = ioState + } }() for model := range modelsToUnload { @@ -695,7 +811,7 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Router ctxTo, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := r.unloadModel(ctxTo, modelName); err != nil { - log.Printf("failed to unload model %s: %v\n", modelName, err) + r.debugLogf("failed to unload model %s: %v\n", modelName, err) } }(model) } @@ -705,7 +821,7 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Router func (r *Router) unloadModel(ctx context.Context, modelName string) error { defer routerModelUnloadGauge.WithLabelValues(r.routerName).Dec() - if err := r.tritonClient.ModelUnload(ctx, modelName); err != nil { + if err := r.unloader.ModelUnload(ctx, modelName); err != nil { return fmt.Errorf("failed to unload model %s: %w", modelName, err) } return nil diff --git a/service/platform/router/router_test.go b/service/platform/router/router_test.go index d24da28..50458d0 100644 --- a/service/platform/router/router_test.go +++ b/service/platform/router/router_test.go @@ -3,6 +3,7 @@ package router import ( "context" "errors" + "fmt" "reflect" "strings" "sync" @@ -12,17 +13,32 @@ import ( "github.com/viant/mly/service/config" "github.com/viant/mly/service/domain" "github.com/viant/mly/service/platform" - tricli "github.com/viant/mly/service/triton" - "github.com/viant/mly/shared" "github.com/viant/mly/shared/common" sharedrouter "github.com/viant/mly/shared/config/router" ) // --- Router Predict scaffolds --- -type mockPredictOnly struct{} +type mockTritonServer struct { + mu sync.Mutex + + readyState map[string]bool + modelLoadErr map[string]error +} + +type mockEvaluator struct { + tritonServer *mockTritonServer + + modelName string + signature *domain.Signature +} + +func (m *mockEvaluator) Predict(ctx context.Context, params []interface{}) ([]interface{}, error) { + inputs := m.signature.Inputs + if len(inputs) != len(params) { + return nil, fmt.Errorf("expected %d inputs, got %d", len(inputs), len(params)) + } -func (m *mockPredictOnly) Predict(ctx context.Context, params []interface{}) ([]interface{}, error) { // params is expected to have a single non-router input in this test: [][]string with shape [1][1] var v string switch typed := params[0].(type) { @@ -37,89 +53,56 @@ func (m *mockPredictOnly) Predict(ctx context.Context, params []interface{}) ([] return []interface{}{out}, nil } -func (m *mockPredictOnly) Signature() *domain.Signature { return nil } -func (m *mockPredictOnly) Dictionary() *common.Dictionary { return nil } -func (m *mockPredictOnly) Inputs() map[string]*domain.Input { return nil } -func (m *mockPredictOnly) Stats(map[string]interface{}) {} -func (m *mockPredictOnly) Close() error { return nil } -func (m *mockPredictOnly) ReloadIfNeeded(ctx context.Context) error { - return nil -} +func (m *mockEvaluator) Signature() *domain.Signature { return m.signature } +func (m *mockEvaluator) Dictionary() *common.Dictionary { return nil } +func (m *mockEvaluator) Inputs() map[string]*domain.Input { return nil } +func (m *mockEvaluator) Stats(map[string]interface{}) {} +func (m *mockEvaluator) Close() error { return nil } -type mockTritonClient struct { - mu sync.Mutex - loadCalls []string - unloadCalls []string - readyCalls []string - readyState map[string]bool - unloadCh chan string - modelLoadErr map[string]error -} +func (m *mockEvaluator) ReloadIfNeeded(ctx context.Context) error { + if m.tritonServer == nil { + return nil + } -func (m *mockTritonClient) ServerReady(ctx context.Context) error { return nil } -func (m *mockTritonClient) ModelInfer(ctx context.Context, modelName string, inputs []interface{}, indexToName map[int]string) ([]interface{}, error) { - return nil, nil -} -func (m *mockTritonClient) ModelReady(ctx context.Context, modelName string) (bool, error) { - m.mu.Lock() - defer m.mu.Unlock() - m.readyCalls = append(m.readyCalls, modelName) - if ready, ok := m.readyState[modelName]; ok { - return ready, nil - } - return true, nil -} -func (m *mockTritonClient) ModelLoad(ctx context.Context, modelName string) error { - m.mu.Lock() - defer m.mu.Unlock() - m.loadCalls = append(m.loadCalls, modelName) - if err := m.modelLoadErr[modelName]; err != nil { + m.tritonServer.mu.Lock() + defer m.tritonServer.mu.Unlock() + + if err := m.tritonServer.modelLoadErr[m.modelName]; err != nil { return err } - if m.readyState == nil { - m.readyState = make(map[string]bool) + + if m.tritonServer.readyState == nil { + m.tritonServer.readyState = make(map[string]bool) } - m.readyState[modelName] = true + + m.tritonServer.readyState[m.modelName] = true return nil } -func (m *mockTritonClient) ModelUnload(ctx context.Context, modelName string) error { - m.mu.Lock() - m.unloadCalls = append(m.unloadCalls, modelName) + +type mockUnloader struct { + tritonServer *mockTritonServer + unloadCh chan string +} + +func (m *mockUnloader) ModelUnload(ctx context.Context, modelName string) error { + if m.tritonServer != nil { + m.tritonServer.mu.Lock() + defer m.tritonServer.mu.Unlock() + + if m.tritonServer.readyState == nil { + m.tritonServer.readyState = make(map[string]bool) + } + + m.tritonServer.readyState[modelName] = false + } + ch := m.unloadCh - m.mu.Unlock() + if ch != nil { ch <- modelName } - return nil -} -func (m *mockTritonClient) ModelMetadata(ctx context.Context, modelName string) (*tricli.ModelMetadata, error) { - return &tricli.ModelMetadata{ - Inputs: []tricli.MetadataTensor{ - {Name: "input1", Datatype: "INT32"}, - }, - Outputs: []tricli.MetadataTensor{ - {Name: "output1", Datatype: "FP32"}, - }, - }, nil -} -func (m *mockTritonClient) Close() error { return nil } - -func (m *mockTritonClient) snapshotLoadCalls() []string { - m.mu.Lock() - defer m.mu.Unlock() - return append([]string(nil), m.loadCalls...) -} -func (m *mockTritonClient) snapshotUnloadCalls() []string { - m.mu.Lock() - defer m.mu.Unlock() - return append([]string(nil), m.unloadCalls...) -} - -func (m *mockTritonClient) snapshotReadyCalls() []string { - m.mu.Lock() - defer m.mu.Unlock() - return append([]string(nil), m.readyCalls...) + return nil } func waitForCalls(t *testing.T, ch <-chan string, count int) []string { @@ -136,7 +119,7 @@ func waitForCalls(t *testing.T, ch <-chan string, count int) []string { return out } -func TestRouter_Predict_RoutesAndConcats(t *testing.T) { +func TestRouter_Predict(t *testing.T) { ctx := context.Background() tests := []struct { @@ -202,7 +185,7 @@ func TestRouter_Predict_RoutesAndConcats(t *testing.T) { ConfigURL: "memory://router-config", InputName: "router_id", Global: config.GlobalModelConfig{ - Exists: true, // avoid fixed replacements path + Exists: true, }, Output: config.OutputConfig{ FieldName: "model_output", @@ -210,7 +193,7 @@ func TestRouter_Predict_RoutesAndConcats(t *testing.T) { }, verifier: func(t *testing.T, results []interface{}) { if len(results) != 2 { - t.Fatalf("expected 1 output, got %d", len(results)) + t.Fatalf("expected 2 outputs, got %d, %v", len(results), results) } func() { @@ -238,25 +221,26 @@ func TestRouter_Predict_RoutesAndConcats(t *testing.T) { }, } - for _, test := range tests { + downstreamInput := domain.Input{ + Name: "text", + Index: 1, + Type: reflect.TypeOf(""), + } + + downstreamSignature := &domain.Signature{ + Inputs: []domain.Input{downstreamInput}, + Outputs: []domain.Output{ + {Name: "score", Index: 0, DataType: "float32"}, + }, + } + for _, test := range tests { t.Run(test.name, func(t *testing.T) { cfg := &config.Model{ ID: "router_test", Mode: "router", Platform: "triton", - MetaInput: shared.MetaInput{ - Inputs: []*shared.Field{ - // router input first (default offset 0) - {Name: "router_id", Index: 0, DataType: "int64"}, - // single backend input - {Name: "text", Index: 1, DataType: "string"}, - }, - Outputs: []*shared.Field{ - {Name: "score", Index: 0, DataType: "float32"}, - }, - }, - Router: test.routerConfig, + Router: test.routerConfig, Triton: &config.TritonConfig{ ServerID: "test_server", }, @@ -266,8 +250,10 @@ func TestRouter_Predict_RoutesAndConcats(t *testing.T) { cfg.Router.MaxQueueSize = 1000 cfg.Router.Workers = 3 - router, err := NewRouter(cfg, nil, map[string]tricli.TritonClient{ - "test_server": &mockTritonClient{}, + router, err := newRouter(cfg, nil, map[string]ModelUnloader{ + "test_server": &mockUnloader{}, + }, func(modelName string) (platform.PlatformEvaluator, error) { + return &mockEvaluator{signature: downstreamSignature}, nil }) if err != nil { @@ -278,7 +264,34 @@ func TestRouter_Predict_RoutesAndConcats(t *testing.T) { 1: "model1", 2: "model2", } - mockEval := &mockPredictOnly{} + + routerOutputs := []domain.Output{ + {Name: "score", Index: 0, DataType: "float32"}, + } + + if test.routerConfig.Output.FieldName != "" { + routerOutputs = append(routerOutputs, domain.Output{Name: test.routerConfig.Output.FieldName, Index: 1, DataType: "string"}) + } + + routerInputName := cfg.Router.InputName + routerInput := domain.Input{Name: routerInputName, Index: 0, Type: reflect.TypeOf(int64(0))} + + router.ioState = &IOState{ + inputs: map[string]*domain.Input{ + routerInputName: &routerInput, + downstreamInput.Name: &downstreamInput, + }, + + signature: &domain.Signature{ + Inputs: []domain.Input{ + routerInput, + downstreamInput, + }, + Outputs: routerOutputs, + }, + } + + mockEval := &mockEvaluator{signature: downstreamSignature} router.routingTable = map[string]platform.PlatformEvaluator{ "model1": mockEval, "model2": mockEval, @@ -302,12 +315,8 @@ func TestRouter_Predict_RoutesAndConcats(t *testing.T) { func TestRouter_applyRouterConfig_LoadsAndSwaps(t *testing.T) { ctx := context.Background() - mockClient := &mockTritonClient{ + mockClient := &mockUnloader{ unloadCh: make(chan string, 2), - readyState: map[string]bool{ - "modelC": false, - "global-new": false, - }, } oldConfig := &sharedrouter.RouterConfig{ @@ -318,27 +327,31 @@ func TestRouter_applyRouterConfig_LoadsAndSwaps(t *testing.T) { GlobalModelName: "global-old", } - router := &Router{ - tritonClient: mockClient, - modelConfig: &config.Model{ - Triton: &config.TritonConfig{ - Timeout: 100, - }, + signature := &domain.Signature{ + Inputs: []domain.Input{ + {Name: "text", Index: 0, Type: reflect.TypeOf("")}, }, - indexToName: map[int]string{ - 0: "text", + Outputs: []domain.Output{ + {Name: "score", Index: 0, DataType: "float32"}, }, + } + + router := &Router{ + unloader: mockClient, routerConfig: oldConfig, routingMap: map[int]string{ 1: "modelA", 2: "modelB", }, routingTable: map[string]platform.PlatformEvaluator{ - "modelA": &mockPredictOnly{}, - "modelB": &mockPredictOnly{}, + "modelA": &mockEvaluator{signature: signature}, + "modelB": &mockEvaluator{signature: signature}, }, - globalModel: &mockPredictOnly{}, + globalModel: &mockEvaluator{}, debug: true, + makeRoutedEvaluator: func(modelName string) (platform.PlatformEvaluator, error) { + return &mockEvaluator{signature: signature}, nil + }, } reusedModelB := router.routingTable["modelB"] @@ -355,59 +368,7 @@ func TestRouter_applyRouterConfig_LoadsAndSwaps(t *testing.T) { t.Fatalf("applyRouterConfig returned error: %v", err) } - loadCalls := mockClient.snapshotLoadCalls() - if len(loadCalls) != 2 { - // global-new and modelC - t.Fatalf("expected 2 model loads, got %d (%v), ready=%v", len(loadCalls), loadCalls, mockClient.snapshotReadyCalls()) - } - - expectedLoads := map[string]bool{ - "modelC": false, - "global-new": false, - } - for _, call := range loadCalls { - if _, ok := expectedLoads[call]; ok { - expectedLoads[call] = true - } - } - for model, seen := range expectedLoads { - if !seen { - t.Fatalf("expected load for %s was not observed; calls=%v", model, loadCalls) - } - } - - readyCalls := mockClient.snapshotReadyCalls() - expectedReady := map[string]bool{ - "modelC": false, - "global-new": false, - } - for _, call := range readyCalls { - if _, ok := expectedReady[call]; ok { - expectedReady[call] = true - } - } - for model, seen := range expectedReady { - if !seen { - t.Fatalf("expected readiness check for %s was not observed; calls=%v", model, readyCalls) - } - } - waitForCalls(t, mockClient.unloadCh, 2) - unloadCalls := mockClient.snapshotUnloadCalls() - expectedUnloads := map[string]bool{ - "modelA": false, - "global-old": false, - } - for _, call := range unloadCalls { - if _, ok := expectedUnloads[call]; ok { - expectedUnloads[call] = true - } - } - for model, seen := range expectedUnloads { - if !seen { - t.Fatalf("expected unload for %s was not observed; calls=%v", model, unloadCalls) - } - } if router.routerConfig != newConfig { t.Fatalf("routerConfig pointer not updated") @@ -417,6 +378,7 @@ func TestRouter_applyRouterConfig_LoadsAndSwaps(t *testing.T) { 1: "modelB", 3: "modelC", } + if !reflect.DeepEqual(router.routingMap, expectedRouting) { t.Fatalf("routingMap mismatch, got %#v", router.routingMap) } @@ -428,12 +390,15 @@ func TestRouter_applyRouterConfig_LoadsAndSwaps(t *testing.T) { if _, ok := router.routingTable["modelB"]; !ok { t.Fatalf("routingTable missing modelB") } + if router.routingTable["modelB"] != reusedModelB { t.Fatalf("modelB evaluator was not reused") } + if _, ok := router.routingTable["modelC"]; !ok { t.Fatalf("routingTable missing modelC") } + if _, ok := router.routingTable["modelA"]; ok { t.Fatalf("routingTable still contains modelA") } @@ -441,34 +406,48 @@ func TestRouter_applyRouterConfig_LoadsAndSwaps(t *testing.T) { func TestRouter_applyRouterConfig_LoadError(t *testing.T) { ctx := context.Background() + loadErr := errors.New("load failure") - mockClient := &mockTritonClient{ - readyState: map[string]bool{ - "modelX": false, - }, - modelLoadErr: map[string]error{ - "modelX": loadErr, - }, + + tritonServer := new(mockTritonServer) + tritonServer.modelLoadErr = map[string]error{ + "modelX": loadErr, } + mockClient := &mockUnloader{} + oldConfig := &sharedrouter.RouterConfig{ EntityMapping: []sharedrouter.EntityKV{ {EntityID: 1, ModelName: "modelA"}, }, } - router := &Router{ - tritonClient: mockClient, - modelConfig: &config.Model{ - Triton: &config.TritonConfig{ - Timeout: 50, - }, + signature := &domain.Signature{ + Inputs: []domain.Input{ + {Name: "text", Index: 0, Type: reflect.TypeOf("")}, }, - indexToName: map[int]string{}, + Outputs: []domain.Output{ + {Name: "score", Index: 0, DataType: "float32"}, + }, + } + + router := &Router{ + debug: true, + routerName: "load_error", + + unloader: mockClient, routerConfig: oldConfig, routingMap: map[int]string{ 1: "modelA", }, + + makeRoutedEvaluator: func(modelName string) (platform.PlatformEvaluator, error) { + return &mockEvaluator{ + tritonServer: tritonServer, + modelName: modelName, + signature: signature, + }, nil + }, } newConfig := &sharedrouter.RouterConfig{ @@ -481,6 +460,7 @@ func TestRouter_applyRouterConfig_LoadError(t *testing.T) { if err == nil { t.Fatalf("expected error but got nil") } + if !strings.Contains(err.Error(), "modelX") { t.Fatalf("expected error mentioning modelX, got %v", err) } @@ -496,20 +476,11 @@ func TestRouter_applyRouterConfig_LoadError(t *testing.T) { if router.routingTable != nil { t.Fatalf("routingTable should not be replaced on error") } - - loadCalls := mockClient.snapshotLoadCalls() - if len(loadCalls) != 1 || loadCalls[0] != "modelX" { - t.Fatalf("expected single load attempt for modelX, got %v", loadCalls) - } } func TestRouter_applyRouterConfig_SkipsLoadWhenReady(t *testing.T) { ctx := context.Background() - mockClient := &mockTritonClient{ - readyState: map[string]bool{ - "modelC": true, - }, - } + mockClient := &mockUnloader{} oldConfig := &sharedrouter.RouterConfig{ EntityMapping: []sharedrouter.EntityKV{ @@ -517,22 +488,26 @@ func TestRouter_applyRouterConfig_SkipsLoadWhenReady(t *testing.T) { }, } - router := &Router{ - tritonClient: mockClient, - modelConfig: &config.Model{ - Triton: &config.TritonConfig{ - Timeout: 10, - }, + signature := &domain.Signature{ + Inputs: []domain.Input{ + {Name: "text", Index: 0, Type: reflect.TypeOf("")}, }, - indexToName: map[int]string{ - 0: "text", + Outputs: []domain.Output{ + {Name: "score", Index: 0, DataType: "float32"}, }, + } + + router := &Router{ + unloader: mockClient, routerConfig: oldConfig, routingMap: map[int]string{ 1: "modelA", }, routingTable: map[string]platform.PlatformEvaluator{ - "modelA": &mockPredictOnly{}, + "modelA": &mockEvaluator{signature: signature}, + }, + makeRoutedEvaluator: func(modelName string) (platform.PlatformEvaluator, error) { + return &mockEvaluator{signature: signature}, nil }, } @@ -547,15 +522,6 @@ func TestRouter_applyRouterConfig_SkipsLoadWhenReady(t *testing.T) { t.Fatalf("applyRouterConfig returned error: %v", err) } - if loads := mockClient.snapshotLoadCalls(); len(loads) != 0 { - t.Fatalf("expected no loads when model is ready, got %v", loads) - } - - readyCalls := mockClient.snapshotReadyCalls() - if len(readyCalls) != 1 || readyCalls[0] != "modelC" { - t.Fatalf("expected readiness check for modelC, got %v", readyCalls) - } - if _, ok := router.routingTable["modelC"]; !ok { t.Fatalf("routingTable missing modelC after reload") } diff --git a/service/triton/evaluator.go b/service/triton/evaluator.go index 010fbd0..b724b27 100644 --- a/service/triton/evaluator.go +++ b/service/triton/evaluator.go @@ -70,24 +70,27 @@ func NewTritonEvaluator(config *config.Model, tritonClients map[string]TritonCli // clients defined in TritonServers are assumed to be in EXPLICIT mode repositoryExplicit: !isPrivateClient || config.Triton.RepositoryExplicit, - } - evaluator.configuredInputs = config.MetaInput.Inputs + configuredInputs: config.MetaInput.Inputs, - evaluator.modelID = config.ID - evaluator.debug = config.Debug + modelID: config.ID, + debug: config.Debug, + } return evaluator, nil } -func NewRoutedTritonEvaluator(modelName string, client TritonClient, timeoutMs int, indexToName map[int]string) (*TritonEvaluator, error) { - return &TritonEvaluator{ - client: client, - modelName: modelName, - timeout: time.Duration(timeoutMs) * time.Millisecond, - repositoryExplicit: true, - indexToName: indexToName, - }, nil +// Upward dependency, but provides Evaluators as needed for the service/platform/router module. +func NewRoutedTritonEvaluator(modelName string, config *config.Model, tritonClients map[string]TritonClient) (*TritonEvaluator, error) { + evaluator, err := NewTritonEvaluator(config, tritonClients) + if err != nil { + return nil, fmt.Errorf("failed to create Triton Routed evaluator: %w", err) + } + + evaluator.modelName = modelName + evaluator.configuredInputs = nil // routed evaluators must not have any additional inputs + + return evaluator, nil } // Predict performs inference via Triton Inference Server From 306eb25a3129a1fadddc663f4c356103229db1fa Mon Sep 17 00:00:00 2001 From: David Choi Date: Wed, 3 Dec 2025 13:26:12 -0800 Subject: [PATCH 12/50] Add support for multiple routers using the same Triton Inference service with model management. --- service/config/model.go | 20 +- service/config/router.go | 2 +- service/endpoint/model.go | 4 +- service/endpoint/service.go | 6 +- service/platform/factory/factory.go | 10 +- service/platform/router/router.go | 76 ++++-- service/platform/router/router_test.go | 356 ++++++++++++++++++++----- service/service.go | 4 +- service/triton/client.go | 7 +- service/triton/evaluator.go | 57 ++-- service/triton/evaluator_test.go | 4 +- service/triton/repository.go | 48 ++++ service/triton/service.go | 46 ++++ shared/config/router/router.go | 2 +- shared/config/router/router_test.go | 10 +- 15 files changed, 517 insertions(+), 135 deletions(-) create mode 100644 service/triton/repository.go create mode 100644 service/triton/service.go diff --git a/service/config/model.go b/service/config/model.go index 04fbf32..282c3a0 100644 --- a/service/config/model.go +++ b/service/config/model.go @@ -160,6 +160,10 @@ func (m *Model) Init(globalBatchConfig *batchconfig.BatcherConfig) { if m.Router != nil { m.Router.Init() } + + if m.Triton != nil { + m.Triton.Init(m.IsRouter()) + } } func (m *Model) Validate() error { @@ -183,7 +187,7 @@ func (m *Model) Validate() error { return fmt.Errorf("tensorflow model %s requires URL", m.ID) } - if m.Mode == "router" { + if m.IsRouter() { return fmt.Errorf("tensorflow model %s is not supported in router mode", m.ID) } case "triton": @@ -191,14 +195,14 @@ func (m *Model) Validate() error { return fmt.Errorf("triton model %s requires Triton configuration", m.ID) } - if err := m.Triton.Validate(m.Mode == "router", m.URL != ""); err != nil { + if err := m.Triton.Validate(m.IsRouter(), m.URL != ""); err != nil { return fmt.Errorf("triton model %s config invalid: %w", m.ID, err) } default: return fmt.Errorf("unsupported platform '%s' for model %s (supported: tensorflow, triton)", platform, m.ID) } - if m.Mode == "router" { + if m.IsRouter() { if m.Router == nil { return fmt.Errorf("router model %s requires Router configuration", m.ID) } @@ -216,6 +220,10 @@ func (m *Model) Validate() error { return nil } +func (m *Model) IsRouter() bool { + return m.Mode == "router" +} + // ConfigCheck is a path to validate relationships with other config entities. func (m *Model) ConfigCheck(validDatastoreIDs map[string]struct{}, validTritonServerIDs map[string]struct{}) error { if m.DataStore != "" { @@ -264,10 +272,14 @@ type TritonConfig struct { Timeout int `json:",omitempty" yaml:",omitempty"` } -func (t *TritonConfig) Init() { +func (t *TritonConfig) Init(isRouter bool) { if t.Timeout == 0 { t.Timeout = 100 } + + if isRouter { + t.ModelName = "" + } } func (t *TritonConfig) Validate(isRouter bool, urlPresent bool) error { diff --git a/service/config/router.go b/service/config/router.go index d7cdd16..9b2cdb2 100644 --- a/service/config/router.go +++ b/service/config/router.go @@ -6,7 +6,7 @@ type RouterConfig struct { // Required if Model.Mode is "router". ConfigURL string - // Required + // Required name of the input that will route the request to the backend. InputName string `json:",omitempty" yaml:",omitempty"` // Unimplemented. diff --git a/service/endpoint/model.go b/service/endpoint/model.go index 6afe7c3..cddb771 100644 --- a/service/endpoint/model.go +++ b/service/endpoint/model.go @@ -52,7 +52,7 @@ func Build( mux *http.ServeMux, config *Config, datastores map[string]*datastore.Service, - tritonClients map[string]triton.TritonClient, + tritonServices map[string]*triton.Service, hooks []Hook, metrics *gmetric.Service, promReg prometheus.Registerer, @@ -129,7 +129,7 @@ func Build( var modelSrv *service.Service var err error - modelSrv, err = service.New(context.Background(), model, fs, metrics, datastores, tritonClients, sema, cfge.MaxEvaluatorWait, serviceOpts...) + modelSrv, err = service.New(context.Background(), model, fs, metrics, datastores, tritonServices, sema, cfge.MaxEvaluatorWait, serviceOpts...) if err != nil { return fmt.Errorf("failed to create service for model:%v, err:%w", model.ID, err) diff --git a/service/endpoint/service.go b/service/endpoint/service.go index f191c9e..67eb051 100644 --- a/service/endpoint/service.go +++ b/service/endpoint/service.go @@ -207,7 +207,7 @@ func New(cfg *Config) (*Service, error) { return nil, fmt.Errorf("failed to create datastores: %w", err) } - tritonClients := make(map[string]triton.TritonClient) + tritonServices := make(map[string]*triton.Service) for _, server := range cfg.TritonServers { tritonClient, err := triton.NewClient(server) if err != nil { @@ -223,14 +223,14 @@ func New(cfg *Config) (*Service, error) { return nil, fmt.Errorf("failed to check triton server %s health: %w", server.ID, err) } - tritonClients[server.ID] = tritonClient + tritonServices[server.ID] = triton.NewService(tritonClient) } hooks := []Hook{ healthHandler, } - err = Build(mux, cfg, datastores, tritonClients, hooks, metrics, promReg) + err = Build(mux, cfg, datastores, tritonServices, hooks, metrics, promReg) if err != nil { return nil, err } diff --git a/service/platform/factory/factory.go b/service/platform/factory/factory.go index f899127..d176739 100644 --- a/service/platform/factory/factory.go +++ b/service/platform/factory/factory.go @@ -22,10 +22,10 @@ func CreateEvaluator( metrics *gmetric.Service, sema *semaphore.Weighted, maxEvaluatorWait time.Duration, - tritonClients map[string]triton.TritonClient, + tritonServices map[string]*triton.Service, ) (platform.PlatformEvaluator, error) { p := cfg.GetPlatform() - isRouter := cfg.Mode == "router" + isRouter := cfg.IsRouter() switch p { case "tensorflow": @@ -36,13 +36,13 @@ func CreateEvaluator( case "triton": if isRouter { makeEvaluator := func(modelName string) (platform.PlatformEvaluator, error) { - return triton.NewRoutedTritonEvaluator(modelName, cfg, tritonClients) + return triton.NewRoutedTritonEvaluator(modelName, cfg, tritonServices) } - return router.NewRouter(cfg, fs, tritonClients, makeEvaluator) + return router.NewRouter(cfg, fs, tritonServices, makeEvaluator) } - return triton.NewTritonEvaluator(cfg, tritonClients) + return triton.NewTritonEvaluator(cfg, tritonServices) default: return nil, fmt.Errorf("unsupported platform: %s for model %s", p, cfg.ID) } diff --git a/service/platform/router/router.go b/service/platform/router/router.go index 7fad23c..d585d0f 100644 --- a/service/platform/router/router.go +++ b/service/platform/router/router.go @@ -19,6 +19,7 @@ import ( "github.com/viant/mly/service/platform" "github.com/viant/mly/service/request/shape" tricli "github.com/viant/mly/service/triton" + "github.com/viant/mly/shared" "github.com/viant/mly/shared/common" "github.com/viant/mly/shared/config/router" "gopkg.in/yaml.v2" @@ -32,8 +33,8 @@ type IOState struct { routerInputOffset int } -type ModelUnloader interface { - ModelUnload(ctx context.Context, modelName string) error +type UnloadService interface { + UnloadModel(ctx context.Context, mlyModelID string, tritonModelName string) error } // Router implements the PlatformEvaluator interface for router mode. @@ -53,8 +54,8 @@ type Router struct { // - ioState routingTableLock sync.RWMutex - // routerConfig contains the last loaded router configuration - routerConfig *router.RouterConfig + // routingConfig contains the last loaded routing configuration + routingConfig *router.RoutingConfig hasGlobalModel bool makeRoutedEvaluator func(modelName string) (platform.PlatformEvaluator, error) @@ -80,24 +81,26 @@ type Router struct { routerName string debug bool - unloader ModelUnloader + unloader UnloadService - ioState *IOState + configuredInputs []*shared.Field + ioState *IOState } // NewRouter creates a new Router instance. // cfg is expected to be Init()'d and Validate()'d before calling this function. -func NewRouter(cfg *config.Model, fs afs.Service, tritonClients map[string]tricli.TritonClient, makeEvaluator func(modelName string) (platform.PlatformEvaluator, error)) (*Router, error) { - unloaders := make(map[string]ModelUnloader) - for serverID, tritonClient := range tritonClients { - unloaders[serverID] = tritonClient +// makeEvaluator is expected to register usage for every created Evaluator. +func NewRouter(cfg *config.Model, fs afs.Service, tritonServices map[string]*tricli.Service, makeEvaluator func(modelName string) (platform.PlatformEvaluator, error)) (*Router, error) { + unloaders := make(map[string]UnloadService) + for serverID, tritonService := range tritonServices { + unloaders[serverID] = tritonService } return newRouter(cfg, fs, unloaders, makeEvaluator) } // newRouter uses a map[string]ModelUnloader, where ModelUnloader is-a triton.TritonClient, for testing. -func newRouter(cfg *config.Model, fs afs.Service, unloaders map[string]ModelUnloader, makeEvaluator func(modelName string) (platform.PlatformEvaluator, error)) (*Router, error) { +func newRouter(cfg *config.Model, fs afs.Service, unloaders map[string]UnloadService, makeEvaluator func(modelName string) (platform.PlatformEvaluator, error)) (*Router, error) { if cfg.Router == nil { return nil, fmt.Errorf("router configuration is required") } @@ -116,6 +119,7 @@ func newRouter(cfg *config.Model, fs afs.Service, unloaders map[string]ModelUnlo } var err error + fixedEvaluator, err = newFixedEvaluator(cfg.Router.Global.PredictionReplacements) if err != nil { return nil, fmt.Errorf("failed to create fixed evaluator: %w", err) @@ -143,6 +147,9 @@ func newRouter(cfg *config.Model, fs afs.Service, unloaders map[string]ModelUnlo fixedEvaluator: fixedEvaluator, fixedEvaluatorFields: fixedEvaluatorFields, + + configuredInputs: cfg.Inputs, + routerInputFieldName: cfg.Router.InputName, } // spawn worker routines @@ -374,7 +381,7 @@ type modelSignature struct { // TODO refactor with service/tfmodel/service.isModified()? func (r *Router) isModified(snapshot *config.Modified) bool { - if r.routerConfig == nil || r.configModified == nil { + if r.routingConfig == nil || r.configModified == nil { return true } @@ -481,7 +488,7 @@ func (r *Router) ReloadIfNeeded(ctx context.Context) error { } } - newConfig := new(router.RouterConfig) + newConfig := new(router.RoutingConfig) // TODO move this check earlier if strings.Contains(r.configURL, ".yaml") { @@ -505,13 +512,13 @@ func (r *Router) ReloadIfNeeded(ctx context.Context) error { } // applyRouterConfig will both update evaluators to new configuration state and verify and build the signature -func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.RouterConfig) error { +func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.RoutingConfig) error { modelsToUnload := make(map[string]struct{}) reuseEvaluators := make(map[string]platform.PlatformEvaluator) var reuseGlobal platform.PlatformEvaluator var finalSignature *domain.Signature - var oldConfig *router.RouterConfig + var oldConfig *router.RoutingConfig func() { r.routingTableLock.RLock() defer r.routingTableLock.RUnlock() @@ -520,7 +527,7 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Router } reuseGlobal = r.globalModel - oldConfig = r.routerConfig + oldConfig = r.routingConfig }() if oldConfig != nil { @@ -672,15 +679,15 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Router // accept first available signature as the final signature if finalSignature == nil { // in the creation of the signature, include the routing input - + // DANGER: this uses the pointer to the signature, so since the signature is modified, the original signature will be modified! + // This doesn't happen in practice, but can cause issues in tests. finalSignature = signature.signature inputOffset := len(finalSignature.Inputs) routerInput := domain.Input{ - Name: r.routerInputFieldName, - Index: inputOffset, - Type: reflect.TypeOf(int64(0)), + Name: r.routerInputFieldName, + Type: reflect.TypeOf(int64(0)), } ioState.routerInputOffset = inputOffset @@ -691,6 +698,25 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Router sigInputMap[input.Name] = &input } + for _, input := range r.configuredInputs { + _, ok := sigInputMap[input.Name] + + if ok { + // the input is configured and already in the self-reported signature + continue + } + + if !input.Auxiliary { + return fmt.Errorf("non-auxiliary input %s for model %s was not in model inputs", input.Name, signature.name) + } + + sigInputMap[input.Name] = &domain.Input{ + Name: input.Name, + Type: input.RawType(), + Auxiliary: input.Auxiliary, + } + } + if r.modelOutputName != "" { // also, add the selected model output modelOutput := domain.Output{ @@ -745,6 +771,10 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Router thisSignatureInputMap[input.Name] = &input + if oldInput.Auxiliary { + continue + } + // TODO permit this if oldInput.Index != input.Index { return fmt.Errorf("signature input %s for model %s has index %d, and the previous signature has index %d", input.Name, signature.name, input.Index, oldInput.Index) @@ -790,7 +820,7 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Router r.routingTableLock.Lock() defer r.routingTableLock.Unlock() - r.routerConfig = newConfig + r.routingConfig = newConfig r.routingMap = newModelMapping r.routingTable = newRoutingTable @@ -810,6 +840,8 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Router ctxTo, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() + r.debugLogf("request to unload model: %s", modelName) + if err := r.unloadModel(ctxTo, modelName); err != nil { r.debugLogf("failed to unload model %s: %v\n", modelName, err) } @@ -821,7 +853,7 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Router func (r *Router) unloadModel(ctx context.Context, modelName string) error { defer routerModelUnloadGauge.WithLabelValues(r.routerName).Dec() - if err := r.unloader.ModelUnload(ctx, modelName); err != nil { + if err := r.unloader.UnloadModel(ctx, r.routerName, modelName); err != nil { return fmt.Errorf("failed to unload model %s: %w", modelName, err) } return nil diff --git a/service/platform/router/router_test.go b/service/platform/router/router_test.go index 50458d0..035195c 100644 --- a/service/platform/router/router_test.go +++ b/service/platform/router/router_test.go @@ -10,9 +10,11 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/viant/mly/service/config" "github.com/viant/mly/service/domain" "github.com/viant/mly/service/platform" + "github.com/viant/mly/service/triton" "github.com/viant/mly/shared/common" sharedrouter "github.com/viant/mly/shared/config/router" ) @@ -26,17 +28,34 @@ type mockTritonServer struct { modelLoadErr map[string]error } +func (m *mockTritonServer) ModelLoad(ctx context.Context, modelName string) error { + m.mu.Lock() + defer m.mu.Unlock() + + if err := m.modelLoadErr[modelName]; err != nil { + return err + } + + if m.readyState == nil { + m.readyState = make(map[string]bool) + } + + m.readyState[modelName] = true + + return nil +} + type mockEvaluator struct { tritonServer *mockTritonServer modelName string - signature *domain.Signature + signature func() *domain.Signature } func (m *mockEvaluator) Predict(ctx context.Context, params []interface{}) ([]interface{}, error) { - inputs := m.signature.Inputs + inputs := m.signature().Inputs if len(inputs) != len(params) { - return nil, fmt.Errorf("expected %d inputs, got %d", len(inputs), len(params)) + return nil, fmt.Errorf("mock error: expected %d inputs, got %d", len(inputs), len(params)) } // params is expected to have a single non-router input in this test: [][]string with shape [1][1] @@ -46,14 +65,14 @@ func (m *mockEvaluator) Predict(ctx context.Context, params []interface{}) ([]in v = typed[0][0] default: tval := reflect.TypeOf(params[0]) - panic("unexpected input type in mockPredictOnly: " + tval.String()) + panic("unexpected input type in mock Predict(): " + tval.String()) } // simple function: length of string as float32 out := [][]float32{{float32(len(v))}} return []interface{}{out}, nil } -func (m *mockEvaluator) Signature() *domain.Signature { return m.signature } +func (m *mockEvaluator) Signature() *domain.Signature { return m.signature() } func (m *mockEvaluator) Dictionary() *common.Dictionary { return nil } func (m *mockEvaluator) Inputs() map[string]*domain.Input { return nil } func (m *mockEvaluator) Stats(map[string]interface{}) {} @@ -64,19 +83,7 @@ func (m *mockEvaluator) ReloadIfNeeded(ctx context.Context) error { return nil } - m.tritonServer.mu.Lock() - defer m.tritonServer.mu.Unlock() - - if err := m.tritonServer.modelLoadErr[m.modelName]; err != nil { - return err - } - - if m.tritonServer.readyState == nil { - m.tritonServer.readyState = make(map[string]bool) - } - - m.tritonServer.readyState[m.modelName] = true - return nil + return m.tritonServer.ModelLoad(ctx, m.modelName) } type mockUnloader struct { @@ -84,7 +91,7 @@ type mockUnloader struct { unloadCh chan string } -func (m *mockUnloader) ModelUnload(ctx context.Context, modelName string) error { +func (m *mockUnloader) ModelUnload(ctx context.Context, tritonModelName string) error { if m.tritonServer != nil { m.tritonServer.mu.Lock() defer m.tritonServer.mu.Unlock() @@ -93,13 +100,13 @@ func (m *mockUnloader) ModelUnload(ctx context.Context, modelName string) error m.tritonServer.readyState = make(map[string]bool) } - m.tritonServer.readyState[modelName] = false + m.tritonServer.readyState[tritonModelName] = false } ch := m.unloadCh if ch != nil { - ch <- modelName + ch <- tritonModelName } return nil @@ -250,10 +257,10 @@ func TestRouter_Predict(t *testing.T) { cfg.Router.MaxQueueSize = 1000 cfg.Router.Workers = 3 - router, err := newRouter(cfg, nil, map[string]ModelUnloader{ - "test_server": &mockUnloader{}, + router, err := newRouter(cfg, nil, map[string]UnloadService{ + "test_server": &triton.Service{}, }, func(modelName string) (platform.PlatformEvaluator, error) { - return &mockEvaluator{signature: downstreamSignature}, nil + return &mockEvaluator{signature: func() *domain.Signature { return downstreamSignature }}, nil }) if err != nil { @@ -291,7 +298,7 @@ func TestRouter_Predict(t *testing.T) { }, } - mockEval := &mockEvaluator{signature: downstreamSignature} + mockEval := &mockEvaluator{signature: func() *domain.Signature { return downstreamSignature }} router.routingTable = map[string]platform.PlatformEvaluator{ "model1": mockEval, "model2": mockEval, @@ -319,7 +326,7 @@ func TestRouter_applyRouterConfig_LoadsAndSwaps(t *testing.T) { unloadCh: make(chan string, 2), } - oldConfig := &sharedrouter.RouterConfig{ + oldConfig := &sharedrouter.RoutingConfig{ EntityMapping: []sharedrouter.EntityKV{ {EntityID: 1, ModelName: "modelA"}, {EntityID: 2, ModelName: "modelB"}, @@ -327,36 +334,38 @@ func TestRouter_applyRouterConfig_LoadsAndSwaps(t *testing.T) { GlobalModelName: "global-old", } - signature := &domain.Signature{ - Inputs: []domain.Input{ - {Name: "text", Index: 0, Type: reflect.TypeOf("")}, - }, - Outputs: []domain.Output{ - {Name: "score", Index: 0, DataType: "float32"}, - }, + makeSig := func() *domain.Signature { + return &domain.Signature{ + Inputs: []domain.Input{ + {Name: "text", Index: 0, Type: reflect.TypeOf("")}, + }, + Outputs: []domain.Output{ + {Name: "score", Index: 0, DataType: "float32"}, + }, + } } router := &Router{ - unloader: mockClient, - routerConfig: oldConfig, + unloader: &triton.Service{Unloader: mockClient, Repository: triton.NewRepository()}, + routingConfig: oldConfig, routingMap: map[int]string{ 1: "modelA", 2: "modelB", }, routingTable: map[string]platform.PlatformEvaluator{ - "modelA": &mockEvaluator{signature: signature}, - "modelB": &mockEvaluator{signature: signature}, + "modelA": &mockEvaluator{signature: makeSig}, + "modelB": &mockEvaluator{signature: makeSig}, }, globalModel: &mockEvaluator{}, debug: true, makeRoutedEvaluator: func(modelName string) (platform.PlatformEvaluator, error) { - return &mockEvaluator{signature: signature}, nil + return &mockEvaluator{signature: makeSig}, nil }, } reusedModelB := router.routingTable["modelB"] - newConfig := &sharedrouter.RouterConfig{ + newConfig := &sharedrouter.RoutingConfig{ EntityMapping: []sharedrouter.EntityKV{ {EntityID: 1, ModelName: "modelB"}, {EntityID: 3, ModelName: "modelC"}, @@ -370,7 +379,7 @@ func TestRouter_applyRouterConfig_LoadsAndSwaps(t *testing.T) { waitForCalls(t, mockClient.unloadCh, 2) - if router.routerConfig != newConfig { + if router.routingConfig != newConfig { t.Fatalf("routerConfig pointer not updated") } @@ -416,7 +425,7 @@ func TestRouter_applyRouterConfig_LoadError(t *testing.T) { mockClient := &mockUnloader{} - oldConfig := &sharedrouter.RouterConfig{ + oldConfig := &sharedrouter.RoutingConfig{ EntityMapping: []sharedrouter.EntityKV{ {EntityID: 1, ModelName: "modelA"}, }, @@ -435,8 +444,8 @@ func TestRouter_applyRouterConfig_LoadError(t *testing.T) { debug: true, routerName: "load_error", - unloader: mockClient, - routerConfig: oldConfig, + unloader: &triton.Service{Unloader: mockClient}, + routingConfig: oldConfig, routingMap: map[int]string{ 1: "modelA", }, @@ -445,12 +454,12 @@ func TestRouter_applyRouterConfig_LoadError(t *testing.T) { return &mockEvaluator{ tritonServer: tritonServer, modelName: modelName, - signature: signature, + signature: func() *domain.Signature { return signature }, }, nil }, } - newConfig := &sharedrouter.RouterConfig{ + newConfig := &sharedrouter.RoutingConfig{ EntityMapping: []sharedrouter.EntityKV{ {EntityID: 2, ModelName: "modelX"}, }, @@ -465,7 +474,7 @@ func TestRouter_applyRouterConfig_LoadError(t *testing.T) { t.Fatalf("expected error mentioning modelX, got %v", err) } - if router.routerConfig != oldConfig { + if router.routingConfig != oldConfig { t.Fatalf("routerConfig should remain unchanged on error") } @@ -478,51 +487,258 @@ func TestRouter_applyRouterConfig_LoadError(t *testing.T) { } } -func TestRouter_applyRouterConfig_SkipsLoadWhenReady(t *testing.T) { +func TestRouter_applyRouterConfig_signature(t *testing.T) { ctx := context.Background() - mockClient := &mockUnloader{} + mockClient := &mockUnloader{ + unloadCh: make(chan string, 1), + tritonServer: &mockTritonServer{}, + } + + makeSig := func() *domain.Signature { + return &domain.Signature{ + Inputs: []domain.Input{ + {Name: "text", Index: 0, Type: reflect.TypeOf("")}, + }, + Outputs: []domain.Output{ + {Name: "score", Index: 0, DataType: "float32"}, + }, + } + } - oldConfig := &sharedrouter.RouterConfig{ + cfg := &config.Model{ + ID: "test_signature", + Debug: true, + Mode: "router", + Platform: "triton", + Router: &config.RouterConfig{ + ConfigURL: "memory://router-config", + InputName: "router_id", + Global: config.GlobalModelConfig{ + Exists: true, + }, + Output: config.OutputConfig{ + FieldName: "model_id", + }, + }, + Triton: &config.TritonConfig{ + ServerID: "test_server", + }, + } + + cfg.Init(nil) + + var router *Router + var err error + + router, err = newRouter(cfg, nil, map[string]UnloadService{ + "test_server": &triton.Service{Unloader: mockClient}, + }, func(modelName string) (platform.PlatformEvaluator, error) { + return &mockEvaluator{signature: makeSig}, nil + }) + + if err != nil { + t.Fatalf("NewRouter error: %v", err) + } + + newConfig := &sharedrouter.RoutingConfig{ EntityMapping: []sharedrouter.EntityKV{ - {EntityID: 1, ModelName: "modelA"}, + {EntityID: 1, ModelName: "model1"}, + {EntityID: 2, ModelName: "model2"}, }, + GlobalModelName: "global-model", } - signature := &domain.Signature{ - Inputs: []domain.Input{ - {Name: "text", Index: 0, Type: reflect.TypeOf("")}, + if err := router.applyRouterConfig(ctx, newConfig); err != nil { + t.Fatalf("applyRouterConfig returned error: %v", err) + } + + detectedInputs := map[string]struct{}{} + for _, input := range router.ioState.signature.Inputs { + detectedInputs[input.Name] = struct{}{} + } + + expectedInputs := map[string]struct{}{ + "text": {}, + "router_id": {}, + } + + assert.Equal(t, expectedInputs, detectedInputs) + assert.Equal(t, 1, router.ioState.routerInputOffset) + + params := []interface{}{ + [][]string{{"a"}, {"abcd"}}, // text + [][]int64{{1}, {2}}, // router_id + } + + results, err := router.Predict(ctx, params) + if err != nil { + t.Fatalf("Predict error: %v", err) + } + + assert.Equal(t, 2, len(results)) +} + +type wrappedUnloader struct { + tritonService *triton.Service + + wg *sync.WaitGroup +} + +func (w *wrappedUnloader) UnloadModel(ctx context.Context, mlyModelID string, tritonModelName string) error { + defer w.wg.Done() + return w.tritonService.UnloadModel(ctx, mlyModelID, tritonModelName) +} + +func TestRouter_applyRouterConfig_sharedTritonServer(t *testing.T) { + tritonServer := &mockTritonServer{} + modelUnloader := &mockUnloader{ + tritonServer: tritonServer, + } + + repository := triton.NewRepository() + tritonService := &triton.Service{ + Unloader: modelUnloader, + Repository: repository, + } + + wrappedService := &wrappedUnloader{ + tritonService: tritonService, + wg: &sync.WaitGroup{}, + } + + unloaders := map[string]UnloadService{ + "test_server": wrappedService, + } + + cfgA := &config.Model{ + ID: "test_shared_a", + Debug: true, + Mode: "router", + Platform: "triton", + Router: &config.RouterConfig{ + ConfigURL: "memory://router-config", + Global: config.GlobalModelConfig{ + PredictionReplacements: []config.PredictionReplacement{ + { + Name: "score", + Type: "float32", + Value: 0.0, + }, + }, + }, }, - Outputs: []domain.Output{ - {Name: "score", Index: 0, DataType: "float32"}, + Triton: &config.TritonConfig{ + ServerID: "test_server", }, } - router := &Router{ - unloader: mockClient, - routerConfig: oldConfig, - routingMap: map[int]string{ - 1: "modelA", - }, - routingTable: map[string]platform.PlatformEvaluator{ - "modelA": &mockEvaluator{signature: signature}, + cfgB := &config.Model{ + ID: "test_shared_b", + Debug: true, + Mode: "router", + Platform: "triton", + Router: &config.RouterConfig{ + ConfigURL: "memory://router-config", + Global: config.GlobalModelConfig{ + PredictionReplacements: []config.PredictionReplacement{ + { + Name: "score", + Type: "float32", + Value: 0.0, + }, + }, + }, }, - makeRoutedEvaluator: func(modelName string) (platform.PlatformEvaluator, error) { - return &mockEvaluator{signature: signature}, nil + Triton: &config.TritonConfig{ + ServerID: "test_server", }, } - newConfig := &sharedrouter.RouterConfig{ + cfgA.Init(nil) + cfgB.Init(nil) + + newSig := func() *domain.Signature { + return &domain.Signature{ + Inputs: []domain.Input{ + {Name: "text", Index: 0, Type: reflect.TypeOf("")}, + }, + Outputs: []domain.Output{ + {Name: "score", Index: 0, DataType: "float32"}, + }, + } + } + + routerA, err := newRouter(cfgA, nil, unloaders, func(modelName string) (platform.PlatformEvaluator, error) { + tritonService.RegisterUsage(cfgA.ID, modelName) + return &mockEvaluator{modelName: modelName, signature: newSig, tritonServer: tritonServer}, nil + }) + + if err != nil { + t.Fatalf("newRouter returned error: %v", err) + } + + routerB, err := newRouter(cfgB, nil, unloaders, func(modelName string) (platform.PlatformEvaluator, error) { + tritonService.RegisterUsage(cfgB.ID, modelName) + return &mockEvaluator{modelName: modelName, signature: newSig, tritonServer: tritonServer}, nil + }) + + if err != nil { + t.Fatalf("newRouter returned error: %v", err) + } + + ctx := context.Background() + + // establish mappings for A and B using the same models + + if err = routerA.applyRouterConfig(ctx, &sharedrouter.RoutingConfig{ EntityMapping: []sharedrouter.EntityKV{ {EntityID: 1, ModelName: "modelA"}, {EntityID: 2, ModelName: "modelC"}, }, + }); err != nil { + t.Fatalf("applyRouterConfig A initial returned error: %v", err) } - if err := router.applyRouterConfig(ctx, newConfig); err != nil { - t.Fatalf("applyRouterConfig returned error: %v", err) + if err = routerB.applyRouterConfig(ctx, &sharedrouter.RoutingConfig{ + EntityMapping: []sharedrouter.EntityKV{ + {EntityID: 1, ModelName: "modelA"}, + }, + }); err != nil { + t.Fatalf("applyRouterConfig B inital returned error: %v", err) } - if _, ok := router.routingTable["modelC"]; !ok { - t.Fatalf("routingTable missing modelC after reload") + // we expect model A and model C to be attempted to be unloaded + wrappedService.wg.Add(2) + + // routerA will now get a different mapping + if err = routerA.applyRouterConfig(ctx, &sharedrouter.RoutingConfig{ + EntityMapping: []sharedrouter.EntityKV{ + {EntityID: 1, ModelName: "modelB"}, + }, + }); err != nil { + t.Fatalf("applyRouterConfig A reload returned error: %v", err) + } + + wrappedService.wg.Wait() + + assert.True(t, tritonServer.readyState["modelA"], "modelA should still be loaded") + assert.False(t, tritonServer.readyState["modelC"], "modelC should be unloaded") + + // we expect model A to be attempted to be unloaded + wrappedService.wg.Add(1) + + if err = routerB.applyRouterConfig(ctx, &sharedrouter.RoutingConfig{ + EntityMapping: []sharedrouter.EntityKV{ + {EntityID: 1, ModelName: "modelB"}, + {EntityID: 2, ModelName: "modelC"}, + }, + }); err != nil { + t.Fatalf("applyRouterConfig B reload returned error: %v", err) } + + wrappedService.wg.Wait() + + assert.False(t, tritonServer.readyState["modelA"], "modelA should be unloaded") + assert.True(t, tritonServer.readyState["modelB"], "modelB should still be loaded") + assert.True(t, tritonServer.readyState["modelC"], "modelC should be loaded") } diff --git a/service/service.go b/service/service.go index 0080882..949fcb5 100644 --- a/service/service.go +++ b/service/service.go @@ -348,7 +348,7 @@ func New( fs afs.Service, metrics *gmetric.Service, datastores map[string]*datastore.Service, - tritonClients map[string]triton.TritonClient, + tritonServices map[string]*triton.Service, sema *semaphore.Weighted, maxEvaluatorWait time.Duration, options ...Option, @@ -363,7 +363,7 @@ func New( cfg.Init(nil) // Create platform evaluator context - evaluatorContext, err := factory.CreateEvaluator(cfg, fs, metrics, sema, maxEvaluatorWait, tritonClients) + evaluatorContext, err := factory.CreateEvaluator(cfg, fs, metrics, sema, maxEvaluatorWait, tritonServices) if err != nil { return nil, fmt.Errorf("failed to create platform evaluator for model %s: %w", cfg.ID, err) } diff --git a/service/triton/client.go b/service/triton/client.go index ec93f3c..f19502a 100644 --- a/service/triton/client.go +++ b/service/triton/client.go @@ -12,6 +12,7 @@ import ( "google.golang.org/grpc/credentials/insecure" ) +// A TritonClient represents a client to a single Triton server. type TritonClient interface { ServerReady(ctx context.Context) error @@ -23,13 +24,17 @@ type TritonClient interface { ModelLoad(ctx context.Context, modelName string) error - ModelUnload(ctx context.Context, modelName string) error + ModelUnloader ModelMetadata(ctx context.Context, modelName string) (*ModelMetadata, error) Close() error } +type ModelUnloader interface { + ModelUnload(ctx context.Context, modelName string) error +} + // https://github.com/kserve/kserve/blob/master/docs/predict-api/v2/required_api.md#model-metadata-response-json-object `$metadata_tensor` type MetadataTensor struct { Name string `json:"name"` diff --git a/service/triton/evaluator.go b/service/triton/evaluator.go index b724b27..feb44ef 100644 --- a/service/triton/evaluator.go +++ b/service/triton/evaluator.go @@ -13,10 +13,9 @@ import ( "github.com/viant/mly/shared/common" ) -// TritonEvaluator implements PlatformEvaluator for Triton Inference Server via gRPC +// TritonEvaluator implements service/platform.PlatformEvaluator. type TritonEvaluator struct { - client TritonClient - + service *Service modelName string // if true, this client is used only for this instance @@ -30,7 +29,7 @@ type TritonEvaluator struct { signature *domain.Signature - // maps feeds index to input name + // maps Feeds index to input name indexToName map[int]string configuredInputs []*shared.Field @@ -39,30 +38,35 @@ type TritonEvaluator struct { } // NewTritonEvaluator creates a new Triton evaluator -func NewTritonEvaluator(config *config.Model, tritonClients map[string]TritonClient) (*TritonEvaluator, error) { - var client TritonClient +func NewTritonEvaluator(config *config.Model, tritonClients map[string]*Service) (*TritonEvaluator, error) { + var service *Service isPrivateClient := config.URL != "" timeout := time.Duration(config.Triton.Timeout) * time.Millisecond if isPrivateClient { // "Private" URL configuration will only support HTTP - client = &HTTPClient{ + client := &HTTPClient{ httpClient: &http.Client{ Timeout: timeout, }, serverURL: config.URL, debug: config.Debug, } + + service = &Service{ + Client: client, + } } else { - client = tritonClients[config.Triton.ServerID] - if client == nil { + service = tritonClients[config.Triton.ServerID] + if service == nil { return nil, fmt.Errorf("client not found for Triton, server ID: %s", config.Triton.ServerID) } } evaluator := &TritonEvaluator{ - client: client, + service: service, + modelName: config.Triton.ModelName, timeout: timeout, @@ -77,11 +81,16 @@ func NewTritonEvaluator(config *config.Model, tritonClients map[string]TritonCli debug: config.Debug, } + err := evaluator.registerUsage() + if err != nil { + return nil, fmt.Errorf("failed to register usage for Triton evaluator: %w", err) + } + return evaluator, nil } // Upward dependency, but provides Evaluators as needed for the service/platform/router module. -func NewRoutedTritonEvaluator(modelName string, config *config.Model, tritonClients map[string]TritonClient) (*TritonEvaluator, error) { +func NewRoutedTritonEvaluator(modelName string, config *config.Model, tritonClients map[string]*Service) (*TritonEvaluator, error) { evaluator, err := NewTritonEvaluator(config, tritonClients) if err != nil { return nil, fmt.Errorf("failed to create Triton Routed evaluator: %w", err) @@ -90,9 +99,23 @@ func NewRoutedTritonEvaluator(modelName string, config *config.Model, tritonClie evaluator.modelName = modelName evaluator.configuredInputs = nil // routed evaluators must not have any additional inputs + err = evaluator.registerUsage() + if err != nil { + return nil, fmt.Errorf("failed to register usage for Triton Routed evaluator: %w", err) + } + return evaluator, nil } +func (t *TritonEvaluator) registerUsage() error { + if t.modelName == "" { + return fmt.Errorf("model name is required for registering usage") + } + + t.service.RegisterUsage(t.modelID, t.modelName) + return nil +} + // Predict performs inference via Triton Inference Server func (t *TritonEvaluator) Predict(ctx context.Context, params []interface{}) ([]interface{}, error) { if len(params) == 0 { @@ -106,7 +129,7 @@ func (t *TritonEvaluator) Predict(ctx context.Context, params []interface{}) ([] defer cancel() } - return t.client.ModelInfer(requestCtx, t.modelName, params, t.indexToName) + return t.service.Client.ModelInfer(requestCtx, t.modelName, params, t.indexToName) } func (t *TritonEvaluator) Signature() *domain.Signature { @@ -129,7 +152,7 @@ func (t *TritonEvaluator) Inputs() map[string]*domain.Input { // Close releases Triton client resources and stops health monitoring func (t *TritonEvaluator) Close() error { if t.isPrivateClient { - return t.client.Close() + return t.service.Client.Close() } return nil @@ -137,7 +160,7 @@ func (t *TritonEvaluator) Close() error { // For independent Triton server models, reloading is not supported. func (t *TritonEvaluator) ReloadIfNeeded(ctx context.Context) error { - ready, err := t.client.ModelReady(ctx, t.modelName) + ready, err := t.service.Client.ModelReady(ctx, t.modelName) if err != nil { return fmt.Errorf("failed to check Triton model %s health: %w", t.modelName, err) } @@ -152,12 +175,12 @@ func (t *TritonEvaluator) ReloadIfNeeded(ctx context.Context) error { return fmt.Errorf("model %s not ready and Triton is not in EXPLICIT Model Control Mode: %w", t.modelName, err) } - err = t.client.ModelLoad(ctx, t.modelName) + err = t.service.Client.ModelLoad(ctx, t.modelName) if err != nil { return fmt.Errorf("failed to load Triton model %s: %w", t.modelName, err) } - ready, err = t.client.ModelReady(ctx, t.modelName) + ready, err = t.service.Client.ModelReady(ctx, t.modelName) if err != nil { return fmt.Errorf("failed to check Triton model %s health after loading: %w", t.modelName, err) } @@ -168,7 +191,7 @@ func (t *TritonEvaluator) ReloadIfNeeded(ctx context.Context) error { } // we need to get the model metadata and consolidate the signature - metadata, err := t.client.ModelMetadata(ctx, t.modelName) + metadata, err := t.service.Client.ModelMetadata(ctx, t.modelName) if err != nil || metadata == nil { return fmt.Errorf("failed to get Triton model %s metadata: %w", t.modelName, err) } diff --git a/service/triton/evaluator_test.go b/service/triton/evaluator_test.go index eecc394..b138556 100644 --- a/service/triton/evaluator_test.go +++ b/service/triton/evaluator_test.go @@ -83,13 +83,13 @@ func (m *mockTritonClient) ModelMetadata(ctx context.Context, modelName string) func (m *mockTritonClient) Close() error { return nil } func newTritonEvaluator(cfg *config.Model, mockClient *mockTritonClient) *TritonEvaluator { - cfg.Triton.Init() + cfg.Triton.Init(cfg.IsRouter()) evaluator := &TritonEvaluator{ modelName: cfg.Triton.ModelName, isPrivateClient: true, repositoryExplicit: false, - client: mockClient, + service: &Service{Client: mockClient}, configuredInputs: cfg.MetaInput.Inputs, } diff --git a/service/triton/repository.go b/service/triton/repository.go new file mode 100644 index 0000000..a28e975 --- /dev/null +++ b/service/triton/repository.go @@ -0,0 +1,48 @@ +package triton + +import ( + "sync" +) + +type mlyModelID string +type tritonModelName string + +type Repository struct { + mu sync.Mutex + usage map[tritonModelName]map[mlyModelID]struct{} +} + +func (r *Repository) RegisterUsage(mlyID mlyModelID, tritonName tritonModelName) { + r.mu.Lock() + defer r.mu.Unlock() + mlyUsages, ok := r.usage[tritonName] + if !ok { + mlyUsages = make(map[mlyModelID]struct{}) + r.usage[tritonName] = mlyUsages + } + + mlyUsages[mlyID] = struct{}{} +} + +// UnregisterUsage returns true if all usages of a model have been unregistered. +// The TritonClient should then actual unload the model on the server. +func (r *Repository) UnregisterUsage(mlyID mlyModelID, tritonName tritonModelName) bool { + r.mu.Lock() + defer r.mu.Unlock() + + mlyUsages, ok := r.usage[tritonName] + if !ok { + // this was never registered, so this is considered having been unregistered. + return true + } + + delete(mlyUsages, mlyID) + return len(mlyUsages) == 0 +} + +func NewRepository() *Repository { + return &Repository{ + usage: make(map[tritonModelName]map[mlyModelID]struct{}), + mu: sync.Mutex{}, + } +} diff --git a/service/triton/service.go b/service/triton/service.go new file mode 100644 index 0000000..b5c3f16 --- /dev/null +++ b/service/triton/service.go @@ -0,0 +1,46 @@ +package triton + +import ( + "context" +) + +// Service is a container for a client and a representation of model repository management. +type Service struct { + Client TritonClient + + Unloader ModelUnloader + Repository *Repository +} + +func (s *Service) RegisterUsage(mlyID string, tritonName string) { + if s.Repository == nil { + return + } + + s.Repository.RegisterUsage(mlyModelID(mlyID), tritonModelName(tritonName)) +} + +func NewService(client TritonClient) *Service { + return &Service{ + Client: client, + Unloader: client, + Repository: NewRepository(), + } +} + +func (s *Service) UnloadModel(ctx context.Context, mlyID string, tritonName string) error { + if s.Repository == nil { + return nil + } + + if s.Unloader == nil { + return nil + } + + shouldUnload := s.Repository.UnregisterUsage(mlyModelID(mlyID), tritonModelName(tritonName)) + if shouldUnload { + return s.Unloader.ModelUnload(ctx, tritonName) + } + + return nil +} diff --git a/shared/config/router/router.go b/shared/config/router/router.go index 58bd880..4a49743 100644 --- a/shared/config/router/router.go +++ b/shared/config/router/router.go @@ -1,6 +1,6 @@ package router -type RouterConfig struct { +type RoutingConfig struct { EntityMapping []EntityKV `json:"entityMapping" yaml:"entityMapping"` GlobalModelName string `json:"globalModelName" yaml:"globalModelName"` diff --git a/shared/config/router/router_test.go b/shared/config/router/router_test.go index e77c58b..c5aeb3c 100644 --- a/shared/config/router/router_test.go +++ b/shared/config/router/router_test.go @@ -8,7 +8,7 @@ import ( ) func TestJSON_EncodeDecode_WithGlobal(t *testing.T) { - cfg := &RouterConfig{ + cfg := &RoutingConfig{ EntityMapping: []EntityKV{ {EntityID: 12345, ModelName: "roas_model_12345_202511121116"}, {EntityID: 12347, ModelName: "roas_model_12347_202511111116"}, @@ -22,7 +22,7 @@ func TestJSON_EncodeDecode_WithGlobal(t *testing.T) { expected := `{"entityMapping":[{"entityID":12345,"modelName":"roas_model_12345_202511121116"},{"entityID":12347,"modelName":"roas_model_12347_202511111116"}],"globalModelName":"roas_global_202511111116"}` require.Equal(t, expected, string(data)) - var decoded RouterConfig + var decoded RoutingConfig require.NoError(t, json.Unmarshal(data, &decoded)) require.Equal(t, cfg.GlobalModelName, decoded.GlobalModelName) @@ -35,7 +35,7 @@ func TestJSON_EncodeDecode_WithGlobal(t *testing.T) { func TestJSON_Decode_NoGlobal(t *testing.T) { data := []byte(`{"entityMapping":[{"entityID":1,"modelName":"m1"}]}`) - var cfg RouterConfig + var cfg RoutingConfig require.NoError(t, json.Unmarshal(data, &cfg)) require.Empty(t, cfg.GlobalModelName) require.Len(t, cfg.EntityMapping, 1) @@ -45,7 +45,7 @@ func TestJSON_Decode_NoGlobal(t *testing.T) { func TestJSON_Decode_EmptyArray(t *testing.T) { data := []byte(`{"entityMapping":[]}`) - var cfg RouterConfig + var cfg RoutingConfig require.NoError(t, json.Unmarshal(data, &cfg)) require.NotNil(t, cfg.EntityMapping) require.Len(t, cfg.EntityMapping, 0) @@ -53,6 +53,6 @@ func TestJSON_Decode_EmptyArray(t *testing.T) { func TestJSON_Decode_InvalidEntityID(t *testing.T) { data := []byte(`{"entityMapping":[{"entityID":"oops","modelName":"x"}]}`) - var cfg RouterConfig + var cfg RoutingConfig require.Error(t, json.Unmarshal(data, &cfg)) } From d922d1dd497eef05f0722264911045a543354d83 Mon Sep 17 00:00:00 2001 From: David Choi Date: Mon, 8 Dec 2025 11:13:01 -0800 Subject: [PATCH 13/50] Retroactively remove support for stating float as an alias for float32 --- service/platform/router/fixed.go | 4 +--- shared/common/type.go | 2 +- shared/field.go | 4 ---- 3 files changed, 2 insertions(+), 8 deletions(-) diff --git a/service/platform/router/fixed.go b/service/platform/router/fixed.go index 7fdc0fd..216d739 100644 --- a/service/platform/router/fixed.go +++ b/service/platform/router/fixed.go @@ -71,7 +71,7 @@ func newFixedEvaluator(repls []config.PredictionReplacement) (*fixedEvaluator, e default: return fmt.Errorf("router replacement %q: value %T not coercible to int64", r.Name, r.Value) } - case "float", "float32": + case "float32": switch n := r.Value.(type) { case int: pr = preparedReplacement{typ: "float32", value: float32(n)} @@ -182,8 +182,6 @@ func (f *fixedEvaluator) Predict(ctx context.Context, params []interface{}) ([]i results[i] = makeInt32(repl.value.(int32)) case "int64": results[i] = makeInt64(repl.value.(int64)) - case "float": - results[i] = makeFloat32(repl.value.(float32)) case "float32": results[i] = makeFloat32(repl.value.(float32)) case "float64": diff --git a/shared/common/type.go b/shared/common/type.go index d3ff745..697f77c 100644 --- a/shared/common/type.go +++ b/shared/common/type.go @@ -14,7 +14,7 @@ func DataType(dataType string) (reflect.Type, error) { return reflect.TypeOf(""), nil case "float64": return reflect.TypeOf(float64(0)), nil - case "float32", "float": + case "float32": return reflect.TypeOf(float32(0)), nil case "int": return reflect.TypeOf(int(0)), nil diff --git a/shared/field.go b/shared/field.go index bbcc4c2..66c9e43 100644 --- a/shared/field.go +++ b/shared/field.go @@ -15,7 +15,6 @@ type ( Index int // The type of the field. - // Supports "float" which maps to float32. // Otherwise, refer to reflect.Type.Name(). DataType string `json:",omitempty" yaml:",omitempty"` @@ -78,9 +77,6 @@ func (f *Field) DataTypeToRawType() { // fieldDataTypeToRawType is a subset of reverse Name() to reflect.Type func fieldDataTypeToRawType(dataType string) reflect.Type { switch dataType { - case "float": - // provided as a convenience - return reflect.TypeOf(float32(0)) case "": // this case is treated as string in common.DataType(), but here it's not OK. panic(fmt.Sprintf("unsupported data type: %s", dataType)) From 34bb186add3e91ffabbcb019412452ad3c15ea79 Mon Sep 17 00:00:00 2001 From: David Choi Date: Mon, 26 Jan 2026 11:42:56 -0800 Subject: [PATCH 14/50] Add Prometheus metrics to client. --- shared/client/marshal.go | 1 + shared/client/prometheus.go | 311 +++++++++++++++++++++++++++++ shared/client/service.go | 75 ++++++- shared/common/error.go | 16 +- shared/datastore/client/service.go | 51 ++++- shared/datastore/service.go | 3 +- shared/datastore/stores.go | 2 +- shared/stat/buckets/prometheus.go | 13 +- shared/stat/promc/error.go | 25 +++ 9 files changed, 474 insertions(+), 23 deletions(-) create mode 100644 shared/client/prometheus.go create mode 100644 shared/stat/promc/error.go diff --git a/shared/client/marshal.go b/shared/client/marshal.go index ca91fa5..9a144bc 100644 --- a/shared/client/marshal.go +++ b/shared/client/marshal.go @@ -7,6 +7,7 @@ import ( "github.com/francoispqt/gojay" ) +// Deprecated. No need to be exported. func Marshal(data interface{}, id string) ([]byte, error) { if data == nil { return nil, fmt.Errorf("data was nil") diff --git a/shared/client/prometheus.go b/shared/client/prometheus.go new file mode 100644 index 0000000..f5da0a7 --- /dev/null +++ b/shared/client/prometheus.go @@ -0,0 +1,311 @@ +package client + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/viant/mly/shared/stat/buckets" + "github.com/viant/mly/shared/stat/promc" +) + +var ( + // end-to-end duration + + runDurationSummaryMicros = prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "run_duration_summary_us", + Help: "Duration of client Run calls.", + Objectives: buckets.CommonSummaryObjectives, + }, + []string{"model"}, + ) + + runDurationHistogramMicros = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "run_duration_histogram_us", + Help: "Duration of client Run calls.", + Buckets: buckets.MicrosecondBuckets, + }, + []string{"model"}, + ) + + httpDurationSummaryMicros = prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "http_duration_summary_us", + Help: "Duration of client HTTP calls including retries.", + Objectives: buckets.CommonSummaryObjectives, + }, + []string{"model"}, + ) + + httpDurationHistogramMicros = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "http_duration_histogram_us", + Help: "Duration of client HTTP calls including retries.", + Buckets: buckets.MicrosecondBuckets, + }, + []string{"model"}, + ) + + httpClientDurationSummaryMicros = prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "http_client_duration_summary_us", + Help: "Duration of client HTTP client calls.", + Objectives: buckets.CommonSummaryObjectives, + }, + []string{"model"}, + ) + + httpClientDurationHistogramMicros = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "http_client_duration_histogram_us", + Help: "Duration of client HTTP client calls.", + Buckets: buckets.MicrosecondBuckets, + }, + []string{"model"}, + ) + + // EarlyCtxError + // loadFromCache error - this can only be a type error from Response.DataItemType(), (*Service).readFromCache() + + runErrorCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "prediction_error_counter", + Help: "Number of client and kind of prediction errors.", + }, + []string{"model", "error"}, + ) + + httpErrorCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "http_error_counter", + Help: "Number of client HTTP errors.", + }, + []string{"model", "error"}, + ) + + httpClientErrorCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "http_client_error_counter", + Help: "Number of client HTTP client errors.", + }, + []string{"model", "error"}, + ) + + // batch size + + batchSizeSummary = prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "batch_size_summary", + Help: "Size of client batches.", + Objectives: buckets.CommonSummaryObjectives, + }, + []string{"model"}, + ) + + batchSizeHistogram = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "batch_size_histogram", + Help: "Size of client batches.", + Buckets: buckets.MicrosecondBuckets, + }, + []string{"model"}, + ) +) + +type prometheusMetrics struct { + runDurationHistogram prometheus.Observer + batchSizeHistogram prometheus.Observer + httpDurationHistogram prometheus.Observer + httpClientDurationHistogram prometheus.Observer + + // The Summary metrics may be nil if noPrometheusSummaries is true. + + runDurationSummary prometheus.Observer + batchSizeSummary prometheus.Observer + httpDurationSummary prometheus.Observer + httpClientDurationSummary prometheus.Observer + + runErrorEarlyCtxCounter prometheus.Counter + runBaseErrorCounters promc.BaseErrorCounters + + httpDownCounter prometheus.Counter + httpBaseErrorCounters promc.BaseErrorCounters + + httpClientBaseErrorCounters promc.BaseErrorCounters +} + +func (m prometheusMetrics) observeRunDuration(duration float64) { + m.runDurationHistogram.Observe(duration) + if m.runDurationSummary != nil { + m.runDurationSummary.Observe(duration) + } +} + +func (m prometheusMetrics) observeBatchSize(batchSize float64) { + m.batchSizeHistogram.Observe(batchSize) + if m.batchSizeSummary != nil { + m.batchSizeSummary.Observe(batchSize) + } +} + +func (m prometheusMetrics) observeHttpDuration(duration float64) { + m.httpDurationHistogram.Observe(duration) + if m.httpDurationSummary != nil { + m.httpDurationSummary.Observe(duration) + } +} + +func (m prometheusMetrics) observeHttpClientDuration(duration float64) { + m.httpClientDurationHistogram.Observe(duration) + if m.httpClientDurationSummary != nil { + m.httpClientDurationSummary.Observe(duration) + } +} + +// Used strictly to test for error type. +var are prometheus.AlreadyRegisteredError + +func isPrometheusAlreadyRegisteredError(err error) bool { + if err == nil { + return false + } + + return errors.As(err, &are) +} + +func (m *prometheusMetrics) registerPrometheusMetrics(registerer prometheus.Registerer, model string, noPrometheusSummaries bool) error { + // convenience function + register := func(metric prometheus.Collector) error { + err := registerer.Register(metric) + if err != nil && !isPrometheusAlreadyRegisteredError(err) { + + dc := make(chan *prometheus.Desc) + go func() { + metric.Describe(dc) + }() + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + var metricName string + select { + case desc := <-dc: + metricName = desc.String() + case <-ctx.Done(): + metricName = "unknown" + } + + return fmt.Errorf("failed to register %s: %T, %w", metricName, err, err) + } + + return nil + } + + var err error + if !noPrometheusSummaries { + err = register(runDurationSummaryMicros) + if err != nil { + return err + } + m.runDurationSummary = runDurationSummaryMicros.WithLabelValues(model) + + err = register(batchSizeSummary) + if err != nil { + return err + } + m.batchSizeSummary = batchSizeSummary.WithLabelValues(model) + + err = register(httpDurationSummaryMicros) + if err != nil { + return err + } + m.httpDurationSummary = httpDurationSummaryMicros.WithLabelValues(model) + + err = register(httpClientDurationSummaryMicros) + if err != nil { + return err + } + m.httpClientDurationSummary = httpClientDurationSummaryMicros.WithLabelValues(model) + } + + err = register(runDurationHistogramMicros) + if err != nil { + return err + } + m.runDurationHistogram = runDurationHistogramMicros.WithLabelValues(model) + + err = register(batchSizeHistogram) + if err != nil { + return err + } + m.batchSizeHistogram = batchSizeHistogram.WithLabelValues(model) + + err = register(httpDurationHistogramMicros) + if err != nil { + return err + } + m.httpDurationHistogram = httpDurationHistogramMicros.WithLabelValues(model) + + err = register(httpClientDurationHistogramMicros) + if err != nil { + return err + } + m.httpClientDurationHistogram = httpClientDurationHistogramMicros.WithLabelValues(model) + + err = register(runErrorCounter) + if err != nil { + return err + } + + m.runErrorEarlyCtxCounter = runErrorCounter.WithLabelValues(model, "earlyCtx") + + // convenience function + mkBECs := func(bec *promc.BaseErrorCounters, counter *prometheus.CounterVec) { + bec.OtherErrorCounter = counter.WithLabelValues(model, "error") + bec.DeadlineExceededCounter = counter.WithLabelValues(model, "deadlineExceeded") + bec.CanceledCounter = counter.WithLabelValues(model, "canceled") + } + + mkBECs(&m.runBaseErrorCounters, runErrorCounter) + + err = register(httpErrorCounter) + if err != nil { + return err + } + mkBECs(&m.httpBaseErrorCounters, httpErrorCounter) + + m.httpDownCounter = httpErrorCounter.WithLabelValues(model, "down") + + err = register(httpClientErrorCounter) + if err != nil { + return err + } + mkBECs(&m.httpClientBaseErrorCounters, runErrorCounter) + + return nil +} diff --git a/shared/client/service.go b/shared/client/service.go index 1b72425..80ec461 100644 --- a/shared/client/service.go +++ b/shared/client/service.go @@ -5,6 +5,7 @@ import ( "context" "crypto/tls" "encoding/json" + "errors" "fmt" "io" "log" @@ -18,6 +19,7 @@ import ( "time" "github.com/francoispqt/gojay" + "github.com/prometheus/client_golang/prometheus" "github.com/viant/gmetric" "github.com/viant/mly/shared/client/config" "github.com/viant/mly/shared/common" @@ -66,6 +68,17 @@ type Service struct { httpCliCounter *gmetric.Operation dictCounter *gmetric.Operation + // PrometheusRegisterer is used to register Prometheus metrics. + // If not provided, the default Prometheus registry will be used. + PrometheusRegisterer prometheus.Registerer + + // noPrometheusSummaries is used to disable Prometheus summaries. + // If true, only histograms will be registered and used. + // See https://prometheus.io/docs/practices/histograms for guidance. + noPrometheusSummaries bool + + prometheusMetrics prometheusMetrics + ErrorHistory tracker.Tracker } @@ -80,15 +93,21 @@ func (s *Service) NewMessage() *Message { // input can vary in types, but if it is an instance of Cachable, then the configured // caching system will be used. func (s *Service) Run(ctx context.Context, input interface{}, response *Response) error { - onDone := s.counter.Begin(time.Now()) + startTime := time.Now() + onDone := s.counter.Begin(startTime) stats := stat.NewValues() + defer func() { onDone(time.Now(), *stats...) + + duration := time.Since(startTime).Microseconds() + s.prometheusMetrics.observeRunDuration(float64(duration)) s.releaseMessage(input) }() if ctx.Err() != nil { stats.Append(stat.EarlyCtxError) + s.prometheusMetrics.runErrorEarlyCtxCounter.Inc() } if response.Data == nil { @@ -107,6 +126,8 @@ func (s *Service) Run(ctx context.Context, input interface{}, response *Response cachedCount, err = s.loadFromCache(ctx, &cached, batchSize, response, cachable) if err != nil { stats.AppendError(err) + s.prometheusMetrics.runBaseErrorCounters.Observe(err) + if ctx.Err() == nil && s.ErrorHistory != nil { go s.ErrorHistory.AddBytes([]byte(err.Error())) } @@ -123,6 +144,8 @@ func (s *Service) Run(ctx context.Context, input interface{}, response *Response s.reportBatch(cachedCount, cached) } + s.prometheusMetrics.observeBatchSize(float64(batchSize)) + if (batchSize > 0 && cachedCount == batchSize) || (batchSize == 0 && cachedCount > 0) { response.Status = common.StatusCached return s.handleResponse(ctx, response.Data, cached, cachable) @@ -131,6 +154,7 @@ func (s *Service) Run(ctx context.Context, input interface{}, response *Response data, err := Marshal(input, modelName) if err != nil { stats.AppendError(err) + s.prometheusMetrics.runBaseErrorCounters.Observe(err) return err } @@ -140,13 +164,18 @@ func (s *Service) Run(ctx context.Context, input interface{}, response *Response } body, err := func() ([]byte, error) { - httpOnDone := s.httpCounter.Begin(time.Now()) + startTime := time.Now() + httpOnDone := s.httpCounter.Begin(startTime) httpStats := stat.NewValues() - od := metric.EnterThenExit(s.httpCounter, time.Now(), stat.Enter, stat.Exit) + od := metric.EnterThenExit(s.httpCounter, startTime, stat.Enter, stat.Exit) defer func() { httpOnDone(time.Now(), httpStats.Values()...) + + duration := time.Since(startTime).Microseconds() + s.prometheusMetrics.observeHttpDuration(float64(duration)) + od() }() @@ -158,6 +187,7 @@ func (s *Service) Run(ctx context.Context, input interface{}, response *Response if err != nil { httpStats.AppendError(err) + s.prometheusMetrics.httpBaseErrorCounters.Observe(err) } return body, err @@ -165,6 +195,7 @@ func (s *Service) Run(ctx context.Context, input interface{}, response *Response if err != nil { stats.AppendError(err) + s.prometheusMetrics.runBaseErrorCounters.Observe(err) if ctx.Err() == nil && s.ErrorHistory != nil { go s.ErrorHistory.AddBytes([]byte(err.Error())) } @@ -175,6 +206,7 @@ func (s *Service) Run(ctx context.Context, input interface{}, response *Response err = gojay.Unmarshal(body, response) if err != nil { stats.AppendError(err) + s.prometheusMetrics.runBaseErrorCounters.Observe(err) return fmt.Errorf("failed to unmarshal: '%s'; due to %w", body, err) } @@ -188,6 +220,7 @@ func (s *Service) Run(ctx context.Context, input interface{}, response *Response if err = s.handleResponse(ctx, response.Data, cached, cachable); err != nil { stats.AppendError(err) + s.prometheusMetrics.runBaseErrorCounters.Observe(err) return fmt.Errorf("failed to handle resp: %w", err) } @@ -224,6 +257,7 @@ func (s *Service) loadFromCache(ctx context.Context, cached *[]interface{}, batc response.Status = common.StatusCached response.DictHash = dictHash } + return cachedCount, nil } @@ -257,6 +291,7 @@ func (s *Service) readFromCacheInBatch(ctx context.Context, batchSize int, dataT return cachedCount, err } +// readFromCache will return an error if target is not a pointer. func (s *Service) readFromCache(ctx context.Context, key string, target interface{}) (bool, int, error) { if s.datastore == nil || !s.datastore.Enabled() { return false, 0, nil @@ -264,7 +299,7 @@ func (s *Service) readFromCache(ctx context.Context, key string, target interfac dataType := reflect.TypeOf(target) if dataType.Kind() != reflect.Ptr { - return false, 0, fmt.Errorf("invalid response data type: expeted ptr but had: %T", target) + return false, 0, fmt.Errorf("invalid response data type: expected reflect.Ptr but had: %T", target) } storeKey := s.datastore.Key(key) @@ -292,6 +327,15 @@ func (s *Service) dictionary() *Dictionary { return dict } +func (s *Service) registerPrometheusMetrics() error { + pr := prometheus.DefaultRegisterer + if s.PrometheusRegisterer != nil { + pr = s.PrometheusRegisterer + } + + return s.prometheusMetrics.registerPrometheusMetrics(pr, s.Model, s.noPrometheusSummaries) +} + func (s *Service) init() error { if s.gmetrics == nil { s.gmetrics = gmetric.New() @@ -303,6 +347,11 @@ func (s *Service) init() error { s.httpCliCounter = s.gmetrics.MultiOperationCounter(location, s.Model+"ClientHTTPCli", s.Model+" client HTTP client performance", time.Microsecond, time.Minute, 2, stat.NewCtxErrOnly()) s.dictCounter = s.gmetrics.MultiOperationCounter(location, s.Model+"ClientDict", s.Model+" client dictionary performance", time.Microsecond, time.Minute, 1, stat.ErrorOnly()) + err := s.registerPrometheusMetrics() + if err != nil { + return fmt.Errorf("failed to register Prometheus metrics: %w", err) + } + if s.ErrorHistory == nil { s.ErrorHistory = mg.NewK(20) } @@ -311,7 +360,7 @@ func (s *Service) init() error { s.Config.MaxRetry = 3 } - err := s.initHTTPClient() + err = s.initHTTPClient() if err != nil { return err } @@ -604,6 +653,11 @@ func (s *Service) postRequest(ctx context.Context, data []byte, mvt *stat.Values // TODO per-host counters host, err := s.getHost() if err != nil { + if errors.Is(err, common.ErrNodeDown) { + mvt.Append(stat.Down) + s.prometheusMetrics.httpDownCounter.Inc() + } + return nil, err } @@ -614,7 +668,10 @@ func (s *Service) postRequest(ctx context.Context, data []byte, mvt *stat.Values if s.Config.Debug { log.Printf("[%s postRequest] connection error:%s", s.Config.Model, err) } + mvt.Append(stat.Down) + s.prometheusMetrics.httpDownCounter.Inc() + host.FlagDown() } @@ -627,16 +684,21 @@ func (s *Service) httpPost(ctx context.Context, data []byte, host *Host) ([]byte var postErr error for i := 0; i < s.MaxRetry; i++ { data, err := func() ([]byte, error) { - onDone := s.httpCliCounter.Begin(time.Now()) + startTime := time.Now() + onDone := s.httpCliCounter.Begin(startTime) stats := stat.NewValues() defer func() { onDone(time.Now(), stats.Values()...) + + duration := time.Since(startTime).Microseconds() + s.prometheusMetrics.observeHttpClientDuration(float64(duration)) }() request, err := http.NewRequestWithContext(ctx, http.MethodPost, evalUrl, bytes.NewReader(data)) if err != nil { stats.AppendError(err) + s.prometheusMetrics.httpClientBaseErrorCounters.Observe(err) return nil, err } @@ -647,6 +709,7 @@ func (s *Service) httpPost(ctx context.Context, data []byte, host *Host) ([]byte if err != nil { stats.AppendError(err) + s.prometheusMetrics.httpClientBaseErrorCounters.Observe(err) return nil, err } diff --git a/shared/common/error.go b/shared/common/error.go index ba62b13..0039897 100644 --- a/shared/common/error.go +++ b/shared/common/error.go @@ -1,9 +1,10 @@ package common import ( + "strings" + "github.com/aerospike/aerospike-client-go/types" "github.com/pkg/errors" - "strings" ) const ( @@ -11,10 +12,10 @@ const ( connRefusedError = "refused" ) -//ErrNodeDown node down error +// ErrNodeDown node down error var ErrNodeDown = errors.New("node is down") -//IsKeyNotFound returns true if key not found error +// IsKeyNotFound returns true if key not found error func IsKeyNotFound(err error) bool { if err == nil { return false @@ -33,7 +34,7 @@ func IsKeyNotFound(err error) bool { return aeroError.ResultCode() == types.KEY_NOT_FOUND_ERROR } -//IsTimeout returns true if timeout error +// IsTimeout returns true if timeout error func IsTimeout(err error) bool { if err == nil { return false @@ -52,12 +53,13 @@ func IsTimeout(err error) bool { return aeroError.ResultCode() == types.TIMEOUT } -//IsTransientError returns if transient error +// IsTransientError returns if transient error +// NOTE: This has an inverted dependency on Aerospike; the downstream implementation detail is not abstracted out. func IsTransientError(err error) bool { return IsKeyNotFound(err) || IsInvalidNode(err) || IsTimeout(err) || IsInvalidNode(err) || IsConnectionError(err) } -//IsInvalidNode returns true is node/cluster is down +// IsInvalidNode returns true is node/cluster is down func IsInvalidNode(err error) bool { if err == nil { return false @@ -79,7 +81,7 @@ func IsInvalidNode(err error) bool { return aeroError.ResultCode() == types.INVALID_NODE_ERROR } -//IsConnectionError returns true if error is connection errpr +// IsConnectionError returns true if error is connection errpr func IsConnectionError(err error) bool { if err == nil { return false diff --git a/shared/datastore/client/service.go b/shared/datastore/client/service.go index d8052b0..eecfd70 100644 --- a/shared/datastore/client/service.go +++ b/shared/datastore/client/service.go @@ -3,13 +3,16 @@ package client import ( "context" "fmt" + "reflect" "strings" "time" aero "github.com/aerospike/aerospike-client-go" + "github.com/viant/gmetric" "github.com/viant/mly/shared/circut" "github.com/viant/mly/shared/common" "github.com/viant/mly/shared/config/datastore" + "github.com/viant/mly/shared/stat" "golang.org/x/sync/singleflight" ) @@ -28,6 +31,10 @@ const ( type Service struct { Client Aero + gmOpGet *gmetric.Operation + gmOpPutReq *gmetric.Operation + gmOpPutExec *gmetric.Operation + config *datastore.Connection // bypassConfiguredTimeout is used to bypass the configured timeout if using WithClientPolicy or WithBasePolicy. @@ -54,6 +61,12 @@ func (s *Service) Get(ctx context.Context, key *aero.Key, binNames ...string) (r return nil, common.ErrNodeDown } + onDone := s.gmOpGet.Begin(time.Now()) + stats := stat.NewValues() + defer func() { + onDone(time.Now(), stats.Values()...) + }() + defer func() { if r := recover(); r != nil { connection := s.config.ID @@ -63,6 +76,8 @@ func (s *Service) Get(ctx context.Context, key *aero.Key, binNames ...string) (r record, err = s.Client.Get(s.basePolicy, key, binNames...) s.checkConnectionError(err) + stats.AppendError(err) + return record, err } @@ -79,6 +94,12 @@ func (s *Service) Put(writePolicy *aero.WritePolicy, key *aero.Key, value aero.B keyStr := keyString(key) + onDone := s.gmOpPutReq.Begin(time.Now()) + stats := stat.NewValues() + defer func() { + onDone(time.Now(), stats.Values()...) + }() + defer func() { if r := recover(); r != nil { connection := s.config.ID @@ -97,8 +118,16 @@ func (s *Service) Put(writePolicy *aero.WritePolicy, key *aero.Key, value aero.B defer cancel() ch := s.group.DoChan(keyStr, func() (interface{}, error) { + onDone := s.gmOpPutExec.Begin(time.Now()) + stats := stat.NewValues() + defer func() { + onDone(time.Now(), stats.Values()...) + }() + err := s.Client.Put(writePolicy, key, value) s.checkConnectionError(err) + stats.AppendError(err) + return nil, err }) @@ -110,7 +139,7 @@ func (s *Service) Put(writePolicy *aero.WritePolicy, key *aero.Key, value aero.B err = fmt.Errorf("put aerospike[%s] key: %s shared: %v error: %w", s.config.ID, keyStr, res.Shared, res.Err) } } - + stats.AppendError(err) return err } @@ -189,9 +218,25 @@ func New(config *datastore.Connection) (*Service, error) { } func NewWithOptions(config *datastore.Connection, options ...Option) (*Service, error) { + return NewWithOptionsV2(config, nil, options...) +} + +func NewWithOptionsV2(config *datastore.Connection, gmetrics *gmetric.Service, options ...Option) (*Service, error) { + if gmetrics == nil { + gmetrics = gmetric.New() + } + + location := reflect.TypeOf(Service{}).PkgPath() + gmOpGet := gmetrics.MultiOperationCounter(location, config.ID+"AerospikeGet", config.ID+" get performance", time.Microsecond, time.Minute, 2, stat.NewCtxErrOnly()) + gmOpPutReq := gmetrics.MultiOperationCounter(location, config.ID+"AerospikePutRequested", config.ID+" put performance including singleflight", time.Microsecond, time.Minute, 2, stat.NewCtxErrOnly()) + gmOpPutExec := gmetrics.MultiOperationCounter(location, config.ID+"AerospikePutExecuted", config.ID+" put performance", time.Microsecond, time.Minute, 2, stat.NewCtxErrOnly()) + srv := &Service{ - config: config, - group: new(singleflight.Group), + config: config, + group: new(singleflight.Group), + gmOpGet: gmOpGet, + gmOpPutReq: gmOpPutReq, + gmOpPutExec: gmOpPutExec, } srv.init(options...) diff --git a/shared/datastore/service.go b/shared/datastore/service.go index 5ac6c1f..cfcb66c 100644 --- a/shared/datastore/service.go +++ b/shared/datastore/service.go @@ -263,7 +263,8 @@ func (s *Service) updateCache(keyString string, entryData EntryData, dictHash in return nil } -// reads from local +// readFromCache reads from local. +// This can return an error if the cache cannot be unmarshalled. func (s *Service) readFromCache(keyString string, value Value, stats *stat.Values) (CacheStatus, int, error) { data, _ := s.cache.Get(keyString) if len(data) == 0 { diff --git a/shared/datastore/stores.go b/shared/datastore/stores.go index 29f943b..18c8149 100644 --- a/shared/datastore/stores.go +++ b/shared/datastore/stores.go @@ -53,7 +53,7 @@ func NewStoresV4(cfg *config.DatastoreList, gmetrics *gmetric.Service, verbose b continue } - aero, err := client.NewWithOptions(connection, clientOptions...) + aero, err := client.NewWithOptionsV2(connection, gmetrics, clientOptions...) if err != nil { return nil, fmt.Errorf("failed to create client for %v, due to %w", connID, err) } diff --git a/shared/stat/buckets/prometheus.go b/shared/stat/buckets/prometheus.go index 24f2b64..f11e4f5 100644 --- a/shared/stat/buckets/prometheus.go +++ b/shared/stat/buckets/prometheus.go @@ -6,10 +6,12 @@ package buckets var MicrosecondBuckets []float64 = []float64{ 100, 500, + // 1 millisecond 1000, 2000, 3000, 5000, 7500, 10000, 20000, 30000, 50000, 75000, 100000, 200000, 400000, 800000, - 1000000, + // 1 second + 1000000, 2000000, } var MillisecondBuckets []float64 = []float64{ @@ -28,8 +30,9 @@ var SecondBuckets []float64 = []float64{ } var CommonSummaryObjectives = map[float64]float64{ - 0.5: 0.05, - 0.9: 0.01, - 0.95: 0.005, - 0.99: 0.001, + 0.5: 0.05, + 0.9: 0.01, + 0.95: 0.005, + 0.99: 0.001, + 0.999: 0.001, } diff --git a/shared/stat/promc/error.go b/shared/stat/promc/error.go new file mode 100644 index 0000000..7f1d81f --- /dev/null +++ b/shared/stat/promc/error.go @@ -0,0 +1,25 @@ +package promc + +import ( + "context" + "errors" + + "github.com/prometheus/client_golang/prometheus" +) + +type BaseErrorCounters struct { + DeadlineExceededCounter prometheus.Counter + CanceledCounter prometheus.Counter + + OtherErrorCounter prometheus.Counter +} + +func (c BaseErrorCounters) Observe(err error) { + if c.DeadlineExceededCounter != nil && errors.Is(err, context.DeadlineExceeded) { + c.DeadlineExceededCounter.Inc() + } else if c.CanceledCounter != nil && errors.Is(err, context.Canceled) { + c.CanceledCounter.Inc() + } else if c.OtherErrorCounter != nil { + c.OtherErrorCounter.Inc() + } +} From 3363823fa6bc9e166932065422ad8ae76973ddaf Mon Sep 17 00:00:00 2001 From: David Choi Date: Mon, 26 Jan 2026 14:33:44 -0800 Subject: [PATCH 15/50] Allow CLI to have non-batch, empty string payloads. --- example/client/payload.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/example/client/payload.go b/example/client/payload.go index 27426a3..b3e504e 100644 --- a/example/client/payload.go +++ b/example/client/payload.go @@ -70,11 +70,19 @@ func Parse(p string, cp *CliPayload) error { for _, chunk := range chunks { def := strings.Split(chunk, ":") if len(def) != 2 { - return fmt.Errorf("chunk \"%s\" missing or has more than one \":\"", chunk) + return fmt.Errorf("chunk \"%s\" has more than one \":\"", chunk) + } + + valStr := def[1] + var vals []string + var err error + if valStr == "" { + vals = []string{""} + } else { + vals, err = csv.NewReader(strings.NewReader(valStr)).Read() } field := def[0] - vals, err := csv.NewReader(strings.NewReader(def[1])).Read() if err != nil { return fmt.Errorf("csv error for field %s: %v", field, err) } From 034b7b908927614ee30c0cd4fd7aa5c3e79d9ff2 Mon Sep 17 00:00:00 2001 From: David Choi Date: Mon, 26 Jan 2026 16:20:40 -0800 Subject: [PATCH 16/50] Clean up client Prometheus setup, better buckets for batch sizes, add export Prometheus options to CLI. --- example/client/option.go | 7 +- example/client/runner.go | 17 ++++ shared/client/prometheus.go | 185 ++++++++++++++++++------------------ 3 files changed, 115 insertions(+), 94 deletions(-) diff --git a/example/client/option.go b/example/client/option.go index a28d251..3adea75 100644 --- a/example/client/option.go +++ b/example/client/option.go @@ -37,8 +37,11 @@ type Options struct { SkipError bool `long:"skiperrs"` - NoOutput bool `long:"noout"` - Metrics bool `long:"metrics"` + // NoOutput suppresses model outputs. + NoOutput bool `long:"noout"` + + Metrics bool `long:"metrics" description:"print gmetric metrics"` + Prometheus bool `long:"prometheus" description:"print prometheus metrics"` ErrorHistory bool `long:"errhist"` // Report forces NoOutput and SkipError true, Metrics and ErrorHistory false. diff --git a/example/client/runner.go b/example/client/runner.go index 463ff33..662f46b 100644 --- a/example/client/runner.go +++ b/example/client/runner.go @@ -4,9 +4,12 @@ import ( "context" "fmt" "log" + "os" "sync" "time" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/common/expfmt" "github.com/viant/gmetric" "github.com/viant/mly/service/endpoint/checker" "github.com/viant/mly/shared/client" @@ -190,6 +193,20 @@ func RunWithOptions(runOpts *Options) error { } } + if runOpts.Prometheus { + mfs, err := prometheus.DefaultGatherer.Gather() + if err != nil { + return fmt.Errorf("failed to gather prometheus metrics: %w", err) + } + + encoder := expfmt.NewEncoder(os.Stdout, expfmt.FmtText) + for _, mf := range mfs { + if err := encoder.Encode(mf); err != nil { + return fmt.Errorf("failed to encode metric family %s: %w", mf.GetName(), err) + } + } + } + if runOpts.Report { toolbox.Dump(report) } diff --git a/shared/client/prometheus.go b/shared/client/prometheus.go index f5da0a7..8048ece 100644 --- a/shared/client/prometheus.go +++ b/shared/client/prometheus.go @@ -11,75 +11,14 @@ import ( "github.com/viant/mly/shared/stat/promc" ) -var ( - // end-to-end duration - - runDurationSummaryMicros = prometheus.NewSummaryVec( - prometheus.SummaryOpts{ - Namespace: "mly", - Subsystem: "client", - Name: "run_duration_summary_us", - Help: "Duration of client Run calls.", - Objectives: buckets.CommonSummaryObjectives, - }, - []string{"model"}, - ) - - runDurationHistogramMicros = prometheus.NewHistogramVec( - prometheus.HistogramOpts{ - Namespace: "mly", - Subsystem: "client", - Name: "run_duration_histogram_us", - Help: "Duration of client Run calls.", - Buckets: buckets.MicrosecondBuckets, - }, - []string{"model"}, - ) - - httpDurationSummaryMicros = prometheus.NewSummaryVec( - prometheus.SummaryOpts{ - Namespace: "mly", - Subsystem: "client", - Name: "http_duration_summary_us", - Help: "Duration of client HTTP calls including retries.", - Objectives: buckets.CommonSummaryObjectives, - }, - []string{"model"}, - ) - - httpDurationHistogramMicros = prometheus.NewHistogramVec( - prometheus.HistogramOpts{ - Namespace: "mly", - Subsystem: "client", - Name: "http_duration_histogram_us", - Help: "Duration of client HTTP calls including retries.", - Buckets: buckets.MicrosecondBuckets, - }, - []string{"model"}, - ) - - httpClientDurationSummaryMicros = prometheus.NewSummaryVec( - prometheus.SummaryOpts{ - Namespace: "mly", - Subsystem: "client", - Name: "http_client_duration_summary_us", - Help: "Duration of client HTTP client calls.", - Objectives: buckets.CommonSummaryObjectives, - }, - []string{"model"}, - ) - - httpClientDurationHistogramMicros = prometheus.NewHistogramVec( - prometheus.HistogramOpts{ - Namespace: "mly", - Subsystem: "client", - Name: "http_client_duration_histogram_us", - Help: "Duration of client HTTP client calls.", - Buckets: buckets.MicrosecondBuckets, - }, - []string{"model"}, - ) +const ( + promDescRunDuration = "Duration of client Run calls." + promDescHTTPDuration = "Duration of client HTTP calls, including retries." + promDescHTTPClientDuration = "Duration of client HTTP client calls." + promDescBatchSize = "Size of client batches." +) +var ( // EarlyCtxError // loadFromCache error - this can only be a type error from Response.DataItemType(), (*Service).readFromCache() @@ -112,30 +51,6 @@ var ( }, []string{"model", "error"}, ) - - // batch size - - batchSizeSummary = prometheus.NewSummaryVec( - prometheus.SummaryOpts{ - Namespace: "mly", - Subsystem: "client", - Name: "batch_size_summary", - Help: "Size of client batches.", - Objectives: buckets.CommonSummaryObjectives, - }, - []string{"model"}, - ) - - batchSizeHistogram = prometheus.NewHistogramVec( - prometheus.HistogramOpts{ - Namespace: "mly", - Subsystem: "client", - Name: "batch_size_histogram", - Help: "Size of client batches.", - Buckets: buckets.MicrosecondBuckets, - }, - []string{"model"}, - ) ) type prometheusMetrics struct { @@ -228,24 +143,68 @@ func (m *prometheusMetrics) registerPrometheusMetrics(registerer prometheus.Regi var err error if !noPrometheusSummaries { + runDurationSummaryMicros := prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "run_duration_summary_us", + Help: promDescRunDuration, + Objectives: buckets.CommonSummaryObjectives, + }, + []string{"model"}, + ) + err = register(runDurationSummaryMicros) if err != nil { return err } m.runDurationSummary = runDurationSummaryMicros.WithLabelValues(model) + batchSizeSummary := prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "batch_size_summary", + Help: promDescBatchSize, + Objectives: buckets.CommonSummaryObjectives, + }, + []string{"model"}, + ) + err = register(batchSizeSummary) if err != nil { return err } m.batchSizeSummary = batchSizeSummary.WithLabelValues(model) + httpDurationSummaryMicros := prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "http_duration_summary_us", + Help: promDescHTTPDuration, + Objectives: buckets.CommonSummaryObjectives, + }, + []string{"model"}, + ) + err = register(httpDurationSummaryMicros) if err != nil { return err } m.httpDurationSummary = httpDurationSummaryMicros.WithLabelValues(model) + httpClientDurationSummaryMicros := prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "http_client_duration_summary_us", + Help: promDescHTTPClientDuration, + Objectives: buckets.CommonSummaryObjectives, + }, + []string{"model"}, + ) + err = register(httpClientDurationSummaryMicros) if err != nil { return err @@ -253,24 +212,66 @@ func (m *prometheusMetrics) registerPrometheusMetrics(registerer prometheus.Regi m.httpClientDurationSummary = httpClientDurationSummaryMicros.WithLabelValues(model) } + runDurationHistogramMicros := prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "run_duration_histogram_us", + Help: promDescRunDuration, + Buckets: buckets.MicrosecondBuckets, + }, + []string{"model"}, + ) + err = register(runDurationHistogramMicros) if err != nil { return err } m.runDurationHistogram = runDurationHistogramMicros.WithLabelValues(model) + batchSizeHistogram := prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "batch_size_histogram", + Help: promDescBatchSize, + Buckets: []float64{1, 2, 3, 4, 5, 7, 10, 12, 15, 20, 25, 30, 40, 50, 60, 70, 80, 90, 100}, + }, + []string{"model"}, + ) err = register(batchSizeHistogram) if err != nil { return err } m.batchSizeHistogram = batchSizeHistogram.WithLabelValues(model) + httpDurationHistogramMicros := prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "http_duration_histogram_us", + Help: promDescHTTPDuration, + Buckets: buckets.MicrosecondBuckets, + }, + []string{"model"}, + ) err = register(httpDurationHistogramMicros) if err != nil { return err } m.httpDurationHistogram = httpDurationHistogramMicros.WithLabelValues(model) + httpClientDurationHistogramMicros := prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "http_client_duration_histogram_us", + Help: promDescHTTPClientDuration, + Buckets: buckets.MicrosecondBuckets, + }, + []string{"model"}, + ) + err = register(httpClientDurationHistogramMicros) if err != nil { return err From 540d5b2765827d80bd36d8141f0b7499e7c0921c Mon Sep 17 00:00:00 2001 From: David Choi Date: Wed, 28 Jan 2026 12:43:08 -0800 Subject: [PATCH 17/50] Add support for batching in router. --- service/config/router.go | 12 +- service/platform/evaluator.go | 2 +- .../platform/router/batchkey_bench_test.go | 137 ++++ service/platform/router/prometheus.go | 12 - service/platform/router/reload.go | 513 ++++++++++++ service/platform/router/router.go | 728 ++++-------------- service/platform/router/router_test.go | 337 ++++++++ service/platform/router/worker.go | 63 -- service/request/shape/batch.go | 96 ++- service/request/shape/batch_test.go | 129 ++++ 10 files changed, 1366 insertions(+), 663 deletions(-) create mode 100644 service/platform/router/batchkey_bench_test.go create mode 100644 service/platform/router/reload.go delete mode 100644 service/platform/router/worker.go diff --git a/service/config/router.go b/service/config/router.go index 9b2cdb2..062d2ad 100644 --- a/service/config/router.go +++ b/service/config/router.go @@ -9,15 +9,17 @@ type RouterConfig struct { // Required name of the input that will route the request to the backend. InputName string `json:",omitempty" yaml:",omitempty"` - // Unimplemented. - // If true, the router will batch the requests to the backend. - BatchBackend bool `json:",omitempty" yaml:",omitempty"` + // ForceBatchSize1 controls whether the router sends individual samples or batches by model. + // When false (default), requests within a single Predict() call that route to the + // same model evaluator are grouped into a single batched prediction call. + // When true, each sample is sent as an individual prediction request with batch size 1. + ForceBatchSize1 bool `json:",omitempty" yaml:",omitempty"` - // The maximum number of concurrent requests to the backend. + // The maximum number of concurrent batches dispatched to model evaluators. // Defaults to 50. Workers int `json:",omitempty" yaml:",omitempty"` - // The maximum number of requests to queue. + // The maximum number of batches to queue before rejecting. // Defaults to 1000. MaxQueueSize int `json:",omitempty" yaml:",omitempty"` diff --git a/service/platform/evaluator.go b/service/platform/evaluator.go index d00d5b6..0c88ca7 100644 --- a/service/platform/evaluator.go +++ b/service/platform/evaluator.go @@ -30,7 +30,7 @@ type PlatformEvaluator interface { // This will be invoked after at least 1 ReloadIfNeeded() succeeds. Inputs() map[string]*domain.Input - // Stats returns platform-specific live metrics, for debugging + // Deprecated: Do not use or implement. Stats returns platform-specific live metrics, for debugging purposes Stats(stats map[string]interface{}) Close() error diff --git a/service/platform/router/batchkey_bench_test.go b/service/platform/router/batchkey_bench_test.go new file mode 100644 index 0000000..41407d7 --- /dev/null +++ b/service/platform/router/batchkey_bench_test.go @@ -0,0 +1,137 @@ +package router + +import ( + "fmt" + "strconv" + "testing" +) + +// Benchmark different approaches to generating batch keys +// These benchmarks compare key generation strategies for the ForceBatchSize1 path + +var ( + sampleModelName = "model_name_example" + sampleOffset = 12345 + result string // prevent compiler optimization +) + +func BenchmarkBatchKey_Sprintf(b *testing.B) { + var r string + for i := 0; i < b.N; i++ { + r = fmt.Sprintf("%s#%d", sampleModelName, sampleOffset) + } + result = r +} + +func BenchmarkBatchKey_StrconcatStrconv(b *testing.B) { + var r string + for i := 0; i < b.N; i++ { + r = sampleModelName + "#" + strconv.Itoa(sampleOffset) + } + result = r +} + +func BenchmarkBatchKey_StrconcatFormatInt(b *testing.B) { + var r string + for i := 0; i < b.N; i++ { + r = sampleModelName + "#" + strconv.FormatInt(int64(sampleOffset), 10) + } + result = r +} + +// Benchmark with varying model name lengths +func BenchmarkBatchKey_ShortName_Sprintf(b *testing.B) { + name := "m1" + var r string + for i := 0; i < b.N; i++ { + r = fmt.Sprintf("%s#%d", name, i%1000) + } + result = r +} + +func BenchmarkBatchKey_ShortName_Strconcat(b *testing.B) { + name := "m1" + var r string + for i := 0; i < b.N; i++ { + r = name + "#" + strconv.Itoa(i%1000) + } + result = r +} + +func BenchmarkBatchKey_LongName_Sprintf(b *testing.B) { + name := "very_long_model_name_with_many_characters_for_testing" + var r string + for i := 0; i < b.N; i++ { + r = fmt.Sprintf("%s#%d", name, i%1000) + } + result = r +} + +func BenchmarkBatchKey_LongName_Strconcat(b *testing.B) { + name := "very_long_model_name_with_many_characters_for_testing" + var r string + for i := 0; i < b.N; i++ { + r = name + "#" + strconv.Itoa(i%1000) + } + result = r +} + +// Benchmark the map lookup with generated keys (more realistic scenario) +func BenchmarkBatchKey_MapLookup_Sprintf(b *testing.B) { + m := make(map[string]int) + names := []string{"model1", "model2", "model3", "model4", "model5"} + + // Pre-populate map + for _, name := range names { + for j := 0; j < 100; j++ { + m[fmt.Sprintf("%s#%d", name, j)] = j + } + } + + b.ResetTimer() + var sum int + for i := 0; i < b.N; i++ { + key := fmt.Sprintf("%s#%d", names[i%5], i%100) + sum += m[key] + } + _ = sum +} + +func BenchmarkBatchKey_MapLookup_Strconcat(b *testing.B) { + m := make(map[string]int) + names := []string{"model1", "model2", "model3", "model4", "model5"} + + // Pre-populate map + for _, name := range names { + for j := 0; j < 100; j++ { + m[name+"#"+strconv.Itoa(j)] = j + } + } + + b.ResetTimer() + var sum int + for i := 0; i < b.N; i++ { + key := names[i%5] + "#" + strconv.Itoa(i%100) + sum += m[key] + } + _ = sum +} + +// Benchmark allocations +func BenchmarkBatchKey_Allocs_Sprintf(b *testing.B) { + b.ReportAllocs() + var r string + for i := 0; i < b.N; i++ { + r = fmt.Sprintf("%s#%d", sampleModelName, i%1000) + } + result = r +} + +func BenchmarkBatchKey_Allocs_Strconcat(b *testing.B) { + b.ReportAllocs() + var r string + for i := 0; i < b.N; i++ { + r = sampleModelName + "#" + strconv.Itoa(i%1000) + } + result = r +} diff --git a/service/platform/router/prometheus.go b/service/platform/router/prometheus.go index 9809a5d..8fa9538 100644 --- a/service/platform/router/prometheus.go +++ b/service/platform/router/prometheus.go @@ -28,17 +28,6 @@ var ( []string{"router", "fixed_only"}, ) - routerWorkerChannelQueuedSummary = prometheus.NewSummaryVec( - prometheus.SummaryOpts{ - Namespace: "mly", - Subsystem: "router", - Name: "worker_channel_queued_summary", - Help: "Number of router predictions queued in the worker channel.", - Objectives: buckets.CommonSummaryObjectives, - }, - []string{"router"}, - ) - routerPredictDroppedCounter = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: "mly", @@ -65,5 +54,4 @@ func init() { prometheus.MustRegister(routerReloadDurationMicrosSummary) prometheus.MustRegister(routerModelUnloadGauge) prometheus.MustRegister(routerPredictDroppedCounter) - prometheus.MustRegister(routerWorkerChannelQueuedSummary) } diff --git a/service/platform/router/reload.go b/service/platform/router/reload.go new file mode 100644 index 0000000..2d2ba4a --- /dev/null +++ b/service/platform/router/reload.go @@ -0,0 +1,513 @@ +package router + +import ( + "compress/gzip" + "context" + "encoding/json" + "fmt" + "io" + "log" + "reflect" + "strings" + "sync" + "time" + + "github.com/viant/mly/service/config" + "github.com/viant/mly/service/domain" + "github.com/viant/mly/service/files" + "github.com/viant/mly/service/platform" + "github.com/viant/mly/shared/config/router" + "gopkg.in/yaml.v2" +) + +type modelSignature struct { + name string + signature *domain.Signature +} + +func (r *Router) ReloadIfNeeded(ctx context.Context) error { + start := time.Now() + isFullReload := false + defer func() { + var mode string + if isFullReload { + mode = "full" + } else { + mode = "checks" + } + routerReloadDurationMicrosSummary.WithLabelValues(r.routerName, mode).Observe(float64(time.Since(start).Microseconds())) + }() + + // fetch and check router configuration file + snapshot, err := files.ModifiedSnapshot(ctx, r.fs, r.configURL, nil) + if err != nil { + return fmt.Errorf("failed to check router configuration file: %w", err) + } + + if !r.isModified(snapshot) { + // check health of all underlying models + var wg sync.WaitGroup + + r.configLock.RLock() + errChannels := len(r.routingTable) + if r.globalModel != nil { + errChannels++ + } + + errCh := make(chan error, errChannels) + + if r.globalModel != nil { + wg.Add(1) + go func() { + defer wg.Done() + err := r.globalModel.ReloadIfNeeded(ctx) + if err != nil { + errCh <- fmt.Errorf("failed to reload global model: %w", err) + } + }() + } + + for m, p := range r.routingTable { + wg.Add(1) + go func(m string, p platform.PlatformEvaluator) { + defer wg.Done() + err := p.ReloadIfNeeded(ctx) + if err != nil { + errCh <- fmt.Errorf("failed to reload model %s: %w", m, err) + } + }(m, p) + } + + wg.Wait() + close(errCh) + + if len(errCh) > 0 { + var errStrings []string + for err := range errCh { + errStrings = append(errStrings, err.Error()) + } + + err = fmt.Errorf("reloading errors: %s", strings.Join(errStrings, "; ")) + } + + r.configLock.RUnlock() + return err + } + + isFullReload = true + + // otherwise just abandon the routing table status checks + + r.configLock.Lock() + defer r.configLock.Unlock() + + r.configModified = snapshot + + // load router configuration file + rawReader, err := r.fs.OpenURL(ctx, r.configURL) + if err != nil { + return fmt.Errorf("failed to open router configuration file: %w", err) + } + + defer rawReader.Close() + var reader io.Reader = rawReader + if strings.HasSuffix(r.configURL, ".gz") { + if reader, err = gzip.NewReader(rawReader); err != nil { + return fmt.Errorf("failed to create gzip reader for router configuration file: %w", err) + } + } + + newConfig := new(router.RoutingConfig) + + // TODO move this check earlier + if strings.Contains(r.configURL, ".yaml") { + decoder := yaml.NewDecoder(reader) + err = decoder.Decode(newConfig) + } else if strings.Contains(r.configURL, ".json") { + err = json.NewDecoder(reader).Decode(newConfig) + } else { + return fmt.Errorf("unsupported router configuration file type: %s", r.configURL) + } + + if err != nil { + return fmt.Errorf("failed to decode router configuration file: %w", err) + } + + if err := r.applyRouterConfig(ctx, newConfig); err != nil { + return err + } + + return nil +} + +// applyRouterConfig will both update evaluators to new configuration state and verify and build the signature +func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.RoutingConfig) error { + modelsToUnload := make(map[string]struct{}) + reuseEvaluators := make(map[string]platform.PlatformEvaluator) + var reuseGlobal platform.PlatformEvaluator + + var finalSignature *domain.Signature + var oldConfig *router.RoutingConfig + func() { + r.routingTableLock.RLock() + defer r.routingTableLock.RUnlock() + if r.ioState != nil { + finalSignature = r.ioState.signature + } + + reuseGlobal = r.globalModel + oldConfig = r.routingConfig + }() + + if oldConfig != nil { + for _, entity := range oldConfig.EntityMapping { + modelsToUnload[entity.ModelName] = struct{}{} + if evaluator, ok := r.routingTable[entity.ModelName]; ok { + reuseEvaluators[entity.ModelName] = evaluator + } + } + + if oldConfig.GlobalModelName != "" { + modelsToUnload[oldConfig.GlobalModelName] = struct{}{} + } + } + + newModelMapping := make(map[int]string) + for _, entity := range newConfig.EntityMapping { + r.debugLogf("add mapping: %d -> %s", entity.EntityID, entity.ModelName) + + newModelMapping[entity.EntityID] = entity.ModelName + delete(modelsToUnload, entity.ModelName) + } + + globalModelName := newConfig.GlobalModelName + if globalModelName == "" && r.hasGlobalModel { + return fmt.Errorf("global model name is missing") + } + + if globalModelName != "" { + r.debugLogf("global model: %s", globalModelName) + delete(modelsToUnload, globalModelName) + } + + newRoutingTable := make(map[string]platform.PlatformEvaluator) + for _, entity := range newConfig.EntityMapping { + model := entity.ModelName + if _, ok := newRoutingTable[model]; ok { + continue + } + + if evaluator, ok := reuseEvaluators[model]; ok { + newRoutingTable[model] = evaluator + continue + } + + evaluator, err := r.makeRoutedEvaluator(model) + + if err != nil { + return fmt.Errorf("failed to create Routed Evaluator for model %s: %w", model, err) + } + + newRoutingTable[model] = evaluator + } + + var globalEvaluator platform.PlatformEvaluator + if globalModelName != "" { + if oldConfig != nil && globalModelName == oldConfig.GlobalModelName && reuseGlobal != nil { + globalEvaluator = reuseGlobal + } else if evaluator, ok := newRoutingTable[globalModelName]; ok { + globalEvaluator = evaluator + } else if evaluator, ok := reuseEvaluators[globalModelName]; ok { + globalEvaluator = evaluator + } else { + var err error + globalEvaluator, err = r.makeRoutedEvaluator(globalModelName) + if err != nil { + return fmt.Errorf("failed to create Routed Evaluator for global model %s: %w", globalModelName, err) + } + } + } + + wg := sync.WaitGroup{} + + numWorkers := len(newRoutingTable) + if globalEvaluator != nil { + numWorkers++ + } + + errCh := make(chan error, numWorkers) + signatureCh := make(chan modelSignature, numWorkers) + + if globalEvaluator != nil { + wg.Add(1) + go func() { + defer wg.Done() + if err := globalEvaluator.ReloadIfNeeded(ctx); err != nil { + errCh <- fmt.Errorf("failed to reload global model %s: %w", globalModelName, err) + } + }() + } + + for model := range newRoutingTable { + wg.Add(1) + go func(model string) { + defer wg.Done() + + r.debugLogf("reload model: %s", model) + + modelEvaluator := newRoutingTable[model] + if err := modelEvaluator.ReloadIfNeeded(ctx); err != nil { + r.debugLogf("failed to reload model: %s: %v", model, err) + errCh <- fmt.Errorf("failed to reload model %s: %w", model, err) + } + + evalSig := modelEvaluator.Signature() + + if evalSig == nil { + errCh <- fmt.Errorf("model %s signature is nil", model) + return + } + + signatureCh <- modelSignature{ + name: model, + signature: evalSig, + } + }(model) + } + + r.debugLogf("wait for reloads") + + wg.Wait() + close(errCh) + close(signatureCh) + + if len(errCh) > 0 { + var errStrings []string + for err := range errCh { + errStrings = append(errStrings, err.Error()) + } + return fmt.Errorf("one or more model reloading errors: %s", strings.Join(errStrings, "; ")) + } + + sigInputMap := make(map[string]*domain.Input) + sigOutputMap := make(map[string]*domain.Output) + + // we only create ioState on the first reload + var ioState *IOState = new(IOState) + if finalSignature != nil { + for _, input := range finalSignature.Inputs { + sigInputMap[input.Name] = &input + } + + for _, output := range finalSignature.Outputs { + sigOutputMap[output.Name] = &output + } + } + + for signature := range signatureCh { + // accept first available signature as the final signature + if finalSignature == nil { + // in the creation of the signature, include the routing input + // DANGER: this uses the pointer to the signature, so since the signature is modified, the original signature will be modified! + // This doesn't happen in practice, but can cause issues in tests. + finalSignature = signature.signature + + inputOffset := len(finalSignature.Inputs) + + routerInput := domain.Input{ + Name: r.routerInputFieldName, + Type: reflect.TypeOf(int64(0)), + } + + ioState.routerInputOffset = inputOffset + + finalSignature.Inputs = append(finalSignature.Inputs, routerInput) + + for _, input := range finalSignature.Inputs { + sigInputMap[input.Name] = &input + } + + for _, input := range r.configuredInputs { + _, ok := sigInputMap[input.Name] + + if ok { + // the input is configured and already in the self-reported signature + continue + } + + if !input.Auxiliary { + return fmt.Errorf("non-auxiliary input %s for model %s was not in model inputs", input.Name, signature.name) + } + + sigInputMap[input.Name] = &domain.Input{ + Name: input.Name, + Type: input.RawType(), + Auxiliary: input.Auxiliary, + } + } + + if r.modelOutputName != "" { + // also, add the selected model output + modelOutput := domain.Output{ + Name: r.modelOutputName, + Index: len(finalSignature.Outputs), + DataType: "string", + } + + finalSignature.Outputs = append(finalSignature.Outputs, modelOutput) + } + + for _, output := range finalSignature.Outputs { + sigOutputMap[output.Name] = &output + } + + continue + } + + thisSignature := signature.signature + // validate signature consistency + thisSignatureOutputMap := make(map[string]*domain.Output) + for _, output := range thisSignature.Outputs { + oldOutput, ok := sigOutputMap[output.Name] + if !ok { + return fmt.Errorf("signature output %s for model %s not found in the previous signature", output.Name, signature.name) + } + + thisSignatureOutputMap[output.Name] = &output + + // TODO permit this + if oldOutput.Index != output.Index { + return fmt.Errorf("signature output %s for model %s has index %d, and the previous signature has index %d", output.Name, signature.name, output.Index, oldOutput.Index) + } + + if oldOutput.DataType != output.DataType { + return fmt.Errorf("signature output %s for model %s has data type %s, and the previous signature has data type %s", output.Name, signature.name, output.DataType, oldOutput.DataType) + } + } + + for expectedOutput := range sigOutputMap { + if _, ok := thisSignatureOutputMap[expectedOutput]; !ok && expectedOutput != r.modelOutputName { + return fmt.Errorf("signature output %s for was not found in model %s signature", expectedOutput, signature.name) + } + } + + thisSignatureInputMap := make(map[string]*domain.Input) + for _, input := range thisSignature.Inputs { + oldInput, ok := sigInputMap[input.Name] + if !ok { + return fmt.Errorf("signature input %s for model %s not found in the previous signature", input.Name, signature.name) + } + + thisSignatureInputMap[input.Name] = &input + + if oldInput.Auxiliary { + continue + } + + // TODO permit this + if oldInput.Index != input.Index { + return fmt.Errorf("signature input %s for model %s has index %d, and the previous signature has index %d", input.Name, signature.name, input.Index, oldInput.Index) + } + + if !oldInput.Type.ConvertibleTo(input.Type) { + return fmt.Errorf("signature input %s for model %s has data type %s, and the previous signature has data type %s", input.Name, signature.name, input.Type.String(), oldInput.Type.String()) + } + } + + for expectedInput := range sigInputMap { + if _, ok := thisSignatureInputMap[expectedInput]; !ok && expectedInput != r.routerInputFieldName { + return fmt.Errorf("signature input %s for was not found in model %s signature", expectedInput, signature.name) + } + } + } + + if r.fixedEvaluatorFields != nil { + // TODO this is actually an acceptable case, but needs to be addressed elsewhere first before it is permitted + for field := range r.fixedEvaluatorFields { + if _, ok := sigOutputMap[field]; !ok { + return fmt.Errorf("fixed evaluator field: %s was not found in the signature outputs", field) + } + } + + for _, field := range sigOutputMap { + if _, ok := r.fixedEvaluatorFields[field.Name]; !ok && field.Name != r.modelOutputName { + return fmt.Errorf("signature output %s is not replaced", field.Name) + } + } + } + + ioState.signature = finalSignature + ioState.inputs = sigInputMap + + if globalEvaluator != nil { + if _, exists := newRoutingTable[globalModelName]; !exists { + newRoutingTable[globalModelName] = globalEvaluator + } + } + + func() { + r.routingTableLock.Lock() + defer r.routingTableLock.Unlock() + + r.routingConfig = newConfig + + r.routingMap = newModelMapping + r.routingTable = newRoutingTable + + r.globalModel = globalEvaluator + + if r.ioState == nil { + r.ioState = ioState + } + }() + + for model := range modelsToUnload { + routerModelUnloadGauge.WithLabelValues(r.routerName).Inc() + + go func(modelName string) { + defer routerModelUnloadGauge.WithLabelValues(r.routerName).Dec() + + ctxTo, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + r.debugLogf("request to unload model: %s", modelName) + + if err := r.unloadModel(ctxTo, modelName); err != nil { + r.debugLogf("failed to unload model %s: %v\n", modelName, err) + } + }(model) + } + + return nil +} + +func (r *Router) debugLogf(format string, args ...interface{}) { + if r.debug { + prefix := "[%s Router] " + log.Printf(prefix+format, append([]interface{}{r.routerName}, args...)...) + } +} + +// TODO refactor with service/tfmodel/service.isModified()? +func (r *Router) isModified(snapshot *config.Modified) bool { + if r.routingConfig == nil || r.configModified == nil { + return true + } + + if snapshot.Max.IsZero() { + return false + } + + r.configLock.RLock() + modified := r.configModified + r.configLock.RUnlock() + + return !(modified.Max.Equal(snapshot.Max) && modified.Min.Equal(snapshot.Min)) +} + +func (r *Router) unloadModel(ctx context.Context, modelName string) error { + defer routerModelUnloadGauge.WithLabelValues(r.routerName).Dec() + if err := r.unloader.UnloadModel(ctx, r.routerName, modelName); err != nil { + return fmt.Errorf("failed to unload model %s: %w", modelName, err) + } + return nil +} diff --git a/service/platform/router/router.go b/service/platform/router/router.go index d585d0f..bcb9a6f 100644 --- a/service/platform/router/router.go +++ b/service/platform/router/router.go @@ -1,28 +1,21 @@ package router import ( - "compress/gzip" "context" - "encoding/json" "fmt" - "io" - "log" - "reflect" - "strings" + "strconv" "sync" "time" "github.com/viant/afs" "github.com/viant/mly/service/config" "github.com/viant/mly/service/domain" - "github.com/viant/mly/service/files" "github.com/viant/mly/service/platform" "github.com/viant/mly/service/request/shape" tricli "github.com/viant/mly/service/triton" "github.com/viant/mly/shared" "github.com/viant/mly/shared/common" "github.com/viant/mly/shared/config/router" - "gopkg.in/yaml.v2" ) type IOState struct { @@ -75,8 +68,6 @@ type Router struct { fixedEvaluatorFields map[string]struct{} outputConfig config.OutputConfig - workCh chan *workRequest - modelOutputName string routerName string @@ -85,6 +76,13 @@ type Router struct { configuredInputs []*shared.Field ioState *IOState + + // forceBatchSize1 when true uses legacy per-sample dispatch; when false (default) uses batched dispatch + forceBatchSize1 bool + // workers limits concurrent model evaluations (used as semaphore capacity in batch mode) + workers int + // maxQueueSize limits queued batches before rejection + maxQueueSize int } // NewRouter creates a new Router instance. @@ -150,12 +148,10 @@ func newRouter(cfg *config.Model, fs afs.Service, unloaders map[string]UnloadSer configuredInputs: cfg.Inputs, routerInputFieldName: cfg.Router.InputName, - } - // spawn worker routines - r.workCh = make(chan *workRequest, cfg.Router.MaxQueueSize) - for i := 0; i < cfg.Router.Workers; i++ { - go handleWorkRequests(r.workCh, routerWorkerChannelQueuedSummary.WithLabelValues(r.routerName)) + forceBatchSize1: cfg.Router.ForceBatchSize1, + workers: cfg.Router.Workers, + maxQueueSize: cfg.Router.MaxQueueSize, } return r, nil @@ -166,8 +162,28 @@ type preparedReplacement struct { value interface{} } -// Predict performs model inference with the given parameters -// params is expected to be [numInputs]([batchSize][1]T) (see service/request.Request.Feeds) +// modelBatch holds accumulated rows destined for a single model evaluator +type modelBatch struct { + evaluator platform.Predictor + modelName string + inputs []interface{} // [numInputs][]interface{} - accumulated batched inputs + rowOffsets []int // original positions in the incoming batch +} + +// batchResult holds the result from a batched model prediction +type batchResult struct { + modelName string + results []interface{} + offsets []int + err error +} + +// Predict performs model inference with the given parameters. +// params is expected to be [numInputs]([batchSize][1]T) (see service/request.Request.Feeds). +// +// Rows are grouped into batches based on their target model evaluator. +// When ForceBatchSize1 is true, each row forms its own batch (batch size 1). +// When ForceBatchSize1 is false (default), rows destined for the same model are grouped together. func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface{}, error) { if len(params) == 0 { return nil, fmt.Errorf("no input parameters provided") @@ -182,7 +198,6 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface } else { fos = "false" } - routerPredictDurationMicrosSummary.WithLabelValues(r.routerName, fos).Observe(float64(time.Since(start).Microseconds())) }() @@ -191,74 +206,52 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface return nil, err } - errCh := make(chan error, expectedBatchSize) - resultsCh := make(chan offsetResults, expectedBatchSize) - - predictWaitGroup := sync.WaitGroup{} - predictWaitGroup.Add(expectedBatchSize) - var signature *domain.Signature + batches := make(map[string]*modelBatch) - err = func() error { - r.routingTableLock.RLock() - defer r.routingTableLock.RUnlock() + // Hold read lock to ensure evaluator references remain valid during prediction. + r.routingTableLock.RLock() + defer r.routingTableLock.RUnlock() + // Phase 1: Group rows into batches + // When forceBatchSize1 is true, each row gets a unique batch key (row offset as string) + // When false, rows are grouped by model name + err = func() error { if r.ioState == nil { return fmt.Errorf("ioState was not initialized") } - // this assignment isn't strictly required to be atomic as it should never change after initialization signature = r.ioState.signature routerInputOffset := r.ioState.routerInputOffset globalExists := r.fixedEvaluator != nil - reportedGlobalModelName := r.outputConfig.GlobalModelOverride noModelName := r.outputConfig.NoModelID numInputs := len(params) + numModelInputs := numInputs - 1 // exclude router input for batchOffset := range expectedBatchSize { - // 1 input is reserved for the router input - request := make([]interface{}, numInputs-1) - - var routingValueBatched interface{} - - // TODO support different input ordering per evaluator - see applyRouterConfig() regarding signatures - for inputOffset := range numInputs { - debatched, err := shape.Debatch(params[inputOffset], batchOffset) - if err != nil { - return fmt.Errorf("failed to debatch for row %d and input %d: %w", batchOffset, inputOffset, err) - } - - if inputOffset < routerInputOffset { - request[inputOffset] = debatched - } else if inputOffset == routerInputOffset { - routingValueBatched = debatched - } else { - request[inputOffset-1] = debatched - } + // Extract routing value for this row + routingValueBatched, err := shape.Debatch(params[routerInputOffset], batchOffset) + if err != nil { + return fmt.Errorf("failed to debatch routing value for row %d: %w", batchOffset, err) } routingValue, err := shape.SqueezeBatch(routingValueBatched) if err != nil { - return fmt.Errorf("failed to extract from batch for row %d: %w", batchOffset, err) + return fmt.Errorf("failed to extract routing value for row %d: %w", batchOffset, err) } - var ok bool = true var routingValueInt int - switch routingValue := routingValue.(type) { + switch rv := routingValue.(type) { case int: - routingValueInt = routingValue + routingValueInt = rv case int32: - routingValueInt = int(routingValue) + routingValueInt = int(rv) case int64: - routingValueInt = int(routingValue) + routingValueInt = int(rv) default: - ok = false - } - - if !ok { return fmt.Errorf("routing value is not an int: %v, is %T, for row %d", routingValue, routingValue, batchOffset) } @@ -268,10 +261,7 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface if !ok { if globalExists { metricFixedOnly = false - // fallback to global model evaluator = r.globalModel - - // override model name if reportedGlobalModelName != "" { routingValueString = reportedGlobalModelName } @@ -281,57 +271,127 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface } } else { metricFixedOnly = false - - var ok bool evaluator, ok = r.routingTable[routingValueString] if !ok { return fmt.Errorf("no evaluator found for routing value: %v", routingValue) } } - select { - case r.workCh <- &workRequest{ - wg: &predictWaitGroup, + // Determine batch key: unique per row when forceBatchSize1, otherwise by model name + batchKey := routingValueString + if r.forceBatchSize1 { + batchKey = routingValueString + "#" + strconv.Itoa(batchOffset) + } - predictor: evaluator, - ctx: ctx, - request: request, + // Get or create batch + batch, exists := batches[batchKey] + if !exists { + batch = &modelBatch{ + evaluator: evaluator, + modelName: routingValueString, + inputs: make([]interface{}, numModelInputs), + rowOffsets: make([]int, 0, 1), + } + batches[batchKey] = batch + } - queuedTime: time.Now(), - offset: batchOffset, - modelOutputEnabled: r.modelOutputName != "", - routingValueString: routingValueString, + // Append this row's inputs to the batch (excluding router input) + inputIdx := 0 + for paramOffset := range numInputs { + if paramOffset == routerInputOffset { + continue + } - responseCh: resultsCh, - errCh: errCh, - }: - // continue - default: - routerPredictDroppedCounter.WithLabelValues(r.routerName).Inc() - return fmt.Errorf("work channel is full") + debatched, err := shape.Debatch(params[paramOffset], batchOffset) + if err != nil { + return fmt.Errorf("failed to debatch for row %d, input %d: %w", batchOffset, paramOffset, err) + } + + batch.inputs[inputIdx], err = shape.AppendRowToBatch(batch.inputs[inputIdx], debatched) + if err != nil { + return fmt.Errorf("failed to append row %d to batch for input %d: %w", batchOffset, paramOffset, err) + } + inputIdx++ } + + batch.rowOffsets = append(batch.rowOffsets, batchOffset) } return nil }() - predictWaitGroup.Wait() - close(errCh) - close(resultsCh) - if err != nil { return nil, err } - for err := range errCh { - return nil, err + // Check queue size limit + if len(batches) > r.maxQueueSize { + routerPredictDroppedCounter.WithLabelValues(r.routerName).Inc() + return nil, fmt.Errorf("too many batches (%d) exceeds max queue size (%d)", len(batches), r.maxQueueSize) + } + + // Phase 2: Execute predictions in parallel with bounded concurrency + resultCh := make(chan batchResult, len(batches)) + semaphore := make(chan struct{}, r.workers) + var wg sync.WaitGroup + + for _, batch := range batches { + wg.Add(1) + go func(b *modelBatch) { + defer wg.Done() + + // Acquire semaphore slot + semaphore <- struct{}{} + defer func() { <-semaphore }() + + // Rely on downstream for timeouts + results, err := b.evaluator.Predict(ctx, b.inputs) + + // Append model name to results if configured + if r.modelOutputName != "" && err == nil { + modelNames := make([][]string, len(b.rowOffsets)) + for i := range modelNames { + modelNames[i] = []string{b.modelName} + } + results = append(results, modelNames) + } + + resultCh <- batchResult{ + modelName: b.modelName, + results: results, + offsets: b.rowOffsets, + err: err, + } + }(batch) } + wg.Wait() + close(resultCh) + + // Phase 3: Reassemble results in original order allResults := make([][]interface{}, expectedBatchSize) - for results := range resultsCh { - allResults[results.offset] = results.results + + for res := range resultCh { + if res.err != nil { + return nil, fmt.Errorf("prediction failed for model %s: %w", res.modelName, res.err) + } + + // Extract individual rows from the batched result and place at original offsets + for localIdx, originalOffset := range res.offsets { + rowResult := make([]interface{}, len(res.results)) + for outputIdx, outputBatch := range res.results { + extracted, err := shape.ExtractRowFromBatch(outputBatch, localIdx) + if err != nil { + return nil, fmt.Errorf("failed to extract row %d from model %s output %d: %w", + localIdx, res.modelName, outputIdx, err) + } + rowResult[outputIdx] = extracted + } + allResults[originalOffset] = rowResult + } } + // Concatenate all rows into final output endResults := make([]interface{}, len(signature.Outputs)) for i, results := range allResults { endResults, err = shape.ConcatAxis0(endResults, results) @@ -366,495 +426,3 @@ func (r *Router) Stats(stats map[string]interface{}) { func (r *Router) Close() error { return nil } - -func (r *Router) debugLogf(format string, args ...interface{}) { - if r.debug { - prefix := "[%s Router] " - log.Printf(prefix+format, append([]interface{}{r.routerName}, args...)...) - } -} - -type modelSignature struct { - name string - signature *domain.Signature -} - -// TODO refactor with service/tfmodel/service.isModified()? -func (r *Router) isModified(snapshot *config.Modified) bool { - if r.routingConfig == nil || r.configModified == nil { - return true - } - - if snapshot.Max.IsZero() { - return false - } - - r.configLock.RLock() - modified := r.configModified - r.configLock.RUnlock() - - return !(modified.Max.Equal(snapshot.Max) && modified.Min.Equal(snapshot.Min)) -} - -func (r *Router) ReloadIfNeeded(ctx context.Context) error { - start := time.Now() - isFullReload := false - defer func() { - var mode string - if isFullReload { - mode = "full" - } else { - mode = "checks" - } - routerReloadDurationMicrosSummary.WithLabelValues(r.routerName, mode).Observe(float64(time.Since(start).Microseconds())) - }() - - // fetch and check router configuration file - snapshot, err := files.ModifiedSnapshot(ctx, r.fs, r.configURL, nil) - if err != nil { - return fmt.Errorf("failed to check router configuration file: %w", err) - } - - if !r.isModified(snapshot) { - // check health of all underlying models - var wg sync.WaitGroup - - r.configLock.RLock() - errChannels := len(r.routingTable) - if r.globalModel != nil { - errChannels++ - } - - errCh := make(chan error, errChannels) - - if r.globalModel != nil { - wg.Add(1) - go func() { - defer wg.Done() - err := r.globalModel.ReloadIfNeeded(ctx) - if err != nil { - errCh <- fmt.Errorf("failed to reload global model: %w", err) - } - }() - } - - for m, p := range r.routingTable { - wg.Add(1) - go func(m string, p platform.PlatformEvaluator) { - defer wg.Done() - err := p.ReloadIfNeeded(ctx) - if err != nil { - errCh <- fmt.Errorf("failed to reload model %s: %w", m, err) - } - }(m, p) - } - - wg.Wait() - close(errCh) - - if len(errCh) > 0 { - var errStrings []string - for err := range errCh { - errStrings = append(errStrings, err.Error()) - } - - err = fmt.Errorf("reloading errors: %s", strings.Join(errStrings, "; ")) - } - - r.configLock.RUnlock() - return err - } - - isFullReload = true - - // otherwise just abandon the routing table status checks - - r.configLock.Lock() - defer r.configLock.Unlock() - - r.configModified = snapshot - - // load router configuration file - rawReader, err := r.fs.OpenURL(ctx, r.configURL) - if err != nil { - return fmt.Errorf("failed to open router configuration file: %w", err) - } - - defer rawReader.Close() - var reader io.Reader = rawReader - if strings.HasSuffix(r.configURL, ".gz") { - if reader, err = gzip.NewReader(rawReader); err != nil { - return fmt.Errorf("failed to create gzip reader for router configuration file: %w", err) - } - } - - newConfig := new(router.RoutingConfig) - - // TODO move this check earlier - if strings.Contains(r.configURL, ".yaml") { - decoder := yaml.NewDecoder(reader) - err = decoder.Decode(newConfig) - } else if strings.Contains(r.configURL, ".json") { - err = json.NewDecoder(reader).Decode(newConfig) - } else { - return fmt.Errorf("unsupported router configuration file type: %s", r.configURL) - } - - if err != nil { - return fmt.Errorf("failed to decode router configuration file: %w", err) - } - - if err := r.applyRouterConfig(ctx, newConfig); err != nil { - return err - } - - return nil -} - -// applyRouterConfig will both update evaluators to new configuration state and verify and build the signature -func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.RoutingConfig) error { - modelsToUnload := make(map[string]struct{}) - reuseEvaluators := make(map[string]platform.PlatformEvaluator) - var reuseGlobal platform.PlatformEvaluator - - var finalSignature *domain.Signature - var oldConfig *router.RoutingConfig - func() { - r.routingTableLock.RLock() - defer r.routingTableLock.RUnlock() - if r.ioState != nil { - finalSignature = r.ioState.signature - } - - reuseGlobal = r.globalModel - oldConfig = r.routingConfig - }() - - if oldConfig != nil { - for _, entity := range oldConfig.EntityMapping { - modelsToUnload[entity.ModelName] = struct{}{} - if evaluator, ok := r.routingTable[entity.ModelName]; ok { - reuseEvaluators[entity.ModelName] = evaluator - } - } - - if oldConfig.GlobalModelName != "" { - modelsToUnload[oldConfig.GlobalModelName] = struct{}{} - } - } - - newModelMapping := make(map[int]string) - for _, entity := range newConfig.EntityMapping { - r.debugLogf("add mapping: %d -> %s", entity.EntityID, entity.ModelName) - - newModelMapping[entity.EntityID] = entity.ModelName - delete(modelsToUnload, entity.ModelName) - } - - globalModelName := newConfig.GlobalModelName - if globalModelName == "" && r.hasGlobalModel { - return fmt.Errorf("global model name is missing") - } - - if globalModelName != "" { - r.debugLogf("global model: %s", globalModelName) - delete(modelsToUnload, globalModelName) - } - - newRoutingTable := make(map[string]platform.PlatformEvaluator) - for _, entity := range newConfig.EntityMapping { - model := entity.ModelName - if _, ok := newRoutingTable[model]; ok { - continue - } - - if evaluator, ok := reuseEvaluators[model]; ok { - newRoutingTable[model] = evaluator - continue - } - - evaluator, err := r.makeRoutedEvaluator(model) - - if err != nil { - return fmt.Errorf("failed to create Routed Evaluator for model %s: %w", model, err) - } - - newRoutingTable[model] = evaluator - } - - var globalEvaluator platform.PlatformEvaluator - if globalModelName != "" { - if oldConfig != nil && globalModelName == oldConfig.GlobalModelName && reuseGlobal != nil { - globalEvaluator = reuseGlobal - } else if evaluator, ok := newRoutingTable[globalModelName]; ok { - globalEvaluator = evaluator - } else if evaluator, ok := reuseEvaluators[globalModelName]; ok { - globalEvaluator = evaluator - } else { - var err error - globalEvaluator, err = r.makeRoutedEvaluator(globalModelName) - if err != nil { - return fmt.Errorf("failed to create Routed Evaluator for global model %s: %w", globalModelName, err) - } - } - } - - wg := sync.WaitGroup{} - - numWorkers := len(newRoutingTable) - if globalEvaluator != nil { - numWorkers++ - } - - errCh := make(chan error, numWorkers) - signatureCh := make(chan modelSignature, numWorkers) - - if globalEvaluator != nil { - wg.Add(1) - go func() { - defer wg.Done() - if err := globalEvaluator.ReloadIfNeeded(ctx); err != nil { - errCh <- fmt.Errorf("failed to reload global model %s: %w", globalModelName, err) - } - }() - } - - for model := range newRoutingTable { - wg.Add(1) - go func(model string) { - defer wg.Done() - - r.debugLogf("reload model: %s", model) - - modelEvaluator := newRoutingTable[model] - if err := modelEvaluator.ReloadIfNeeded(ctx); err != nil { - r.debugLogf("failed to reload model: %s: %v", model, err) - errCh <- fmt.Errorf("failed to reload model %s: %w", model, err) - } - - evalSig := modelEvaluator.Signature() - - if evalSig == nil { - errCh <- fmt.Errorf("model %s signature is nil", model) - return - } - - signatureCh <- modelSignature{ - name: model, - signature: evalSig, - } - }(model) - } - - r.debugLogf("wait for reloads") - - wg.Wait() - close(errCh) - close(signatureCh) - - if len(errCh) > 0 { - var errStrings []string - for err := range errCh { - errStrings = append(errStrings, err.Error()) - } - return fmt.Errorf("one or more model reloading errors: %s", strings.Join(errStrings, "; ")) - } - - sigInputMap := make(map[string]*domain.Input) - sigOutputMap := make(map[string]*domain.Output) - - // we only create ioState on the first reload - var ioState *IOState = new(IOState) - if finalSignature != nil { - for _, input := range finalSignature.Inputs { - sigInputMap[input.Name] = &input - } - - for _, output := range finalSignature.Outputs { - sigOutputMap[output.Name] = &output - } - } - - for signature := range signatureCh { - // accept first available signature as the final signature - if finalSignature == nil { - // in the creation of the signature, include the routing input - // DANGER: this uses the pointer to the signature, so since the signature is modified, the original signature will be modified! - // This doesn't happen in practice, but can cause issues in tests. - finalSignature = signature.signature - - inputOffset := len(finalSignature.Inputs) - - routerInput := domain.Input{ - Name: r.routerInputFieldName, - Type: reflect.TypeOf(int64(0)), - } - - ioState.routerInputOffset = inputOffset - - finalSignature.Inputs = append(finalSignature.Inputs, routerInput) - - for _, input := range finalSignature.Inputs { - sigInputMap[input.Name] = &input - } - - for _, input := range r.configuredInputs { - _, ok := sigInputMap[input.Name] - - if ok { - // the input is configured and already in the self-reported signature - continue - } - - if !input.Auxiliary { - return fmt.Errorf("non-auxiliary input %s for model %s was not in model inputs", input.Name, signature.name) - } - - sigInputMap[input.Name] = &domain.Input{ - Name: input.Name, - Type: input.RawType(), - Auxiliary: input.Auxiliary, - } - } - - if r.modelOutputName != "" { - // also, add the selected model output - modelOutput := domain.Output{ - Name: r.modelOutputName, - Index: len(finalSignature.Outputs), - DataType: "string", - } - - finalSignature.Outputs = append(finalSignature.Outputs, modelOutput) - } - - for _, output := range finalSignature.Outputs { - sigOutputMap[output.Name] = &output - } - - continue - } - - thisSignature := signature.signature - // validate signature consistency - thisSignatureOutputMap := make(map[string]*domain.Output) - for _, output := range thisSignature.Outputs { - oldOutput, ok := sigOutputMap[output.Name] - if !ok { - return fmt.Errorf("signature output %s for model %s not found in the previous signature", output.Name, signature.name) - } - - thisSignatureOutputMap[output.Name] = &output - - // TODO permit this - if oldOutput.Index != output.Index { - return fmt.Errorf("signature output %s for model %s has index %d, and the previous signature has index %d", output.Name, signature.name, output.Index, oldOutput.Index) - } - - if oldOutput.DataType != output.DataType { - return fmt.Errorf("signature output %s for model %s has data type %s, and the previous signature has data type %s", output.Name, signature.name, output.DataType, oldOutput.DataType) - } - } - - for expectedOutput := range sigOutputMap { - if _, ok := thisSignatureOutputMap[expectedOutput]; !ok && expectedOutput != r.modelOutputName { - return fmt.Errorf("signature output %s for was not found in model %s signature", expectedOutput, signature.name) - } - } - - thisSignatureInputMap := make(map[string]*domain.Input) - for _, input := range thisSignature.Inputs { - oldInput, ok := sigInputMap[input.Name] - if !ok { - return fmt.Errorf("signature input %s for model %s not found in the previous signature", input.Name, signature.name) - } - - thisSignatureInputMap[input.Name] = &input - - if oldInput.Auxiliary { - continue - } - - // TODO permit this - if oldInput.Index != input.Index { - return fmt.Errorf("signature input %s for model %s has index %d, and the previous signature has index %d", input.Name, signature.name, input.Index, oldInput.Index) - } - - if !oldInput.Type.ConvertibleTo(input.Type) { - return fmt.Errorf("signature input %s for model %s has data type %s, and the previous signature has data type %s", input.Name, signature.name, input.Type.String(), oldInput.Type.String()) - } - } - - for expectedInput := range sigInputMap { - if _, ok := thisSignatureInputMap[expectedInput]; !ok && expectedInput != r.routerInputFieldName { - return fmt.Errorf("signature input %s for was not found in model %s signature", expectedInput, signature.name) - } - } - } - - if r.fixedEvaluatorFields != nil { - // TODO this is actually an acceptable case, but needs to be addressed elsewhere first before it is permitted - for field := range r.fixedEvaluatorFields { - if _, ok := sigOutputMap[field]; !ok { - return fmt.Errorf("fixed evaluator field: %s was not found in the signature outputs", field) - } - } - - for _, field := range sigOutputMap { - if _, ok := r.fixedEvaluatorFields[field.Name]; !ok && field.Name != r.modelOutputName { - return fmt.Errorf("signature output %s is not replaced", field.Name) - } - } - } - - ioState.signature = finalSignature - ioState.inputs = sigInputMap - - if globalEvaluator != nil { - if _, exists := newRoutingTable[globalModelName]; !exists { - newRoutingTable[globalModelName] = globalEvaluator - } - } - - func() { - r.routingTableLock.Lock() - defer r.routingTableLock.Unlock() - - r.routingConfig = newConfig - - r.routingMap = newModelMapping - r.routingTable = newRoutingTable - - r.globalModel = globalEvaluator - - if r.ioState == nil { - r.ioState = ioState - } - }() - - for model := range modelsToUnload { - routerModelUnloadGauge.WithLabelValues(r.routerName).Inc() - - go func(modelName string) { - defer routerModelUnloadGauge.WithLabelValues(r.routerName).Dec() - - ctxTo, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - r.debugLogf("request to unload model: %s", modelName) - - if err := r.unloadModel(ctxTo, modelName); err != nil { - r.debugLogf("failed to unload model %s: %v\n", modelName, err) - } - }(model) - } - - return nil -} - -func (r *Router) unloadModel(ctx context.Context, modelName string) error { - defer routerModelUnloadGauge.WithLabelValues(r.routerName).Dec() - if err := r.unloader.UnloadModel(ctx, r.routerName, modelName); err != nil { - return fmt.Errorf("failed to unload model %s: %w", modelName, err) - } - return nil -} diff --git a/service/platform/router/router_test.go b/service/platform/router/router_test.go index 035195c..e9bd751 100644 --- a/service/platform/router/router_test.go +++ b/service/platform/router/router_test.go @@ -742,3 +742,340 @@ func TestRouter_applyRouterConfig_sharedTritonServer(t *testing.T) { assert.True(t, tritonServer.readyState["modelB"], "modelB should still be loaded") assert.True(t, tritonServer.readyState["modelC"], "modelC should be loaded") } + +// --- Batched Prediction Tests --- + +// testMockEvaluator is a configurable mock that handles batched inputs and tracks calls +type testMockEvaluator struct { + modelName string + signature *domain.Signature + predictCalls int + mu sync.Mutex + err error // if set, Predict returns this error +} + +func newTestMockEvaluator(name string, sig *domain.Signature) *testMockEvaluator { + return &testMockEvaluator{modelName: name, signature: sig} +} + +func (m *testMockEvaluator) Predict(ctx context.Context, params []interface{}) ([]interface{}, error) { + m.mu.Lock() + m.predictCalls++ + m.mu.Unlock() + + if m.err != nil { + return nil, m.err + } + + if len(params) == 0 { + return nil, fmt.Errorf("no params provided") + } + + // Determine batch size from first input + var batchSize int + switch typed := params[0].(type) { + case [][]string: + batchSize = len(typed) + case [][]int32: + batchSize = len(typed) + case [][]int64: + batchSize = len(typed) + case [][]float32: + batchSize = len(typed) + case [][]float64: + batchSize = len(typed) + default: + return nil, fmt.Errorf("unexpected input type: %T", params[0]) + } + + // Compute output: for each row, output the length of the first string input + results := make([][]float32, batchSize) + for i := 0; i < batchSize; i++ { + var length int + if typed, ok := params[0].([][]string); ok { + length = len(typed[i][0]) + } + results[i] = []float32{float32(length)} + } + + return []interface{}{results}, nil +} + +func (m *testMockEvaluator) Signature() *domain.Signature { return m.signature } +func (m *testMockEvaluator) Dictionary() *common.Dictionary { return nil } +func (m *testMockEvaluator) Inputs() map[string]*domain.Input { return nil } +func (m *testMockEvaluator) Stats(map[string]interface{}) {} +func (m *testMockEvaluator) Close() error { return nil } +func (m *testMockEvaluator) ReloadIfNeeded(ctx context.Context) error { + return nil +} + +func (m *testMockEvaluator) GetPredictCalls() int { + m.mu.Lock() + defer m.mu.Unlock() + return m.predictCalls +} + +// predictTestSetup holds the common test setup for prediction tests +type predictTestSetup struct { + router *Router + evaluators map[string]*testMockEvaluator + signature *domain.Signature +} + +// predictTestOptions configures the test setup +type predictTestOptions struct { + forceBatchSize1 bool + modelOutputName string + modelNames []string // defaults to ["model1", "model2"] +} + +// setupPredictTest creates a router with mock evaluators for prediction testing +func setupPredictTest(t *testing.T, opts predictTestOptions) *predictTestSetup { + t.Helper() + + if len(opts.modelNames) == 0 { + opts.modelNames = []string{"model1", "model2"} + } + + // Create shared signature + downstreamSig := &domain.Signature{ + Inputs: []domain.Input{ + {Name: "text", Index: 0, Type: reflect.TypeOf("")}, + }, + Outputs: []domain.Output{ + {Name: "score", Index: 0, DataType: "float32"}, + }, + } + + // Create evaluators + evaluators := make(map[string]*testMockEvaluator) + for _, name := range opts.modelNames { + evaluators[name] = newTestMockEvaluator(name, downstreamSig) + } + + // Create config + cfg := &config.Model{ + ID: "router_predict_test", + Mode: "router", + Platform: "triton", + Router: &config.RouterConfig{ + ConfigURL: "memory://router-config", + InputName: "router_id", + ForceBatchSize1: opts.forceBatchSize1, + Global: config.GlobalModelConfig{Exists: true}, + }, + Triton: &config.TritonConfig{ServerID: "test_server"}, + } + + if opts.modelOutputName != "" { + cfg.Router.Output.FieldName = opts.modelOutputName + } + + cfg.Init(nil) + cfg.Router.MaxQueueSize = 1000 + cfg.Router.Workers = 10 + + // Create router + router, err := newRouter(cfg, nil, map[string]UnloadService{ + "test_server": &triton.Service{}, + }, func(modelName string) (platform.PlatformEvaluator, error) { + if eval, ok := evaluators[modelName]; ok { + return eval, nil + } + return evaluators[opts.modelNames[0]], nil + }) + if err != nil { + t.Fatalf("newRouter error: %v", err) + } + + // Setup routing map (1->model1, 2->model2, etc.) + router.routingMap = make(map[int]string) + for i, name := range opts.modelNames { + router.routingMap[i+1] = name + } + + // Setup routing table + router.routingTable = make(map[string]platform.PlatformEvaluator) + for name, eval := range evaluators { + router.routingTable[name] = eval + } + + // Build signature outputs + outputs := []domain.Output{{Name: "score", Index: 0, DataType: "float32"}} + if opts.modelOutputName != "" { + outputs = append(outputs, domain.Output{Name: opts.modelOutputName, Index: 1, DataType: "string"}) + } + + // Setup ioState + routerInputName := cfg.Router.InputName + routerInput := domain.Input{Name: routerInputName, Index: 1, Type: reflect.TypeOf(int64(0))} + downstreamInput := domain.Input{Name: "text", Index: 0, Type: reflect.TypeOf("")} + + routerSig := &domain.Signature{ + Inputs: []domain.Input{downstreamInput, routerInput}, + Outputs: outputs, + } + + router.ioState = &IOState{ + inputs: map[string]*domain.Input{ + routerInputName: &routerInput, + downstreamInput.Name: &downstreamInput, + }, + signature: routerSig, + routerInputOffset: 1, + } + + return &predictTestSetup{ + router: router, + evaluators: evaluators, + signature: routerSig, + } +} + +func TestRouter_Predict_BatchingBehavior(t *testing.T) { + tests := []struct { + name string + forceBatchSize1 bool + inputs []string + routingIDs []int64 + wantScores [][]float32 + wantCallCounts map[string]int + }{ + { + name: "batched_groups_by_model", + forceBatchSize1: false, + inputs: []string{"a", "bb", "ccc", "dddd", "eeeee", "ffffff"}, + routingIDs: []int64{1, 2, 1, 2, 1, 2}, // alternating model1/model2 + wantScores: [][]float32{{1}, {2}, {3}, {4}, {5}, {6}}, + wantCallCounts: map[string]int{"model1": 1, "model2": 1}, // each model called once + }, + { + name: "batched_single_model", + forceBatchSize1: false, + inputs: []string{"a", "bb", "ccc", "dddd"}, + routingIDs: []int64{1, 1, 1, 1}, // all to model1 + wantScores: [][]float32{{1}, {2}, {3}, {4}}, + wantCallCounts: map[string]int{"model1": 1, "model2": 0}, + }, + { + name: "force_batch_size_1", + forceBatchSize1: true, + inputs: []string{"a", "bb", "ccc", "dddd"}, + routingIDs: []int64{1, 1, 1, 1}, // all to model1 + wantScores: [][]float32{{1}, {2}, {3}, {4}}, + wantCallCounts: map[string]int{"model1": 4, "model2": 0}, // 4 individual calls + }, + { + name: "force_batch_size_1_multiple_models", + forceBatchSize1: true, + inputs: []string{"a", "bb", "ccc", "dddd"}, + routingIDs: []int64{1, 2, 1, 2}, // alternating + wantScores: [][]float32{{1}, {2}, {3}, {4}}, + wantCallCounts: map[string]int{"model1": 2, "model2": 2}, // 2 calls each + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setup := setupPredictTest(t, predictTestOptions{ + forceBatchSize1: tt.forceBatchSize1, + }) + + // Build input params + stringInputs := make([][]string, len(tt.inputs)) + for i, s := range tt.inputs { + stringInputs[i] = []string{s} + } + routingInputs := make([][]int64, len(tt.routingIDs)) + for i, id := range tt.routingIDs { + routingInputs[i] = []int64{id} + } + + params := []interface{}{stringInputs, routingInputs} + + results, err := setup.router.Predict(context.Background(), params) + if err != nil { + t.Fatalf("Predict error: %v", err) + } + + // Verify scores + scores, ok := results[0].([][]float32) + if !ok { + t.Fatalf("expected [][]float32, got %T", results[0]) + } + if !reflect.DeepEqual(scores, tt.wantScores) { + t.Errorf("scores mismatch:\n got: %v\n want: %v", scores, tt.wantScores) + } + + // Verify call counts + for modelName, wantCount := range tt.wantCallCounts { + if eval, ok := setup.evaluators[modelName]; ok { + gotCount := eval.GetPredictCalls() + if gotCount != wantCount { + t.Errorf("%s call count: got %d, want %d", modelName, gotCount, wantCount) + } + } + } + }) + } +} + +func TestRouter_Predict_WithModelOutput(t *testing.T) { + setup := setupPredictTest(t, predictTestOptions{ + modelOutputName: "model_id", + }) + + params := []interface{}{ + [][]string{{"a"}, {"bb"}, {"ccc"}, {"dddd"}}, + [][]int64{{1}, {2}, {1}, {2}}, + } + + results, err := setup.router.Predict(context.Background(), params) + if err != nil { + t.Fatalf("Predict error: %v", err) + } + + if len(results) != 2 { + t.Fatalf("expected 2 outputs, got %d", len(results)) + } + + // Verify scores + scores, ok := results[0].([][]float32) + if !ok { + t.Fatalf("expected [][]float32, got %T", results[0]) + } + wantScores := [][]float32{{1}, {2}, {3}, {4}} + assert.Equal(t, wantScores, scores, "scores mismatch") + + // Verify model IDs + modelIDs, ok := results[1].([][]string) + if !ok { + t.Fatalf("expected [][]string for model_id, got %T", results[1]) + } + wantModelIDs := [][]string{{"model1"}, {"model2"}, {"model1"}, {"model2"}} + assert.Equal(t, wantModelIDs, modelIDs, "model_id mismatch") +} + +func TestRouter_Predict_ErrorPropagation(t *testing.T) { + setup := setupPredictTest(t, predictTestOptions{ + modelNames: []string{"model1"}, + }) + + // Configure evaluator to return error + setup.evaluators["model1"].err = errors.New("model prediction failed") + + params := []interface{}{ + [][]string{{"a"}, {"bb"}}, + [][]int64{{1}, {1}}, + } + + _, err := setup.router.Predict(context.Background(), params) + if err == nil { + t.Fatal("expected error but got nil") + } + + if !strings.Contains(err.Error(), "model prediction failed") { + t.Errorf("expected error to contain 'model prediction failed', got: %v", err) + } +} diff --git a/service/platform/router/worker.go b/service/platform/router/worker.go deleted file mode 100644 index 09bc073..0000000 --- a/service/platform/router/worker.go +++ /dev/null @@ -1,63 +0,0 @@ -package router - -import ( - "context" - "fmt" - "log" - "sync" - "time" - - "github.com/prometheus/client_golang/prometheus" - "github.com/viant/mly/service/platform" -) - -type workRequest struct { - wg *sync.WaitGroup - - predictor platform.Predictor - ctx context.Context - request []interface{} - - queuedTime time.Time - offset int - modelOutputEnabled bool - routingValueString string - - responseCh chan offsetResults - errCh chan error -} - -type offsetResults struct { - offset int - results []interface{} -} - -func handleWorkRequests(workCh chan *workRequest, observer prometheus.Observer) { - for request := range workCh { - if request == nil { - log.Println("work request is nil, stopping") - break - } - - func(request workRequest) { - observer.Observe(float64(time.Since(request.queuedTime).Microseconds())) - - defer request.wg.Done() - results, err := request.predictor.Predict(request.ctx, request.request) - if err != nil { - request.errCh <- fmt.Errorf("failed to predict for row %d: %w", request.offset, err) - return - } - - if request.modelOutputEnabled { - // TODO fix ordering - results = append(results, [][]string{{request.routingValueString}}) - } - - request.responseCh <- offsetResults{ - offset: request.offset, - results: results, - } - }(*request) - } -} diff --git a/service/request/shape/batch.go b/service/request/shape/batch.go index 3f62604..4351b4a 100644 --- a/service/request/shape/batch.go +++ b/service/request/shape/batch.go @@ -1,6 +1,9 @@ package shape -import "fmt" +import ( + "fmt" + "reflect" +) // DetermineBatchSize determines batch size from a service.Request.Feeds slice. func DetermineBatchSize(inputs []interface{}) (int, error) { @@ -66,7 +69,96 @@ func Debatch(untypedBatch interface{}, i int) (interface{}, error) { return nil, fmt.Errorf("unexpected batch type: %T", untypedBatch) } -// concatAxis0 concatenates two tensors along axis 0 (batch dimension). +// AppendRowToBatch appends a single debatched row to an accumulating batch. +// If the accumulator is nil, it initializes it with the row's type. +// The row is expected to be in debatched form: [][]T with shape [1][1]. +// The accumulator will grow to shape [N][1] after N appends. +func AppendRowToBatch(accumulator interface{}, row interface{}) (interface{}, error) { + if accumulator == nil { + // Initialize with the row (already in correct shape [1][1]) + return row, nil + } + + switch accTyped := accumulator.(type) { + case [][]int32: + rowTyped, ok := row.([][]int32) + if !ok { + return nil, fmt.Errorf("type mismatch: accumulator is [][]int32, row is %T", row) + } + return append(accTyped, rowTyped...), nil + case [][]int64: + rowTyped, ok := row.([][]int64) + if !ok { + return nil, fmt.Errorf("type mismatch: accumulator is [][]int64, row is %T", row) + } + return append(accTyped, rowTyped...), nil + case [][]float32: + rowTyped, ok := row.([][]float32) + if !ok { + return nil, fmt.Errorf("type mismatch: accumulator is [][]float32, row is %T", row) + } + return append(accTyped, rowTyped...), nil + case [][]float64: + rowTyped, ok := row.([][]float64) + if !ok { + return nil, fmt.Errorf("type mismatch: accumulator is [][]float64, row is %T", row) + } + return append(accTyped, rowTyped...), nil + case [][]string: + rowTyped, ok := row.([][]string) + if !ok { + return nil, fmt.Errorf("type mismatch: accumulator is [][]string, row is %T", row) + } + return append(accTyped, rowTyped...), nil + default: + return nil, fmt.Errorf("unsupported accumulator type: %T", accumulator) + } +} + +// ExtractRowFromBatch extracts a single row from a batch at the given index. +// The batch is expected to have shape [N][M] and the result will have shape [1][M]. +func ExtractRowFromBatch(batch interface{}, index int) (interface{}, error) { + switch typedBatch := batch.(type) { + case [][]int32: + if index >= len(typedBatch) { + return nil, fmt.Errorf("index %d out of range for batch of size %d", index, len(typedBatch)) + } + return [][]int32{typedBatch[index]}, nil + case [][]int64: + if index >= len(typedBatch) { + return nil, fmt.Errorf("index %d out of range for batch of size %d", index, len(typedBatch)) + } + return [][]int64{typedBatch[index]}, nil + case [][]float32: + if index >= len(typedBatch) { + return nil, fmt.Errorf("index %d out of range for batch of size %d", index, len(typedBatch)) + } + return [][]float32{typedBatch[index]}, nil + case [][]float64: + if index >= len(typedBatch) { + return nil, fmt.Errorf("index %d out of range for batch of size %d", index, len(typedBatch)) + } + return [][]float64{typedBatch[index]}, nil + case [][]string: + if index >= len(typedBatch) { + return nil, fmt.Errorf("index %d out of range for batch of size %d", index, len(typedBatch)) + } + return [][]string{typedBatch[index]}, nil + default: + return nil, fmt.Errorf("unsupported batch type: %T", batch) + } +} + +// BatchSize returns the batch size (first dimension) of a batch tensor. +func BatchSize(batch interface{}) (int, error) { + val := reflect.ValueOf(batch) + if val.Kind() != reflect.Slice { + return 0, fmt.Errorf("expected slice, got %T", batch) + } + return val.Len(), nil +} + +// ConcatAxis0 concatenates two tensors along axis 0 (batch dimension). func ConcatAxis0(x []interface{}, y []interface{}) ([]interface{}, error) { if len(x) != len(y) { return nil, fmt.Errorf("x and y must have the same length: %d vs %d", len(x), len(y)) diff --git a/service/request/shape/batch_test.go b/service/request/shape/batch_test.go index 9b43628..82706d5 100644 --- a/service/request/shape/batch_test.go +++ b/service/request/shape/batch_test.go @@ -88,3 +88,132 @@ func TestDebatchAndSqueezeBatch_Int64(t *testing.T) { t.Errorf("squeezeBatch() got %v (%T), want 20 (int64)", scalar, scalar) } } + +func TestAppendRowToBatch_String(t *testing.T) { + // Start with nil accumulator + row1 := [][]string{{"hello"}} + acc, err := AppendRowToBatch(nil, row1) + if err != nil { + t.Fatalf("AppendRowToBatch returned error: %v", err) + } + + want1 := [][]string{{"hello"}} + if !reflect.DeepEqual(acc, want1) { + t.Errorf("after first append: got %#v, want %#v", acc, want1) + } + + // Append second row + row2 := [][]string{{"world"}} + acc, err = AppendRowToBatch(acc, row2) + if err != nil { + t.Fatalf("AppendRowToBatch returned error: %v", err) + } + + want2 := [][]string{{"hello"}, {"world"}} + if !reflect.DeepEqual(acc, want2) { + t.Errorf("after second append: got %#v, want %#v", acc, want2) + } + + // Append third row + row3 := [][]string{{"foo"}} + acc, err = AppendRowToBatch(acc, row3) + if err != nil { + t.Fatalf("AppendRowToBatch returned error: %v", err) + } + + want3 := [][]string{{"hello"}, {"world"}, {"foo"}} + if !reflect.DeepEqual(acc, want3) { + t.Errorf("after third append: got %#v, want %#v", acc, want3) + } +} + +func TestAppendRowToBatch_Float32(t *testing.T) { + row1 := [][]float32{{1.5}} + acc, err := AppendRowToBatch(nil, row1) + if err != nil { + t.Fatalf("AppendRowToBatch returned error: %v", err) + } + + row2 := [][]float32{{2.5}} + acc, err = AppendRowToBatch(acc, row2) + if err != nil { + t.Fatalf("AppendRowToBatch returned error: %v", err) + } + + want := [][]float32{{1.5}, {2.5}} + if !reflect.DeepEqual(acc, want) { + t.Errorf("got %#v, want %#v", acc, want) + } +} + +func TestAppendRowToBatch_TypeMismatch(t *testing.T) { + acc := [][]string{{"hello"}} + row := [][]int64{{123}} + + _, err := AppendRowToBatch(acc, row) + if err == nil { + t.Fatal("expected type mismatch error, got nil") + } +} + +func TestExtractRowFromBatch_String(t *testing.T) { + batch := [][]string{{"a"}, {"b"}, {"c"}} + + row, err := ExtractRowFromBatch(batch, 1) + if err != nil { + t.Fatalf("ExtractRowFromBatch returned error: %v", err) + } + + want := [][]string{{"b"}} + if !reflect.DeepEqual(row, want) { + t.Errorf("got %#v, want %#v", row, want) + } +} + +func TestExtractRowFromBatch_Float32(t *testing.T) { + batch := [][]float32{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}} + + row, err := ExtractRowFromBatch(batch, 2) + if err != nil { + t.Fatalf("ExtractRowFromBatch returned error: %v", err) + } + + want := [][]float32{{5.0, 6.0}} + if !reflect.DeepEqual(row, want) { + t.Errorf("got %#v, want %#v", row, want) + } +} + +func TestExtractRowFromBatch_OutOfRange(t *testing.T) { + batch := [][]int64{{1}, {2}} + + _, err := ExtractRowFromBatch(batch, 5) + if err == nil { + t.Fatal("expected out of range error, got nil") + } +} + +func TestBatchSize(t *testing.T) { + tests := []struct { + name string + batch interface{} + want int + }{ + {"string batch", [][]string{{"a"}, {"b"}, {"c"}}, 3}, + {"float32 batch", [][]float32{{1}, {2}}, 2}, + {"int64 batch", [][]int64{{1}}, 1}, + {"empty batch", [][]string{}, 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := BatchSize(tt.batch) + if err != nil { + t.Fatalf("BatchSize returned error: %v", err) + } + if got != tt.want { + t.Errorf("got %d, want %d", got, tt.want) + } + }) + } +} From f918e0bd6c916d68fc09d8c6988769152c7540dc Mon Sep 17 00:00:00 2001 From: David Choi Date: Wed, 28 Jan 2026 13:09:15 -0800 Subject: [PATCH 18/50] Fix double decrement on unload. --- service/platform/router/reload.go | 1 - 1 file changed, 1 deletion(-) diff --git a/service/platform/router/reload.go b/service/platform/router/reload.go index 2d2ba4a..a17c73a 100644 --- a/service/platform/router/reload.go +++ b/service/platform/router/reload.go @@ -505,7 +505,6 @@ func (r *Router) isModified(snapshot *config.Modified) bool { } func (r *Router) unloadModel(ctx context.Context, modelName string) error { - defer routerModelUnloadGauge.WithLabelValues(r.routerName).Dec() if err := r.unloader.UnloadModel(ctx, r.routerName, modelName); err != nil { return fmt.Errorf("failed to unload model %s: %w", modelName, err) } From 6e2e8ad787492bb49d64f4fa38d6566cee714dcb Mon Sep 17 00:00:00 2001 From: David Choi Date: Wed, 28 Jan 2026 14:05:08 -0800 Subject: [PATCH 19/50] Fix nil panics in tests. --- shared/datastore/client/service.go | 1 + shared/datastore/client/service_test.go | 14 +++++++------- shared/datastore/service_test.go | 15 ++++++++++----- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/shared/datastore/client/service.go b/shared/datastore/client/service.go index eecfd70..0bd7c1d 100644 --- a/shared/datastore/client/service.go +++ b/shared/datastore/client/service.go @@ -240,6 +240,7 @@ func NewWithOptionsV2(config *datastore.Connection, gmetrics *gmetric.Service, o } srv.init(options...) + breaker := circut.New(time.Second, srv) srv.Breaker = breaker return srv, srv.connect() diff --git a/shared/datastore/client/service_test.go b/shared/datastore/client/service_test.go index b99136f..64cfdac 100644 --- a/shared/datastore/client/service_test.go +++ b/shared/datastore/client/service_test.go @@ -9,7 +9,7 @@ import ( aero "github.com/aerospike/aerospike-client-go" "github.com/viant/mly/shared/circut" - "golang.org/x/sync/singleflight" + "github.com/viant/mly/shared/config/datastore" ) type MockAero struct { @@ -37,13 +37,13 @@ func TestPut(t *testing.T) { L: mockLock, } - service := &Service{ - Client: mockAero, - group: new(singleflight.Group), - basePolicy: &aero.BasePolicy{ - TotalTimeout: 15 * time.Second, - }, + config := &datastore.Connection{ + ID: "test", } + config.Init() + + service, _ := NewWithOptionsV2(config, nil) + service.Client = mockAero breaker := circut.New(time.Second, service) service.Breaker = breaker diff --git a/shared/datastore/service_test.go b/shared/datastore/service_test.go index 5323cf9..ed69cf8 100644 --- a/shared/datastore/service_test.go +++ b/shared/datastore/service_test.go @@ -3,13 +3,12 @@ package datastore import ( "context" "testing" - "time" "github.com/aerospike/aerospike-client-go" "github.com/stretchr/testify/assert" - "github.com/viant/mly/shared/circut" "github.com/viant/mly/shared/common" + "github.com/viant/mly/shared/config/datastore" "github.com/viant/mly/shared/datastore/client" ) @@ -42,10 +41,16 @@ func TestFromClientMapsRecordAndDoesNotMapHashBin(t *testing.T) { "Field": "value", common.HashBin: 123, }} - clientSvc := &client.Service{ - Client: stubAeroRecord{record: rec}, - Breaker: circut.New(time.Second*10, &stubProber{}), + + config := &datastore.Connection{ + ID: "test", } + config.Init() + + clientSvc, _ := client.New(config) + // ignore err since we're mocking the client and don't need ot actually connect + + clientSvc.Client = stubAeroRecord{record: rec} key := &Key{Namespace: "ns", Set: "set", Value: "key"} type Foo struct { From 3948cbd9f6fdf7c98f32c0660eaf4c66534ef937 Mon Sep 17 00:00:00 2001 From: David Choi Date: Wed, 28 Jan 2026 14:42:18 -0800 Subject: [PATCH 20/50] Fix bug with router registration. --- service/triton/evaluator.go | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/service/triton/evaluator.go b/service/triton/evaluator.go index feb44ef..e287a52 100644 --- a/service/triton/evaluator.go +++ b/service/triton/evaluator.go @@ -39,6 +39,19 @@ type TritonEvaluator struct { // NewTritonEvaluator creates a new Triton evaluator func NewTritonEvaluator(config *config.Model, tritonClients map[string]*Service) (*TritonEvaluator, error) { + evaluator, err := createEvaluator(config, tritonClients) + if err != nil { + return nil, err + } + err = evaluator.registerUsage() + if err != nil { + return nil, fmt.Errorf("failed to register usage for Triton evaluator: %w", err) + } + + return evaluator, nil +} + +func createEvaluator(config *config.Model, tritonClients map[string]*Service) (*TritonEvaluator, error) { var service *Service isPrivateClient := config.URL != "" @@ -81,17 +94,13 @@ func NewTritonEvaluator(config *config.Model, tritonClients map[string]*Service) debug: config.Debug, } - err := evaluator.registerUsage() - if err != nil { - return nil, fmt.Errorf("failed to register usage for Triton evaluator: %w", err) - } - return evaluator, nil + } // Upward dependency, but provides Evaluators as needed for the service/platform/router module. func NewRoutedTritonEvaluator(modelName string, config *config.Model, tritonClients map[string]*Service) (*TritonEvaluator, error) { - evaluator, err := NewTritonEvaluator(config, tritonClients) + evaluator, err := createEvaluator(config, tritonClients) if err != nil { return nil, fmt.Errorf("failed to create Triton Routed evaluator: %w", err) } From 882908f6a6cb0efbf3db703867420e0666236cfb Mon Sep 17 00:00:00 2001 From: David Choi Date: Wed, 28 Jan 2026 14:42:33 -0800 Subject: [PATCH 21/50] Fix whitespace. --- shared/datastore/client/service_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shared/datastore/client/service_test.go b/shared/datastore/client/service_test.go index 64cfdac..c1da7d0 100644 --- a/shared/datastore/client/service_test.go +++ b/shared/datastore/client/service_test.go @@ -41,7 +41,7 @@ func TestPut(t *testing.T) { ID: "test", } config.Init() - + service, _ := NewWithOptionsV2(config, nil) service.Client = mockAero From bf981308b4b0884f82f9fabb26e3b25af43e7858 Mon Sep 17 00:00:00 2001 From: David Choi Date: Thu, 29 Jan 2026 10:46:14 -0800 Subject: [PATCH 22/50] Add e2e for router - Add PID based start for services --- example/e2e/check-port.sh | 46 +++++++++++++++++-- example/e2e/regression/app.yaml | 5 +- .../e2e/regression/cases/012_router/test.yaml | 6 +-- example/e2e/start-with-pid.sh | 45 ++++++++++++++++++ example/e2e/system.yaml | 2 + example/server/etc/config.yaml | 40 +++++++++++++++- example/server/etc/router.yaml | 5 ++ 7 files changed, 136 insertions(+), 13 deletions(-) create mode 100644 example/e2e/start-with-pid.sh create mode 100644 example/server/etc/router.yaml diff --git a/example/e2e/check-port.sh b/example/e2e/check-port.sh index dbd1de0..d76b0f5 100644 --- a/example/e2e/check-port.sh +++ b/example/e2e/check-port.sh @@ -4,26 +4,62 @@ set -x -e ADDR=$1 if [ -z "$ADDR" ]; then - echo "usage: $0 ADDRESS [TIMES] [SLEEP]" + echo "usage: $0 ADDRESS [TIMES] [SLEEP] [PIDFILE]" echo "address required" exit 2 fi TIMES=${2:-30} SLEEP=${3:-1} +PIDFILE=${4:-} + +# Check if process from PID file is still running +# Returns 0 if no pidfile specified, process is running, or pidfile doesn't exist yet +# Returns 1 if pidfile exists but process is dead +check_pid() { + if [ -z "$PIDFILE" ]; then + return 0 + fi + if [ ! -f "$PIDFILE" ]; then + # PID file doesn't exist yet - service may still be starting + return 0 + fi + local pid + pid=$(cat "$PIDFILE" 2>/dev/null) + if [ -z "$pid" ]; then + return 0 + fi + if kill -0 "$pid" 2>/dev/null; then + return 0 + fi + echo "process $pid from $PIDFILE is no longer running" + return 1 +} -LOOPS=0 ERROR=1 set +x for i in $(seq $TIMES); do - sleep 1 + sleep "$SLEEP" + + # Early exit if monitored process died + if ! check_pid; then + echo "service process terminated before becoming ready" + exit 3 + fi + set +e - curl $1 &>/dev/null + curl "$ADDR" &>/dev/null ERROR=$? set -e if [ $ERROR -eq 0 ]; then break fi -done +done + +# Final PID check - ensure process is still alive even if curl succeeded +if [ $ERROR -eq 0 ] && ! check_pid; then + echo "service process terminated" + exit 3 +fi exit $ERROR diff --git a/example/e2e/regression/app.yaml b/example/e2e/regression/app.yaml index 50d0e23..9ba9e83 100644 --- a/example/e2e/regression/app.yaml +++ b/example/e2e/regression/app.yaml @@ -18,7 +18,7 @@ pipeline: directory: /tmp/e2e checkError: true immuneToHangups: true - command: ./mly-endly -c=/tmp/e2e/config.yaml 2>&1 >/tmp/e2e/mly-endly.log + command: bash ${appPath}/example/e2e/start-with-pid.sh /tmp/e2e/mly-endly.pid ./mly-endly -c=/tmp/e2e/config.yaml >/tmp/e2e/mly-endly.log 2>&1 env: DEBUG: 'true' LD_LIBRARY_PATH: /usr/local/lib @@ -28,5 +28,4 @@ pipeline: target: $target checkError: true commands: - - bash ${appPath}/example/e2e/check-port.sh localhost:8086 30 - + - bash ${appPath}/example/e2e/check-port.sh localhost:8086 30 1s /tmp/e2e/mly-endly.pid diff --git a/example/e2e/regression/cases/012_router/test.yaml b/example/e2e/regression/cases/012_router/test.yaml index c5d50df..5008960 100644 --- a/example/e2e/regression/cases/012_router/test.yaml +++ b/example/e2e/regression/cases/012_router/test.yaml @@ -7,8 +7,8 @@ pipeline: target: $target checkError: true commands: - - /tmp/e2e/mlyc -m sli_triton -a 'sa:b;sl:a;aux:c1' - - /tmp/e2e/mlyc -m sli_triton -a 'sa:b;sl:a;aux:c2' + - /tmp/e2e/mlyc -m sli_router -a 'routing_id|int:1;sa:b;sl:a' + - /tmp/e2e/mlyc -m sli_router -a 'routing_id|int:2;sa:b;sl:a' # assert: # action: validator:assert @@ -19,5 +19,3 @@ pipeline: # action: validator:assert # actual: $AsJSON($test.Cmd[1].Stdout) # expect: $LoadJSON('${parentPath}/expect-cache.json') - - diff --git a/example/e2e/start-with-pid.sh b/example/e2e/start-with-pid.sh new file mode 100644 index 0000000..b98e52b --- /dev/null +++ b/example/e2e/start-with-pid.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +# Starts a command in the background and writes its PID to a file. +# Useful for monitoring service startup with check-port.sh +# +# Usage: start-with-pid.sh PIDFILE COMMAND [ARGS...] +# +# Example: +# ./start-with-pid.sh /tmp/myservice.pid ./myservice --port 8080 +# ./check-port.sh http://localhost:8080 30 1 /tmp/myservice.pid + +set -e + +PIDFILE=$1 +if [ -z "$PIDFILE" ]; then + echo "usage: $0 PIDFILE COMMAND [ARGS...]" + echo "pidfile path required" + exit 2 +fi +shift + +if [ $# -eq 0 ]; then + echo "usage: $0 PIDFILE COMMAND [ARGS...]" + echo "command required" + exit 2 +fi + +# Ensure parent directory exists +PIDDIR=$(dirname "$PIDFILE") +if [ ! -d "$PIDDIR" ]; then + mkdir -p "$PIDDIR" +fi + +# Remove stale PID file if it exists +rm -f "$PIDFILE" + +# Start the command in the background +"$@" & +PID=$! + +# Write PID to file +echo "$PID" > "$PIDFILE" + +echo "started process $PID, pidfile: $PIDFILE" +echo "command: $*" diff --git a/example/e2e/system.yaml b/example/e2e/system.yaml index a739970..c9eb299 100644 --- a/example/e2e/system.yaml +++ b/example/e2e/system.yaml @@ -25,6 +25,8 @@ pipeline: - cp -r ${appPath}/example/model/string_lookups_int_model/* ${modelRepo}/sli_1/1/model.savedmodel - mkdir -p ${modelRepo}/sli_2/1/model.savedmodel - cp -r ${appPath}/example/model/string_lookups_int_model/* ${modelRepo}/sli_2/1/model.savedmodel + - mkdir -p ${modelRepo}/sli_exp/1/model.savedmodel + - cp -r ${appPath}/example/model/string_lookups_int_model/* ${modelRepo}/sli_exp/1/model.savedmodel - mkdir -p ${modelRepo}/r2/1/model.savedmodel - cp -r ${appPath}/example/model/r2_output_model/* ${modelRepo}/r2/1/model.savedmodel - mkdir -p ${modelRepo}/ko/1/model.savedmodel diff --git a/example/server/etc/config.yaml b/example/server/etc/config.yaml index 84316fe..fd3bce9 100644 --- a/example/server/etc/config.yaml +++ b/example/server/etc/config.yaml @@ -26,7 +26,7 @@ models: - id: sli_triton_grpc_explicit platform: triton triton: - modelName: sli_2 + modelName: sli_exp serverID: triton_grpc_explicit useDict: true debug: false @@ -47,6 +47,44 @@ models: - name: expand dataType: int + - id: sli_router + platform: triton + mode: router + triton: + serverID: triton_grpc_explicit + Router: + ConfigURL: ${appPath}/example/server/etc/router.yaml + InputName: routing_id + Workers: 2 + Global: + Exists: false + PredictionReplacements: + - Name: expand + Type: int64 + Value: 0 + Output: + FieldName: model_id + Datastore: sli_triton + useDict: true + Inputs: + - name: routing_id + datatype: int64 + wildcard: true + - name: sl + datatype: string + wildcard: true + - name: sa + datatype: string + wildcard: true + Outputs: + - name: expand + datatype: int64 + - name: model_id + datatype: string + + Test: + SingleBatch: true + - id: ko_triton platform: triton triton: diff --git a/example/server/etc/router.yaml b/example/server/etc/router.yaml new file mode 100644 index 0000000..9d06850 --- /dev/null +++ b/example/server/etc/router.yaml @@ -0,0 +1,5 @@ +entityMapping: + - entityID: 1 + modelName: sli_1 + - entityID: 2 + modelName: sli_2 \ No newline at end of file From 7af71fd4aaca95edf909a48cbf2ddd7ccc1b1724 Mon Sep 17 00:00:00 2001 From: David Choi Date: Thu, 29 Jan 2026 10:48:07 -0800 Subject: [PATCH 23/50] Add support for reordered IO in downstream Evaluators for router. --- service/platform/router/fixed.go | 17 + service/platform/router/reload.go | 28 +- service/platform/router/reload_test.go | 440 ++++++++++++ service/platform/router/router.go | 153 ++-- service/platform/router/router_test.go | 939 +++++++++++++------------ 5 files changed, 1073 insertions(+), 504 deletions(-) create mode 100644 service/platform/router/reload_test.go diff --git a/service/platform/router/fixed.go b/service/platform/router/fixed.go index 216d739..14660cd 100644 --- a/service/platform/router/fixed.go +++ b/service/platform/router/fixed.go @@ -8,10 +8,26 @@ import ( "github.com/viant/mly/service/request/shape" ) +// preparedReplacement holds a pre-parsed replacement value for fixed evaluator outputs +type preparedReplacement struct { + name string + typ string + value interface{} +} + type fixedEvaluator struct { prepared []preparedReplacement } +// OutputNames returns the output names in the order they will be returned by Predict +func (f *fixedEvaluator) OutputNames() []string { + names := make([]string, len(f.prepared)) + for i, p := range f.prepared { + names[i] = p.name + } + return names +} + func newFixedEvaluator(repls []config.PredictionReplacement) (*fixedEvaluator, error) { prepared := make([]preparedReplacement, 0, len(repls)) @@ -105,6 +121,7 @@ func newFixedEvaluator(repls []config.PredictionReplacement) (*fixedEvaluator, e return fmt.Errorf("unsupported router replacement type %q for %q", r.Type, r.Name) } + pr.name = r.Name prepared = append(prepared, pr) } diff --git a/service/platform/router/reload.go b/service/platform/router/reload.go index a17c73a..61c491d 100644 --- a/service/platform/router/reload.go +++ b/service/platform/router/reload.go @@ -307,16 +307,20 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Routin for signature := range signatureCh { // accept first available signature as the final signature if finalSignature == nil { - // in the creation of the signature, include the routing input - // DANGER: this uses the pointer to the signature, so since the signature is modified, the original signature will be modified! - // This doesn't happen in practice, but can cause issues in tests. - finalSignature = signature.signature + srcSig := signature.signature + finalSignature = &domain.Signature{ + Inputs: make([]domain.Input, len(srcSig.Inputs), len(srcSig.Inputs)+1), + Outputs: make([]domain.Output, len(srcSig.Outputs)), + } + copy(finalSignature.Inputs, srcSig.Inputs) + copy(finalSignature.Outputs, srcSig.Outputs) inputOffset := len(finalSignature.Inputs) routerInput := domain.Input{ - Name: r.routerInputFieldName, - Type: reflect.TypeOf(int64(0)), + Name: r.routerInputFieldName, + Type: reflect.TypeOf(int64(0)), + Index: inputOffset, } ioState.routerInputOffset = inputOffset @@ -375,11 +379,7 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Routin thisSignatureOutputMap[output.Name] = &output - // TODO permit this - if oldOutput.Index != output.Index { - return fmt.Errorf("signature output %s for model %s has index %d, and the previous signature has index %d", output.Name, signature.name, output.Index, oldOutput.Index) - } - + // Note: Index differences are permitted - outputs are matched by name if oldOutput.DataType != output.DataType { return fmt.Errorf("signature output %s for model %s has data type %s, and the previous signature has data type %s", output.Name, signature.name, output.DataType, oldOutput.DataType) } @@ -404,11 +404,7 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Routin continue } - // TODO permit this - if oldInput.Index != input.Index { - return fmt.Errorf("signature input %s for model %s has index %d, and the previous signature has index %d", input.Name, signature.name, input.Index, oldInput.Index) - } - + // Note: Index differences are permitted - inputs are reordered by name at dispatch time if !oldInput.Type.ConvertibleTo(input.Type) { return fmt.Errorf("signature input %s for model %s has data type %s, and the previous signature has data type %s", input.Name, signature.name, input.Type.String(), oldInput.Type.String()) } diff --git a/service/platform/router/reload_test.go b/service/platform/router/reload_test.go new file mode 100644 index 0000000..1b6d780 --- /dev/null +++ b/service/platform/router/reload_test.go @@ -0,0 +1,440 @@ +package router + +import ( + "context" + "errors" + "reflect" + "strings" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/viant/mly/service/config" + "github.com/viant/mly/service/domain" + "github.com/viant/mly/service/platform" + "github.com/viant/mly/service/triton" + sharedrouter "github.com/viant/mly/shared/config/router" +) + +func TestRouter_applyRouterConfig_LoadsAndSwaps(t *testing.T) { + ctx := context.Background() + mockClient := &mockUnloader{ + unloadCh: make(chan string, 2), + } + + oldConfig := &sharedrouter.RoutingConfig{ + EntityMapping: []sharedrouter.EntityKV{ + {EntityID: 1, ModelName: "modelA"}, + {EntityID: 2, ModelName: "modelB"}, + }, + GlobalModelName: "global-old", + } + + makeSig := func() *domain.Signature { + return &domain.Signature{ + Inputs: []domain.Input{ + {Name: "text", Index: 0, Type: reflect.TypeOf("")}, + }, + Outputs: []domain.Output{ + {Name: "score", Index: 0, DataType: "float32"}, + }, + } + } + + router := &Router{ + unloader: &triton.Service{Unloader: mockClient, Repository: triton.NewRepository()}, + routingConfig: oldConfig, + routingMap: map[int]string{ + 1: "modelA", + 2: "modelB", + }, + routingTable: map[string]platform.PlatformEvaluator{ + "modelA": &mockEvaluator{signature: makeSig}, + "modelB": &mockEvaluator{signature: makeSig}, + }, + globalModel: &mockEvaluator{}, + debug: true, + makeRoutedEvaluator: func(modelName string) (platform.PlatformEvaluator, error) { + return &mockEvaluator{signature: makeSig}, nil + }, + } + + reusedModelB := router.routingTable["modelB"] + + newConfig := &sharedrouter.RoutingConfig{ + EntityMapping: []sharedrouter.EntityKV{ + {EntityID: 1, ModelName: "modelB"}, + {EntityID: 3, ModelName: "modelC"}, + }, + GlobalModelName: "global-new", + } + + if err := router.applyRouterConfig(ctx, newConfig); err != nil { + t.Fatalf("applyRouterConfig returned error: %v", err) + } + + waitForCalls(t, mockClient.unloadCh, 2) + + if router.routingConfig != newConfig { + t.Fatalf("routerConfig pointer not updated") + } + + expectedRouting := map[int]string{ + 1: "modelB", + 3: "modelC", + } + + if !reflect.DeepEqual(router.routingMap, expectedRouting) { + t.Fatalf("routingMap mismatch, got %#v", router.routingMap) + } + + if router.globalModel == nil { + t.Fatalf("globalModel was not set") + } + + if _, ok := router.routingTable["modelB"]; !ok { + t.Fatalf("routingTable missing modelB") + } + + if router.routingTable["modelB"] != reusedModelB { + t.Fatalf("modelB evaluator was not reused") + } + + if _, ok := router.routingTable["modelC"]; !ok { + t.Fatalf("routingTable missing modelC") + } + + if _, ok := router.routingTable["modelA"]; ok { + t.Fatalf("routingTable still contains modelA") + } +} + +func TestRouter_applyRouterConfig_LoadError(t *testing.T) { + ctx := context.Background() + + loadErr := errors.New("load failure") + + tritonServer := new(mockTritonServer) + tritonServer.modelLoadErr = map[string]error{ + "modelX": loadErr, + } + + mockClient := &mockUnloader{} + + oldConfig := &sharedrouter.RoutingConfig{ + EntityMapping: []sharedrouter.EntityKV{ + {EntityID: 1, ModelName: "modelA"}, + }, + } + + signature := &domain.Signature{ + Inputs: []domain.Input{ + {Name: "text", Index: 0, Type: reflect.TypeOf("")}, + }, + Outputs: []domain.Output{ + {Name: "score", Index: 0, DataType: "float32"}, + }, + } + + router := &Router{ + debug: true, + routerName: "load_error", + + unloader: &triton.Service{Unloader: mockClient}, + routingConfig: oldConfig, + routingMap: map[int]string{ + 1: "modelA", + }, + + makeRoutedEvaluator: func(modelName string) (platform.PlatformEvaluator, error) { + return &mockEvaluator{ + tritonServer: tritonServer, + modelName: modelName, + signature: func() *domain.Signature { return signature }, + }, nil + }, + } + + newConfig := &sharedrouter.RoutingConfig{ + EntityMapping: []sharedrouter.EntityKV{ + {EntityID: 2, ModelName: "modelX"}, + }, + } + + err := router.applyRouterConfig(ctx, newConfig) + if err == nil { + t.Fatalf("expected error but got nil") + } + + if !strings.Contains(err.Error(), "modelX") { + t.Fatalf("expected error mentioning modelX, got %v", err) + } + + if router.routingConfig != oldConfig { + t.Fatalf("routerConfig should remain unchanged on error") + } + + if !reflect.DeepEqual(router.routingMap, map[int]string{1: "modelA"}) { + t.Fatalf("routingMap should remain unchanged on error") + } + + if router.routingTable != nil { + t.Fatalf("routingTable should not be replaced on error") + } +} + +func TestRouter_applyRouterConfig_signature(t *testing.T) { + ctx := context.Background() + mockClient := &mockUnloader{ + unloadCh: make(chan string, 1), + tritonServer: &mockTritonServer{}, + } + + makeSig := func() *domain.Signature { + return &domain.Signature{ + Inputs: []domain.Input{ + {Name: "text", Index: 0, Type: reflect.TypeOf("")}, + }, + Outputs: []domain.Output{ + {Name: "score", Index: 0, DataType: "float32"}, + }, + } + } + + cfg := &config.Model{ + ID: "test_signature", + Debug: true, + Mode: "router", + Platform: "triton", + Router: &config.RouterConfig{ + ConfigURL: "memory://router-config", + InputName: "router_id", + Global: config.GlobalModelConfig{ + Exists: true, + }, + Output: config.OutputConfig{ + FieldName: "model_id", + }, + }, + Triton: &config.TritonConfig{ + ServerID: "test_server", + }, + } + + cfg.Init(nil) + + var router *Router + var err error + + router, err = newRouter(cfg, nil, map[string]UnloadService{ + "test_server": &triton.Service{Unloader: mockClient}, + }, func(modelName string) (platform.PlatformEvaluator, error) { + return &mockEvaluator{signature: makeSig}, nil + }) + + if err != nil { + t.Fatalf("NewRouter error: %v", err) + } + + newConfig := &sharedrouter.RoutingConfig{ + EntityMapping: []sharedrouter.EntityKV{ + {EntityID: 1, ModelName: "model1"}, + {EntityID: 2, ModelName: "model2"}, + }, + GlobalModelName: "global-model", + } + + if err := router.applyRouterConfig(ctx, newConfig); err != nil { + t.Fatalf("applyRouterConfig returned error: %v", err) + } + + detectedInputs := map[string]struct{}{} + for _, input := range router.ioState.signature.Inputs { + detectedInputs[input.Name] = struct{}{} + } + + expectedInputs := map[string]struct{}{ + "text": {}, + "router_id": {}, + } + + assert.Equal(t, expectedInputs, detectedInputs) + assert.Equal(t, 1, router.ioState.routerInputOffset) + + params := []interface{}{ + [][]string{{"a"}, {"abcd"}}, // text + [][]int64{{1}, {2}}, // router_id + } + + results, err := router.Predict(ctx, params) + if err != nil { + t.Fatalf("Predict error: %v", err) + } + + assert.Equal(t, 2, len(results)) +} + +type wrappedUnloader struct { + tritonService *triton.Service + + wg *sync.WaitGroup +} + +func (w *wrappedUnloader) UnloadModel(ctx context.Context, mlyModelID string, tritonModelName string) error { + defer w.wg.Done() + return w.tritonService.UnloadModel(ctx, mlyModelID, tritonModelName) +} + +func TestRouter_applyRouterConfig_sharedTritonServer(t *testing.T) { + tritonServer := &mockTritonServer{} + modelUnloader := &mockUnloader{ + tritonServer: tritonServer, + } + + repository := triton.NewRepository() + tritonService := &triton.Service{ + Unloader: modelUnloader, + Repository: repository, + } + + wrappedService := &wrappedUnloader{ + tritonService: tritonService, + wg: &sync.WaitGroup{}, + } + + unloaders := map[string]UnloadService{ + "test_server": wrappedService, + } + + cfgA := &config.Model{ + ID: "test_shared_a", + Debug: true, + Mode: "router", + Platform: "triton", + Router: &config.RouterConfig{ + ConfigURL: "memory://router-config", + Global: config.GlobalModelConfig{ + PredictionReplacements: []config.PredictionReplacement{ + { + Name: "score", + Type: "float32", + Value: 0.0, + }, + }, + }, + }, + Triton: &config.TritonConfig{ + ServerID: "test_server", + }, + } + + cfgB := &config.Model{ + ID: "test_shared_b", + Debug: true, + Mode: "router", + Platform: "triton", + Router: &config.RouterConfig{ + ConfigURL: "memory://router-config", + Global: config.GlobalModelConfig{ + PredictionReplacements: []config.PredictionReplacement{ + { + Name: "score", + Type: "float32", + Value: 0.0, + }, + }, + }, + }, + Triton: &config.TritonConfig{ + ServerID: "test_server", + }, + } + + cfgA.Init(nil) + cfgB.Init(nil) + + newSig := func() *domain.Signature { + return &domain.Signature{ + Inputs: []domain.Input{ + {Name: "text", Index: 0, Type: reflect.TypeOf("")}, + }, + Outputs: []domain.Output{ + {Name: "score", Index: 0, DataType: "float32"}, + }, + } + } + + routerA, err := newRouter(cfgA, nil, unloaders, func(modelName string) (platform.PlatformEvaluator, error) { + tritonService.RegisterUsage(cfgA.ID, modelName) + return &mockEvaluator{modelName: modelName, signature: newSig, tritonServer: tritonServer}, nil + }) + + if err != nil { + t.Fatalf("newRouter returned error: %v", err) + } + + routerB, err := newRouter(cfgB, nil, unloaders, func(modelName string) (platform.PlatformEvaluator, error) { + tritonService.RegisterUsage(cfgB.ID, modelName) + return &mockEvaluator{modelName: modelName, signature: newSig, tritonServer: tritonServer}, nil + }) + + if err != nil { + t.Fatalf("newRouter returned error: %v", err) + } + + ctx := context.Background() + + // establish mappings for A and B using the same models + + if err = routerA.applyRouterConfig(ctx, &sharedrouter.RoutingConfig{ + EntityMapping: []sharedrouter.EntityKV{ + {EntityID: 1, ModelName: "modelA"}, + {EntityID: 2, ModelName: "modelC"}, + }, + }); err != nil { + t.Fatalf("applyRouterConfig A initial returned error: %v", err) + } + + if err = routerB.applyRouterConfig(ctx, &sharedrouter.RoutingConfig{ + EntityMapping: []sharedrouter.EntityKV{ + {EntityID: 1, ModelName: "modelA"}, + }, + }); err != nil { + t.Fatalf("applyRouterConfig B inital returned error: %v", err) + } + + // we expect model A and model C to be attempted to be unloaded + wrappedService.wg.Add(2) + + // routerA will now get a different mapping + if err = routerA.applyRouterConfig(ctx, &sharedrouter.RoutingConfig{ + EntityMapping: []sharedrouter.EntityKV{ + {EntityID: 1, ModelName: "modelB"}, + }, + }); err != nil { + t.Fatalf("applyRouterConfig A reload returned error: %v", err) + } + + wrappedService.wg.Wait() + + assert.True(t, tritonServer.readyState["modelA"], "modelA should still be loaded") + assert.False(t, tritonServer.readyState["modelC"], "modelC should be unloaded") + + // we expect model A to be attempted to be unloaded + wrappedService.wg.Add(1) + + if err = routerB.applyRouterConfig(ctx, &sharedrouter.RoutingConfig{ + EntityMapping: []sharedrouter.EntityKV{ + {EntityID: 1, ModelName: "modelB"}, + {EntityID: 2, ModelName: "modelC"}, + }, + }); err != nil { + t.Fatalf("applyRouterConfig B reload returned error: %v", err) + } + + wrappedService.wg.Wait() + + assert.False(t, tritonServer.readyState["modelA"], "modelA should be unloaded") + assert.True(t, tritonServer.readyState["modelB"], "modelB should still be loaded") + assert.True(t, tritonServer.readyState["modelC"], "modelC should be loaded") +} diff --git a/service/platform/router/router.go b/service/platform/router/router.go index bcb9a6f..1bae49b 100644 --- a/service/platform/router/router.go +++ b/service/platform/router/router.go @@ -22,7 +22,7 @@ type IOState struct { inputs map[string]*domain.Input signature *domain.Signature - // router input offset is the index of the router input in the inputs array + // router input offset is the index of the routing input in the inputs array routerInputOffset int } @@ -62,7 +62,7 @@ type Router struct { globalModel platform.PlatformEvaluator // fixedEvaluator is non-nil IFF there is no global model configured - fixedEvaluator platform.Predictor + fixedEvaluator *fixedEvaluator // fixedEvaluatorFields is for checking all outputs in the signature are replaced fixedEvaluatorFields map[string]struct{} @@ -157,38 +157,36 @@ func newRouter(cfg *config.Model, fs afs.Service, unloaders map[string]UnloadSer return r, nil } -type preparedReplacement struct { - typ string - value interface{} -} - // modelBatch holds accumulated rows destined for a single model evaluator type modelBatch struct { - evaluator platform.Predictor - modelName string - inputs []interface{} // [numInputs][]interface{} - accumulated batched inputs - rowOffsets []int // original positions in the incoming batch + evaluator platform.PlatformEvaluator // need Signature() for input reordering + isFixedEval bool // true if using fixedEvaluator (no reordering needed) + modelName string + inputsByName map[string]interface{} // keyed by input name - accumulated batched inputs + rowOffsets []int // original positions in the incoming batch } // batchResult holds the result from a batched model prediction type batchResult struct { - modelName string - results []interface{} - offsets []int - err error + modelName string + results []interface{} + offsets []int + err error + outputNames []string // output names in the order returned by evaluator (for reordering) } // Predict performs model inference with the given parameters. // params is expected to be [numInputs]([batchSize][1]T) (see service/request.Request.Feeds). // // Rows are grouped into batches based on their target model evaluator. -// When ForceBatchSize1 is true, each row forms its own batch (batch size 1). -// When ForceBatchSize1 is false (default), rows destined for the same model are grouped together. +// When forceBatchSize1 is true, each row forms its own batch (batch size 1). +// When forceBatchSize1 is false (default), rows destined for the same model are batched together. func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface{}, error) { if len(params) == 0 { return nil, fmt.Errorf("no input parameters provided") } + // metricFixedOnly is true if the request is only using the fixedEvaluator metricFixedOnly := true start := time.Now() defer func() { @@ -213,9 +211,7 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface r.routingTableLock.RLock() defer r.routingTableLock.RUnlock() - // Phase 1: Group rows into batches - // When forceBatchSize1 is true, each row gets a unique batch key (row offset as string) - // When false, rows are grouped by model name + // Phase 1: Group rows into batches by name err = func() error { if r.ioState == nil { return fmt.Errorf("ioState was not initialized") @@ -224,12 +220,11 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface signature = r.ioState.signature routerInputOffset := r.ioState.routerInputOffset - globalExists := r.fixedEvaluator != nil + hasFixedEvaluator := r.fixedEvaluator != nil reportedGlobalModelName := r.outputConfig.GlobalModelOverride noModelName := r.outputConfig.NoModelID numInputs := len(params) - numModelInputs := numInputs - 1 // exclude router input for batchOffset := range expectedBatchSize { // Extract routing value for this row @@ -257,17 +252,19 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface routingValueString, ok := r.routingMap[routingValueInt] - var evaluator platform.Predictor + var evaluator platform.PlatformEvaluator + isFixedEval := false if !ok { - if globalExists { + if hasFixedEvaluator { + // No global model, use fixed evaluator (returns constant values) + routingValueString = noModelName + isFixedEval = true + } else { metricFixedOnly = false evaluator = r.globalModel if reportedGlobalModelName != "" { routingValueString = reportedGlobalModelName } - } else { - routingValueString = noModelName - evaluator = r.fixedEvaluator } } else { metricFixedOnly = false @@ -287,31 +284,33 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface batch, exists := batches[batchKey] if !exists { batch = &modelBatch{ - evaluator: evaluator, - modelName: routingValueString, - inputs: make([]interface{}, numModelInputs), - rowOffsets: make([]int, 0, 1), + evaluator: evaluator, + isFixedEval: isFixedEval, + modelName: routingValueString, + inputsByName: make(map[string]interface{}), + rowOffsets: make([]int, 0, 1), } batches[batchKey] = batch } // Append this row's inputs to the batch (excluding router input) - inputIdx := 0 + // Use signature to resolve input names for name-based accumulation for paramOffset := range numInputs { if paramOffset == routerInputOffset { continue } + inputName := signature.Inputs[paramOffset].Name + debatched, err := shape.Debatch(params[paramOffset], batchOffset) if err != nil { - return fmt.Errorf("failed to debatch for row %d, input %d: %w", batchOffset, paramOffset, err) + return fmt.Errorf("failed to debatch for row %d, input %s: %w", batchOffset, inputName, err) } - batch.inputs[inputIdx], err = shape.AppendRowToBatch(batch.inputs[inputIdx], debatched) + batch.inputsByName[inputName], err = shape.AppendRowToBatch(batch.inputsByName[inputName], debatched) if err != nil { - return fmt.Errorf("failed to append row %d to batch for input %d: %w", batchOffset, paramOffset, err) + return fmt.Errorf("failed to append row %d to batch for input %s: %w", batchOffset, inputName, err) } - inputIdx++ } batch.rowOffsets = append(batch.rowOffsets, batchOffset) @@ -344,8 +343,45 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface semaphore <- struct{}{} defer func() { <-semaphore }() - // Rely on downstream for timeouts - results, err := b.evaluator.Predict(ctx, b.inputs) + // Reorder inputs to match each evaluator's expected order before calling Predict + var results []interface{} + var err error + var outputNames []string + + if b.isFixedEval { + // Fixed evaluator returns constant values + // Pass any accumulated input so it can determine batch size + var fixedInputs []interface{} + for _, inputData := range b.inputsByName { + fixedInputs = append(fixedInputs, inputData) + break // only need one input for batch size + } + results, err = r.fixedEvaluator.Predict(ctx, fixedInputs) + outputNames = r.fixedEvaluator.OutputNames() + } else { + // Reorder inputs to match this evaluator's expected order + evalSig := b.evaluator.Signature() + orderedInputs := make([]interface{}, len(evalSig.Inputs)) + for i, sigInput := range evalSig.Inputs { + inputData, exists := b.inputsByName[sigInput.Name] + if !exists { + err = fmt.Errorf("input %s not found in batch for model %s", sigInput.Name, b.modelName) + break + } + orderedInputs[i] = inputData + } + + if err == nil { + // Rely on downstream for timeouts + results, err = b.evaluator.Predict(ctx, orderedInputs) + + // Capture output names for reordering in Phase 3 + outputNames = make([]string, len(evalSig.Outputs)) + for i, out := range evalSig.Outputs { + outputNames[i] = out.Name + } + } + } // Append model name to results if configured if r.modelOutputName != "" && err == nil { @@ -354,13 +390,15 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface modelNames[i] = []string{b.modelName} } results = append(results, modelNames) + outputNames = append(outputNames, r.modelOutputName) } resultCh <- batchResult{ - modelName: b.modelName, - results: results, - offsets: b.rowOffsets, - err: err, + modelName: b.modelName, + results: results, + offsets: b.rowOffsets, + err: err, + outputNames: outputNames, } }(batch) } @@ -369,6 +407,12 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface close(resultCh) // Phase 3: Reassemble results in original order + // Build router output name -> index mapping for reordering + routerOutputIndex := make(map[string]int, len(signature.Outputs)) + for i, out := range signature.Outputs { + routerOutputIndex[out.Name] = i + } + allResults := make([][]interface{}, expectedBatchSize) for res := range resultCh { @@ -377,15 +421,32 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface } // Extract individual rows from the batched result and place at original offsets + // Reorder outputs to match router's expected output order for localIdx, originalOffset := range res.offsets { - rowResult := make([]interface{}, len(res.results)) - for outputIdx, outputBatch := range res.results { + rowResult := make([]interface{}, len(signature.Outputs)) + for evalOutputIdx, outputBatch := range res.results { extracted, err := shape.ExtractRowFromBatch(outputBatch, localIdx) if err != nil { return nil, fmt.Errorf("failed to extract row %d from model %s output %d: %w", - localIdx, res.modelName, outputIdx, err) + localIdx, res.modelName, evalOutputIdx, err) } - rowResult[outputIdx] = extracted + + // Map evaluator output index to router output index by name + var routerIdx int + if res.outputNames == nil { + // Fallback: assume same order (shouldn't happen in normal operation) + routerIdx = evalOutputIdx + } else { + outputName := res.outputNames[evalOutputIdx] + var exists bool + routerIdx, exists = routerOutputIndex[outputName] + if !exists { + return nil, fmt.Errorf("output %s from model %s not found in router signature", + outputName, res.modelName) + } + } + + rowResult[routerIdx] = extracted } allResults[originalOffset] = rowResult } diff --git a/service/platform/router/router_test.go b/service/platform/router/router_test.go index e9bd751..c2ecbf5 100644 --- a/service/platform/router/router_test.go +++ b/service/platform/router/router_test.go @@ -16,7 +16,6 @@ import ( "github.com/viant/mly/service/platform" "github.com/viant/mly/service/triton" "github.com/viant/mly/shared/common" - sharedrouter "github.com/viant/mly/shared/config/router" ) // --- Router Predict scaffolds --- @@ -137,7 +136,6 @@ func TestRouter_Predict(t *testing.T) { { name: "with global model", routerConfig: &config.RouterConfig{ - ConfigURL: "memory://router-config", InputName: "router_id", Global: config.GlobalModelConfig{ Exists: true, // avoid fixed replacements path @@ -160,7 +158,6 @@ func TestRouter_Predict(t *testing.T) { { name: "without global model", routerConfig: &config.RouterConfig{ - ConfigURL: "memory://router-config", InputName: "router_id", Global: config.GlobalModelConfig{ PredictionReplacements: []config.PredictionReplacement{ @@ -189,7 +186,6 @@ func TestRouter_Predict(t *testing.T) { { name: "with model output name", routerConfig: &config.RouterConfig{ - ConfigURL: "memory://router-config", InputName: "router_id", Global: config.GlobalModelConfig{ Exists: true, @@ -243,6 +239,10 @@ func TestRouter_Predict(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + if test.routerConfig.ConfigURL == "" { + test.routerConfig.ConfigURL = "memory://router-config" + } + cfg := &config.Model{ ID: "router_test", Mode: "router", @@ -320,433 +320,10 @@ func TestRouter_Predict(t *testing.T) { } } -func TestRouter_applyRouterConfig_LoadsAndSwaps(t *testing.T) { - ctx := context.Background() - mockClient := &mockUnloader{ - unloadCh: make(chan string, 2), - } - - oldConfig := &sharedrouter.RoutingConfig{ - EntityMapping: []sharedrouter.EntityKV{ - {EntityID: 1, ModelName: "modelA"}, - {EntityID: 2, ModelName: "modelB"}, - }, - GlobalModelName: "global-old", - } - - makeSig := func() *domain.Signature { - return &domain.Signature{ - Inputs: []domain.Input{ - {Name: "text", Index: 0, Type: reflect.TypeOf("")}, - }, - Outputs: []domain.Output{ - {Name: "score", Index: 0, DataType: "float32"}, - }, - } - } - - router := &Router{ - unloader: &triton.Service{Unloader: mockClient, Repository: triton.NewRepository()}, - routingConfig: oldConfig, - routingMap: map[int]string{ - 1: "modelA", - 2: "modelB", - }, - routingTable: map[string]platform.PlatformEvaluator{ - "modelA": &mockEvaluator{signature: makeSig}, - "modelB": &mockEvaluator{signature: makeSig}, - }, - globalModel: &mockEvaluator{}, - debug: true, - makeRoutedEvaluator: func(modelName string) (platform.PlatformEvaluator, error) { - return &mockEvaluator{signature: makeSig}, nil - }, - } - - reusedModelB := router.routingTable["modelB"] - - newConfig := &sharedrouter.RoutingConfig{ - EntityMapping: []sharedrouter.EntityKV{ - {EntityID: 1, ModelName: "modelB"}, - {EntityID: 3, ModelName: "modelC"}, - }, - GlobalModelName: "global-new", - } - - if err := router.applyRouterConfig(ctx, newConfig); err != nil { - t.Fatalf("applyRouterConfig returned error: %v", err) - } - - waitForCalls(t, mockClient.unloadCh, 2) - - if router.routingConfig != newConfig { - t.Fatalf("routerConfig pointer not updated") - } - - expectedRouting := map[int]string{ - 1: "modelB", - 3: "modelC", - } - - if !reflect.DeepEqual(router.routingMap, expectedRouting) { - t.Fatalf("routingMap mismatch, got %#v", router.routingMap) - } - - if router.globalModel == nil { - t.Fatalf("globalModel was not set") - } - - if _, ok := router.routingTable["modelB"]; !ok { - t.Fatalf("routingTable missing modelB") - } - - if router.routingTable["modelB"] != reusedModelB { - t.Fatalf("modelB evaluator was not reused") - } - - if _, ok := router.routingTable["modelC"]; !ok { - t.Fatalf("routingTable missing modelC") - } - - if _, ok := router.routingTable["modelA"]; ok { - t.Fatalf("routingTable still contains modelA") - } -} - -func TestRouter_applyRouterConfig_LoadError(t *testing.T) { - ctx := context.Background() - - loadErr := errors.New("load failure") - - tritonServer := new(mockTritonServer) - tritonServer.modelLoadErr = map[string]error{ - "modelX": loadErr, - } - - mockClient := &mockUnloader{} - - oldConfig := &sharedrouter.RoutingConfig{ - EntityMapping: []sharedrouter.EntityKV{ - {EntityID: 1, ModelName: "modelA"}, - }, - } - - signature := &domain.Signature{ - Inputs: []domain.Input{ - {Name: "text", Index: 0, Type: reflect.TypeOf("")}, - }, - Outputs: []domain.Output{ - {Name: "score", Index: 0, DataType: "float32"}, - }, - } - - router := &Router{ - debug: true, - routerName: "load_error", - - unloader: &triton.Service{Unloader: mockClient}, - routingConfig: oldConfig, - routingMap: map[int]string{ - 1: "modelA", - }, - - makeRoutedEvaluator: func(modelName string) (platform.PlatformEvaluator, error) { - return &mockEvaluator{ - tritonServer: tritonServer, - modelName: modelName, - signature: func() *domain.Signature { return signature }, - }, nil - }, - } - - newConfig := &sharedrouter.RoutingConfig{ - EntityMapping: []sharedrouter.EntityKV{ - {EntityID: 2, ModelName: "modelX"}, - }, - } - - err := router.applyRouterConfig(ctx, newConfig) - if err == nil { - t.Fatalf("expected error but got nil") - } - - if !strings.Contains(err.Error(), "modelX") { - t.Fatalf("expected error mentioning modelX, got %v", err) - } - - if router.routingConfig != oldConfig { - t.Fatalf("routerConfig should remain unchanged on error") - } - - if !reflect.DeepEqual(router.routingMap, map[int]string{1: "modelA"}) { - t.Fatalf("routingMap should remain unchanged on error") - } - - if router.routingTable != nil { - t.Fatalf("routingTable should not be replaced on error") - } -} - -func TestRouter_applyRouterConfig_signature(t *testing.T) { - ctx := context.Background() - mockClient := &mockUnloader{ - unloadCh: make(chan string, 1), - tritonServer: &mockTritonServer{}, - } - - makeSig := func() *domain.Signature { - return &domain.Signature{ - Inputs: []domain.Input{ - {Name: "text", Index: 0, Type: reflect.TypeOf("")}, - }, - Outputs: []domain.Output{ - {Name: "score", Index: 0, DataType: "float32"}, - }, - } - } - - cfg := &config.Model{ - ID: "test_signature", - Debug: true, - Mode: "router", - Platform: "triton", - Router: &config.RouterConfig{ - ConfigURL: "memory://router-config", - InputName: "router_id", - Global: config.GlobalModelConfig{ - Exists: true, - }, - Output: config.OutputConfig{ - FieldName: "model_id", - }, - }, - Triton: &config.TritonConfig{ - ServerID: "test_server", - }, - } - - cfg.Init(nil) - - var router *Router - var err error - - router, err = newRouter(cfg, nil, map[string]UnloadService{ - "test_server": &triton.Service{Unloader: mockClient}, - }, func(modelName string) (platform.PlatformEvaluator, error) { - return &mockEvaluator{signature: makeSig}, nil - }) - - if err != nil { - t.Fatalf("NewRouter error: %v", err) - } - - newConfig := &sharedrouter.RoutingConfig{ - EntityMapping: []sharedrouter.EntityKV{ - {EntityID: 1, ModelName: "model1"}, - {EntityID: 2, ModelName: "model2"}, - }, - GlobalModelName: "global-model", - } - - if err := router.applyRouterConfig(ctx, newConfig); err != nil { - t.Fatalf("applyRouterConfig returned error: %v", err) - } - - detectedInputs := map[string]struct{}{} - for _, input := range router.ioState.signature.Inputs { - detectedInputs[input.Name] = struct{}{} - } - - expectedInputs := map[string]struct{}{ - "text": {}, - "router_id": {}, - } - - assert.Equal(t, expectedInputs, detectedInputs) - assert.Equal(t, 1, router.ioState.routerInputOffset) - - params := []interface{}{ - [][]string{{"a"}, {"abcd"}}, // text - [][]int64{{1}, {2}}, // router_id - } - - results, err := router.Predict(ctx, params) - if err != nil { - t.Fatalf("Predict error: %v", err) - } - - assert.Equal(t, 2, len(results)) -} - -type wrappedUnloader struct { - tritonService *triton.Service - - wg *sync.WaitGroup -} - -func (w *wrappedUnloader) UnloadModel(ctx context.Context, mlyModelID string, tritonModelName string) error { - defer w.wg.Done() - return w.tritonService.UnloadModel(ctx, mlyModelID, tritonModelName) -} - -func TestRouter_applyRouterConfig_sharedTritonServer(t *testing.T) { - tritonServer := &mockTritonServer{} - modelUnloader := &mockUnloader{ - tritonServer: tritonServer, - } - - repository := triton.NewRepository() - tritonService := &triton.Service{ - Unloader: modelUnloader, - Repository: repository, - } - - wrappedService := &wrappedUnloader{ - tritonService: tritonService, - wg: &sync.WaitGroup{}, - } - - unloaders := map[string]UnloadService{ - "test_server": wrappedService, - } - - cfgA := &config.Model{ - ID: "test_shared_a", - Debug: true, - Mode: "router", - Platform: "triton", - Router: &config.RouterConfig{ - ConfigURL: "memory://router-config", - Global: config.GlobalModelConfig{ - PredictionReplacements: []config.PredictionReplacement{ - { - Name: "score", - Type: "float32", - Value: 0.0, - }, - }, - }, - }, - Triton: &config.TritonConfig{ - ServerID: "test_server", - }, - } - - cfgB := &config.Model{ - ID: "test_shared_b", - Debug: true, - Mode: "router", - Platform: "triton", - Router: &config.RouterConfig{ - ConfigURL: "memory://router-config", - Global: config.GlobalModelConfig{ - PredictionReplacements: []config.PredictionReplacement{ - { - Name: "score", - Type: "float32", - Value: 0.0, - }, - }, - }, - }, - Triton: &config.TritonConfig{ - ServerID: "test_server", - }, - } - - cfgA.Init(nil) - cfgB.Init(nil) - - newSig := func() *domain.Signature { - return &domain.Signature{ - Inputs: []domain.Input{ - {Name: "text", Index: 0, Type: reflect.TypeOf("")}, - }, - Outputs: []domain.Output{ - {Name: "score", Index: 0, DataType: "float32"}, - }, - } - } - - routerA, err := newRouter(cfgA, nil, unloaders, func(modelName string) (platform.PlatformEvaluator, error) { - tritonService.RegisterUsage(cfgA.ID, modelName) - return &mockEvaluator{modelName: modelName, signature: newSig, tritonServer: tritonServer}, nil - }) - - if err != nil { - t.Fatalf("newRouter returned error: %v", err) - } - - routerB, err := newRouter(cfgB, nil, unloaders, func(modelName string) (platform.PlatformEvaluator, error) { - tritonService.RegisterUsage(cfgB.ID, modelName) - return &mockEvaluator{modelName: modelName, signature: newSig, tritonServer: tritonServer}, nil - }) - - if err != nil { - t.Fatalf("newRouter returned error: %v", err) - } - - ctx := context.Background() - - // establish mappings for A and B using the same models - - if err = routerA.applyRouterConfig(ctx, &sharedrouter.RoutingConfig{ - EntityMapping: []sharedrouter.EntityKV{ - {EntityID: 1, ModelName: "modelA"}, - {EntityID: 2, ModelName: "modelC"}, - }, - }); err != nil { - t.Fatalf("applyRouterConfig A initial returned error: %v", err) - } - - if err = routerB.applyRouterConfig(ctx, &sharedrouter.RoutingConfig{ - EntityMapping: []sharedrouter.EntityKV{ - {EntityID: 1, ModelName: "modelA"}, - }, - }); err != nil { - t.Fatalf("applyRouterConfig B inital returned error: %v", err) - } - - // we expect model A and model C to be attempted to be unloaded - wrappedService.wg.Add(2) - - // routerA will now get a different mapping - if err = routerA.applyRouterConfig(ctx, &sharedrouter.RoutingConfig{ - EntityMapping: []sharedrouter.EntityKV{ - {EntityID: 1, ModelName: "modelB"}, - }, - }); err != nil { - t.Fatalf("applyRouterConfig A reload returned error: %v", err) - } - - wrappedService.wg.Wait() - - assert.True(t, tritonServer.readyState["modelA"], "modelA should still be loaded") - assert.False(t, tritonServer.readyState["modelC"], "modelC should be unloaded") - - // we expect model A to be attempted to be unloaded - wrappedService.wg.Add(1) - - if err = routerB.applyRouterConfig(ctx, &sharedrouter.RoutingConfig{ - EntityMapping: []sharedrouter.EntityKV{ - {EntityID: 1, ModelName: "modelB"}, - {EntityID: 2, ModelName: "modelC"}, - }, - }); err != nil { - t.Fatalf("applyRouterConfig B reload returned error: %v", err) - } - - wrappedService.wg.Wait() - - assert.False(t, tritonServer.readyState["modelA"], "modelA should be unloaded") - assert.True(t, tritonServer.readyState["modelB"], "modelB should still be loaded") - assert.True(t, tritonServer.readyState["modelC"], "modelC should be loaded") -} - // --- Batched Prediction Tests --- -// testMockEvaluator is a configurable mock that handles batched inputs and tracks calls -type testMockEvaluator struct { +// countingMockEvaluator is a configurable mock that handles batched inputs and tracks calls +type countingMockEvaluator struct { modelName string signature *domain.Signature predictCalls int @@ -754,11 +331,11 @@ type testMockEvaluator struct { err error // if set, Predict returns this error } -func newTestMockEvaluator(name string, sig *domain.Signature) *testMockEvaluator { - return &testMockEvaluator{modelName: name, signature: sig} +func newTestMockEvaluator(name string, sig *domain.Signature) *countingMockEvaluator { + return &countingMockEvaluator{modelName: name, signature: sig} } -func (m *testMockEvaluator) Predict(ctx context.Context, params []interface{}) ([]interface{}, error) { +func (m *countingMockEvaluator) Predict(ctx context.Context, params []interface{}) ([]interface{}, error) { m.mu.Lock() m.predictCalls++ m.mu.Unlock() @@ -801,16 +378,16 @@ func (m *testMockEvaluator) Predict(ctx context.Context, params []interface{}) ( return []interface{}{results}, nil } -func (m *testMockEvaluator) Signature() *domain.Signature { return m.signature } -func (m *testMockEvaluator) Dictionary() *common.Dictionary { return nil } -func (m *testMockEvaluator) Inputs() map[string]*domain.Input { return nil } -func (m *testMockEvaluator) Stats(map[string]interface{}) {} -func (m *testMockEvaluator) Close() error { return nil } -func (m *testMockEvaluator) ReloadIfNeeded(ctx context.Context) error { +func (m *countingMockEvaluator) Signature() *domain.Signature { return m.signature } +func (m *countingMockEvaluator) Dictionary() *common.Dictionary { return nil } +func (m *countingMockEvaluator) Inputs() map[string]*domain.Input { return nil } +func (m *countingMockEvaluator) Stats(map[string]interface{}) {} +func (m *countingMockEvaluator) Close() error { return nil } +func (m *countingMockEvaluator) ReloadIfNeeded(ctx context.Context) error { return nil } -func (m *testMockEvaluator) GetPredictCalls() int { +func (m *countingMockEvaluator) getPredictCalls() int { m.mu.Lock() defer m.mu.Unlock() return m.predictCalls @@ -819,7 +396,7 @@ func (m *testMockEvaluator) GetPredictCalls() int { // predictTestSetup holds the common test setup for prediction tests type predictTestSetup struct { router *Router - evaluators map[string]*testMockEvaluator + evaluators map[string]*countingMockEvaluator signature *domain.Signature } @@ -849,7 +426,7 @@ func setupPredictTest(t *testing.T, opts predictTestOptions) *predictTestSetup { } // Create evaluators - evaluators := make(map[string]*testMockEvaluator) + evaluators := make(map[string]*countingMockEvaluator) for _, name := range opts.modelNames { evaluators[name] = newTestMockEvaluator(name, downstreamSig) } @@ -1011,7 +588,7 @@ func TestRouter_Predict_BatchingBehavior(t *testing.T) { // Verify call counts for modelName, wantCount := range tt.wantCallCounts { if eval, ok := setup.evaluators[modelName]; ok { - gotCount := eval.GetPredictCalls() + gotCount := eval.getPredictCalls() if gotCount != wantCount { t.Errorf("%s call count: got %d, want %d", modelName, gotCount, wantCount) } @@ -1079,3 +656,481 @@ func TestRouter_Predict_ErrorPropagation(t *testing.T) { t.Errorf("expected error to contain 'model prediction failed', got: %v", err) } } + +// orderVerifyingEvaluator verifies inputs arrive in expected order and computes output +type orderVerifyingEvaluator struct { + t *testing.T + modelName string + signature *domain.Signature + expectedOrder []string // expected input names in order +} + +func (e *orderVerifyingEvaluator) Predict(ctx context.Context, params []interface{}) ([]interface{}, error) { + if len(params) != len(e.expectedOrder) { + return nil, fmt.Errorf("expected %d inputs, got %d", len(e.expectedOrder), len(params)) + } + + // Verify we received inputs in the expected order by checking types match signature + for i, param := range params { + expectedName := e.expectedOrder[i] + + // Verify the type is a slice (basic sanity check) + paramType := reflect.TypeOf(param) + if paramType.Kind() != reflect.Slice { + return nil, fmt.Errorf("input %d (%s): expected slice, got %v", i, expectedName, paramType) + } + } + + // Compute output: concatenate first input values (assumes string inputs) + var batchSize int + switch typed := params[0].(type) { + case [][]string: + batchSize = len(typed) + case [][]int64: + batchSize = len(typed) + default: + return nil, fmt.Errorf("unexpected first input type: %T", params[0]) + } + + // Output: sum of string lengths from input_a + input_b + results := make([][]float32, batchSize) + for i := 0; i < batchSize; i++ { + var sum int + for _, param := range params { + if typed, ok := param.([][]string); ok { + sum += len(typed[i][0]) + } + } + results[i] = []float32{float32(sum)} + } + + return []interface{}{results}, nil +} + +func (e *orderVerifyingEvaluator) Signature() *domain.Signature { return e.signature } +func (e *orderVerifyingEvaluator) Dictionary() *common.Dictionary { return nil } +func (e *orderVerifyingEvaluator) Inputs() map[string]*domain.Input { return nil } +func (e *orderVerifyingEvaluator) Stats(map[string]interface{}) {} +func (e *orderVerifyingEvaluator) Close() error { return nil } +func (e *orderVerifyingEvaluator) ReloadIfNeeded(ctx context.Context) error { return nil } + +// TestRouter_Predict_DifferentInputOrdering verifies that models with different +// input orderings receive their inputs in the correct order +func TestRouter_Predict_DifferentInputOrdering(t *testing.T) { + // Model1 expects: [input_a, input_b] (indices 0, 1) + // Model2 expects: [input_b, input_a] (indices 0, 1) - REVERSED ORDER + model1Sig := &domain.Signature{ + Inputs: []domain.Input{ + {Name: "input_a", Index: 0, Type: reflect.TypeOf("")}, + {Name: "input_b", Index: 1, Type: reflect.TypeOf("")}, + }, + Outputs: []domain.Output{ + {Name: "score", Index: 0, DataType: "float32"}, + }, + } + + model2Sig := &domain.Signature{ + Inputs: []domain.Input{ + {Name: "input_b", Index: 0, Type: reflect.TypeOf("")}, // REVERSED + {Name: "input_a", Index: 1, Type: reflect.TypeOf("")}, // REVERSED + }, + Outputs: []domain.Output{ + {Name: "score", Index: 0, DataType: "float32"}, + }, + } + + model1Eval := &orderVerifyingEvaluator{ + t: t, + modelName: "model1", + signature: model1Sig, + expectedOrder: []string{"input_a", "input_b"}, + } + + model2Eval := &orderVerifyingEvaluator{ + t: t, + modelName: "model2", + signature: model2Sig, + expectedOrder: []string{"input_b", "input_a"}, // expects reversed order + } + + cfg := &config.Model{ + ID: "router_ordering_test", + Mode: "router", + Platform: "triton", + Router: &config.RouterConfig{ + ConfigURL: "memory://router-config", + InputName: "router_id", + Global: config.GlobalModelConfig{Exists: true}, + }, + Triton: &config.TritonConfig{ServerID: "test_server"}, + } + cfg.Init(nil) + cfg.Router.MaxQueueSize = 1000 + cfg.Router.Workers = 10 + + router, err := newRouter(cfg, nil, map[string]UnloadService{ + "test_server": &triton.Service{}, + }, func(modelName string) (platform.PlatformEvaluator, error) { + if modelName == "model1" { + return model1Eval, nil + } + return model2Eval, nil + }) + if err != nil { + t.Fatalf("newRouter error: %v", err) + } + + router.routingMap = map[int]string{ + 1: "model1", + 2: "model2", + } + router.routingTable = map[string]platform.PlatformEvaluator{ + "model1": model1Eval, + "model2": model2Eval, + } + + // Router's signature: [input_a, input_b, router_id] + // This is the order the router receives inputs from the request + routerSig := &domain.Signature{ + Inputs: []domain.Input{ + {Name: "input_a", Index: 0, Type: reflect.TypeOf("")}, + {Name: "input_b", Index: 1, Type: reflect.TypeOf("")}, + {Name: "router_id", Index: 2, Type: reflect.TypeOf(int64(0))}, + }, + Outputs: []domain.Output{ + {Name: "score", Index: 0, DataType: "float32"}, + }, + } + + router.ioState = &IOState{ + inputs: map[string]*domain.Input{ + "input_a": &routerSig.Inputs[0], + "input_b": &routerSig.Inputs[1], + "router_id": &routerSig.Inputs[2], + }, + signature: routerSig, + routerInputOffset: 2, // router_id is at index 2 + } + + // Input data: + // Row 0: input_a="aa", input_b="bbbb", router_id=1 (-> model1) + // Row 1: input_a="ccc", input_b="dd", router_id=2 (-> model2) + // Row 2: input_a="e", input_b="ffffff", router_id=1 (-> model1) + // Row 3: input_a="gggg", input_b="h", router_id=2 (-> model2) + params := []interface{}{ + [][]string{{"aa"}, {"ccc"}, {"e"}, {"gggg"}}, // input_a + [][]string{{"bbbb"}, {"dd"}, {"ffffff"}, {"h"}}, // input_b + [][]int64{{1}, {2}, {1}, {2}}, // router_id + } + + results, err := router.Predict(context.Background(), params) + if err != nil { + t.Fatalf("Predict error: %v", err) + } + + // Verify results + // Row 0: len("aa") + len("bbbb") = 2 + 4 = 6 + // Row 1: len("ccc") + len("dd") = 3 + 2 = 5 + // Row 2: len("e") + len("ffffff") = 1 + 6 = 7 + // Row 3: len("gggg") + len("h") = 4 + 1 = 5 + scores, ok := results[0].([][]float32) + if !ok { + t.Fatalf("expected [][]float32, got %T", results[0]) + } + + wantScores := [][]float32{{6}, {5}, {7}, {5}} + if !reflect.DeepEqual(scores, wantScores) { + t.Errorf("scores mismatch:\n got: %v\n want: %v", scores, wantScores) + } +} + +// multiOutputEvaluator returns multiple outputs in the order specified by its signature +type multiOutputEvaluator struct { + modelName string + signature *domain.Signature +} + +func (e *multiOutputEvaluator) Predict(ctx context.Context, params []interface{}) ([]interface{}, error) { + // Get batch size from first input + var batchSize int + switch typed := params[0].(type) { + case [][]string: + batchSize = len(typed) + default: + return nil, fmt.Errorf("unexpected input type: %T", params[0]) + } + + // Return outputs in the order defined by this evaluator's signature + // For each row: score_a = len(input), score_b = len(input) * 2 + results := make([]interface{}, len(e.signature.Outputs)) + for outIdx, outDef := range e.signature.Outputs { + outputData := make([][]float32, batchSize) + for i := 0; i < batchSize; i++ { + inputStr := params[0].([][]string)[i][0] + var value float32 + switch outDef.Name { + case "score_a": + value = float32(len(inputStr)) + case "score_b": + value = float32(len(inputStr) * 2) + } + outputData[i] = []float32{value} + } + results[outIdx] = outputData + } + + return results, nil +} + +func (e *multiOutputEvaluator) Signature() *domain.Signature { return e.signature } +func (e *multiOutputEvaluator) Dictionary() *common.Dictionary { return nil } +func (e *multiOutputEvaluator) Inputs() map[string]*domain.Input { return nil } +func (e *multiOutputEvaluator) Stats(map[string]interface{}) {} +func (e *multiOutputEvaluator) Close() error { return nil } +func (e *multiOutputEvaluator) ReloadIfNeeded(ctx context.Context) error { return nil } + +// TestRouter_Predict_DifferentOutputOrdering verifies that models with different +// output orderings have their outputs correctly reordered to match the router's signature +func TestRouter_Predict_DifferentOutputOrdering(t *testing.T) { + // Model1 returns: [score_a, score_b] (indices 0, 1) + // Model2 returns: [score_b, score_a] (indices 0, 1) - REVERSED ORDER + model1Sig := &domain.Signature{ + Inputs: []domain.Input{ + {Name: "text", Index: 0, Type: reflect.TypeOf("")}, + }, + Outputs: []domain.Output{ + {Name: "score_a", Index: 0, DataType: "float32"}, + {Name: "score_b", Index: 1, DataType: "float32"}, + }, + } + + model2Sig := &domain.Signature{ + Inputs: []domain.Input{ + {Name: "text", Index: 0, Type: reflect.TypeOf("")}, + }, + Outputs: []domain.Output{ + {Name: "score_b", Index: 0, DataType: "float32"}, // REVERSED + {Name: "score_a", Index: 1, DataType: "float32"}, // REVERSED + }, + } + + model1Eval := &multiOutputEvaluator{modelName: "model1", signature: model1Sig} + model2Eval := &multiOutputEvaluator{modelName: "model2", signature: model2Sig} + + cfg := &config.Model{ + ID: "router_output_ordering_test", + Mode: "router", + Platform: "triton", + Router: &config.RouterConfig{ + ConfigURL: "memory://router-config", + InputName: "router_id", + Global: config.GlobalModelConfig{Exists: true}, + }, + Triton: &config.TritonConfig{ServerID: "test_server"}, + } + cfg.Init(nil) + cfg.Router.MaxQueueSize = 1000 + cfg.Router.Workers = 10 + + router, err := newRouter(cfg, nil, map[string]UnloadService{ + "test_server": &triton.Service{}, + }, func(modelName string) (platform.PlatformEvaluator, error) { + if modelName == "model1" { + return model1Eval, nil + } + return model2Eval, nil + }) + if err != nil { + t.Fatalf("newRouter error: %v", err) + } + + router.routingMap = map[int]string{ + 1: "model1", + 2: "model2", + } + router.routingTable = map[string]platform.PlatformEvaluator{ + "model1": model1Eval, + "model2": model2Eval, + } + + // Router's signature: outputs are [score_a, score_b] (this is the canonical order) + routerSig := &domain.Signature{ + Inputs: []domain.Input{ + {Name: "text", Index: 0, Type: reflect.TypeOf("")}, + {Name: "router_id", Index: 1, Type: reflect.TypeOf(int64(0))}, + }, + Outputs: []domain.Output{ + {Name: "score_a", Index: 0, DataType: "float32"}, + {Name: "score_b", Index: 1, DataType: "float32"}, + }, + } + + router.ioState = &IOState{ + inputs: map[string]*domain.Input{ + "text": &routerSig.Inputs[0], + "router_id": &routerSig.Inputs[1], + }, + signature: routerSig, + routerInputOffset: 1, + } + + // Input data: + // Row 0: text="aa", router_id=1 (-> model1) => score_a=2, score_b=4 + // Row 1: text="bbb", router_id=2 (-> model2) => score_a=3, score_b=6 + // Row 2: text="c", router_id=1 (-> model1) => score_a=1, score_b=2 + // Row 3: text="dddd", router_id=2 (-> model2) => score_a=4, score_b=8 + params := []interface{}{ + [][]string{{"aa"}, {"bbb"}, {"c"}, {"dddd"}}, // text + [][]int64{{1}, {2}, {1}, {2}}, // router_id + } + + results, err := router.Predict(context.Background(), params) + if err != nil { + t.Fatalf("Predict error: %v", err) + } + + if len(results) != 2 { + t.Fatalf("expected 2 outputs, got %d", len(results)) + } + + // Verify score_a (output index 0 in router's signature) + scoreA, ok := results[0].([][]float32) + if !ok { + t.Fatalf("expected [][]float32 for score_a, got %T", results[0]) + } + wantScoreA := [][]float32{{2}, {3}, {1}, {4}} + if !reflect.DeepEqual(scoreA, wantScoreA) { + t.Errorf("score_a mismatch:\n got: %v\n want: %v", scoreA, wantScoreA) + } + + // Verify score_b (output index 1 in router's signature) + scoreB, ok := results[1].([][]float32) + if !ok { + t.Fatalf("expected [][]float32 for score_b, got %T", results[1]) + } + wantScoreB := [][]float32{{4}, {6}, {2}, {8}} + if !reflect.DeepEqual(scoreB, wantScoreB) { + t.Errorf("score_b mismatch:\n got: %v\n want: %v", scoreB, wantScoreB) + } +} + +// TestRouter_Predict_FixedEvaluatorOutputOrdering verifies that the fixed evaluator +// correctly reorders its outputs to match the router's signature order, even when +// the PredictionReplacements are in a different order than the model outputs +func TestRouter_Predict_FixedEvaluatorOutputOrdering(t *testing.T) { + // Model signature has outputs: [score_a, score_b] + // But PredictionReplacements are configured in reverse order: [score_b, score_a] + // The fixed evaluator should reorder to match the router signature + + modelSig := &domain.Signature{ + Inputs: []domain.Input{ + {Name: "text", Index: 0, Type: reflect.TypeOf("")}, + }, + Outputs: []domain.Output{ + {Name: "score_a", Index: 0, DataType: "float32"}, + {Name: "score_b", Index: 1, DataType: "float32"}, + }, + } + + modelEval := &multiOutputEvaluator{modelName: "model1", signature: modelSig} + + // Create config with PredictionReplacements in REVERSE order from model outputs + cfg := &config.Model{ + ID: "router_fixed_ordering_test", + Mode: "router", + Platform: "triton", + Router: &config.RouterConfig{ + ConfigURL: "memory://router-config", + InputName: "router_id", + Global: config.GlobalModelConfig{ + Exists: false, // This enables fixed evaluator + PredictionReplacements: []config.PredictionReplacement{ + {Name: "score_b", Type: "float32", Value: 99.0}, // REVERSED ORDER + {Name: "score_a", Type: "float32", Value: 42.0}, // REVERSED ORDER + }, + }, + }, + Triton: &config.TritonConfig{ServerID: "test_server"}, + } + cfg.Init(nil) + cfg.Router.MaxQueueSize = 1000 + cfg.Router.Workers = 10 + + router, err := newRouter(cfg, nil, map[string]UnloadService{ + "test_server": &triton.Service{}, + }, func(modelName string) (platform.PlatformEvaluator, error) { + return modelEval, nil + }) + if err != nil { + t.Fatalf("newRouter error: %v", err) + } + + router.routingMap = map[int]string{ + 1: "model1", + } + router.routingTable = map[string]platform.PlatformEvaluator{ + "model1": modelEval, + } + + // Router signature: outputs are [score_a, score_b] + routerSig := &domain.Signature{ + Inputs: []domain.Input{ + {Name: "text", Index: 0, Type: reflect.TypeOf("")}, + {Name: "router_id", Index: 1, Type: reflect.TypeOf(int64(0))}, + }, + Outputs: []domain.Output{ + {Name: "score_a", Index: 0, DataType: "float32"}, + {Name: "score_b", Index: 1, DataType: "float32"}, + }, + } + + router.ioState = &IOState{ + inputs: map[string]*domain.Input{ + "text": &routerSig.Inputs[0], + "router_id": &routerSig.Inputs[1], + }, + signature: routerSig, + routerInputOffset: 1, + } + + // Input data: + // Row 0: text="aa", router_id=1 (-> model1) + // Row 1: text="bbb", router_id=999 (-> fixed evaluator, ID not in routingMap) + params := []interface{}{ + [][]string{{"aa"}, {"bbb"}}, + [][]int64{{1}, {999}}, // 999 not in routingMap, uses fixed evaluator + } + + results, err := router.Predict(context.Background(), params) + if err != nil { + t.Fatalf("Predict error: %v", err) + } + + if len(results) != 2 { + t.Fatalf("expected 2 outputs, got %d", len(results)) + } + + // Verify score_a (output index 0 in router's signature) + // Row 0: from model, len("aa") = 2 + // Row 1: from fixed evaluator, should be 42.0 (NOT 99.0) + scoreA, ok := results[0].([][]float32) + if !ok { + t.Fatalf("expected [][]float32 for score_a, got %T", results[0]) + } + wantScoreA := [][]float32{{2}, {42}} + if !reflect.DeepEqual(scoreA, wantScoreA) { + t.Errorf("score_a mismatch:\n got: %v\n want: %v", scoreA, wantScoreA) + } + + // Verify score_b (output index 1 in router's signature) + // Row 0: from model, len("aa") * 2 = 4 + // Row 1: from fixed evaluator, should be 99.0 (NOT 42.0) + scoreB, ok := results[1].([][]float32) + if !ok { + t.Fatalf("expected [][]float32 for score_b, got %T", results[1]) + } + wantScoreB := [][]float32{{4}, {99}} + if !reflect.DeepEqual(scoreB, wantScoreB) { + t.Errorf("score_b mismatch:\n got: %v\n want: %v", scoreB, wantScoreB) + } +} From 259c0395f5dd2e51af150fef851484aec02591b5 Mon Sep 17 00:00:00 2001 From: David Choi Date: Thu, 29 Jan 2026 14:42:27 -0800 Subject: [PATCH 24/50] Fix ineffective queue and throttling. Minor code clarifications. --- service/platform/router/fixed.go | 10 +-- service/platform/router/reload.go | 46 +++++++----- service/platform/router/router.go | 119 +++++++++++++++++------------- 3 files changed, 96 insertions(+), 79 deletions(-) diff --git a/service/platform/router/fixed.go b/service/platform/router/fixed.go index 14660cd..3971ad3 100644 --- a/service/platform/router/fixed.go +++ b/service/platform/router/fixed.go @@ -1,11 +1,9 @@ package router import ( - "context" "fmt" "github.com/viant/mly/service/config" - "github.com/viant/mly/service/request/shape" ) // preparedReplacement holds a pre-parsed replacement value for fixed evaluator outputs @@ -19,7 +17,6 @@ type fixedEvaluator struct { prepared []preparedReplacement } -// OutputNames returns the output names in the order they will be returned by Predict func (f *fixedEvaluator) OutputNames() []string { names := make([]string, len(f.prepared)) for i, p := range f.prepared { @@ -134,12 +131,7 @@ func newFixedEvaluator(repls []config.PredictionReplacement) (*fixedEvaluator, e return &fixedEvaluator{prepared: prepared}, nil } -func (f *fixedEvaluator) Predict(ctx context.Context, params []interface{}) ([]interface{}, error) { - batchSize, err := shape.DetermineBatchSize(params) - if err != nil { - return nil, err - } - +func (f *fixedEvaluator) Predict(batchSize int) ([]interface{}, error) { makeString := func(v string) [][]string { out := make([][]string, batchSize) for i := 0; i < batchSize; i++ { diff --git a/service/platform/router/reload.go b/service/platform/router/reload.go index 61c491d..ab02186 100644 --- a/service/platform/router/reload.go +++ b/service/platform/router/reload.go @@ -94,6 +94,7 @@ func (r *Router) ReloadIfNeeded(ctx context.Context) error { return err } + // see defer above isFullReload = true // otherwise just abandon the routing table status checks @@ -146,11 +147,13 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Routin reuseEvaluators := make(map[string]platform.PlatformEvaluator) var reuseGlobal platform.PlatformEvaluator + // copy members to local scope var finalSignature *domain.Signature var oldConfig *router.RoutingConfig func() { r.routingTableLock.RLock() defer r.routingTableLock.RUnlock() + if r.ioState != nil { finalSignature = r.ioState.signature } @@ -304,10 +307,11 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Routin } } - for signature := range signatureCh { + // dsmi stands for DownStream Model Information + for dsmi := range signatureCh { // accept first available signature as the final signature if finalSignature == nil { - srcSig := signature.signature + srcSig := dsmi.signature finalSignature = &domain.Signature{ Inputs: make([]domain.Input, len(srcSig.Inputs), len(srcSig.Inputs)+1), Outputs: make([]domain.Output, len(srcSig.Outputs)), @@ -331,6 +335,7 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Routin sigInputMap[input.Name] = &input } + // add configured (aux) inputs to signature for _, input := range r.configuredInputs { _, ok := sigInputMap[input.Name] @@ -340,7 +345,7 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Routin } if !input.Auxiliary { - return fmt.Errorf("non-auxiliary input %s for model %s was not in model inputs", input.Name, signature.name) + return fmt.Errorf("non-auxiliary input %s for model %s was not in model inputs", input.Name, dsmi.name) } sigInputMap[input.Name] = &domain.Input{ @@ -351,7 +356,7 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Routin } if r.modelOutputName != "" { - // also, add the selected model output + // add the selected model output modelOutput := domain.Output{ Name: r.modelOutputName, Index: len(finalSignature.Outputs), @@ -368,63 +373,64 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Routin continue } - thisSignature := signature.signature + dsSignature := dsmi.signature // validate signature consistency + // Note: Index differences are permitted - IOs are matched by name + + // check that the new signature has no new outputs thisSignatureOutputMap := make(map[string]*domain.Output) - for _, output := range thisSignature.Outputs { + for _, output := range dsSignature.Outputs { oldOutput, ok := sigOutputMap[output.Name] if !ok { - return fmt.Errorf("signature output %s for model %s not found in the previous signature", output.Name, signature.name) + return fmt.Errorf("signature output %s for model %s not found in the previous signature", output.Name, dsmi.name) } thisSignatureOutputMap[output.Name] = &output - // Note: Index differences are permitted - outputs are matched by name if oldOutput.DataType != output.DataType { - return fmt.Errorf("signature output %s for model %s has data type %s, and the previous signature has data type %s", output.Name, signature.name, output.DataType, oldOutput.DataType) + return fmt.Errorf("signature output %s for model %s has data type %s, and the previous signature has data type %s", output.Name, dsmi.name, output.DataType, oldOutput.DataType) } } + // check that the new signature has no new outputs except the model name output for expectedOutput := range sigOutputMap { if _, ok := thisSignatureOutputMap[expectedOutput]; !ok && expectedOutput != r.modelOutputName { - return fmt.Errorf("signature output %s for was not found in model %s signature", expectedOutput, signature.name) + return fmt.Errorf("signature output %s for was not found in model %s signature", expectedOutput, dsmi.name) } } + // check that the new signature has no new inputs thisSignatureInputMap := make(map[string]*domain.Input) - for _, input := range thisSignature.Inputs { + for _, input := range dsSignature.Inputs { oldInput, ok := sigInputMap[input.Name] if !ok { - return fmt.Errorf("signature input %s for model %s not found in the previous signature", input.Name, signature.name) + return fmt.Errorf("signature input %s for model %s not found in the previous signature", input.Name, dsmi.name) } thisSignatureInputMap[input.Name] = &input - if oldInput.Auxiliary { - continue - } - - // Note: Index differences are permitted - inputs are reordered by name at dispatch time if !oldInput.Type.ConvertibleTo(input.Type) { - return fmt.Errorf("signature input %s for model %s has data type %s, and the previous signature has data type %s", input.Name, signature.name, input.Type.String(), oldInput.Type.String()) + return fmt.Errorf("signature input %s for model %s has data type %s, and the previous signature has data type %s", input.Name, dsmi.name, input.Type.String(), oldInput.Type.String()) } } + // check that the new signature has all expected inputs except for the routing input for expectedInput := range sigInputMap { if _, ok := thisSignatureInputMap[expectedInput]; !ok && expectedInput != r.routerInputFieldName { - return fmt.Errorf("signature input %s for was not found in model %s signature", expectedInput, signature.name) + return fmt.Errorf("signature input %s for was not found in model %s signature", expectedInput, dsmi.name) } } } if r.fixedEvaluatorFields != nil { - // TODO this is actually an acceptable case, but needs to be addressed elsewhere first before it is permitted + // TODO this is actually an acceptable case, we can simply ignore fixed evaluator fields that aren't applicable for field := range r.fixedEvaluatorFields { if _, ok := sigOutputMap[field]; !ok { return fmt.Errorf("fixed evaluator field: %s was not found in the signature outputs", field) } } + // check that the fixed evaluator fields have all expected outputs for _, field := range sigOutputMap { if _, ok := r.fixedEvaluatorFields[field.Name]; !ok && field.Name != r.modelOutputName { return fmt.Errorf("signature output %s is not replaced", field.Name) diff --git a/service/platform/router/router.go b/service/platform/router/router.go index 1bae49b..c6af1c0 100644 --- a/service/platform/router/router.go +++ b/service/platform/router/router.go @@ -5,6 +5,7 @@ import ( "fmt" "strconv" "sync" + "sync/atomic" "time" "github.com/viant/afs" @@ -79,10 +80,14 @@ type Router struct { // forceBatchSize1 when true uses legacy per-sample dispatch; when false (default) uses batched dispatch forceBatchSize1 bool - // workers limits concurrent model evaluations (used as semaphore capacity in batch mode) - workers int + + // workerSemaphore limits concurrent model evaluations + workerSemaphore chan struct{} + // maxQueueSize limits queued batches before rejection - maxQueueSize int + maxQueueSize uint64 + + queued *atomic.Uint64 } // NewRouter creates a new Router instance. @@ -99,7 +104,8 @@ func NewRouter(cfg *config.Model, fs afs.Service, tritonServices map[string]*tri // newRouter uses a map[string]ModelUnloader, where ModelUnloader is-a triton.TritonClient, for testing. func newRouter(cfg *config.Model, fs afs.Service, unloaders map[string]UnloadService, makeEvaluator func(modelName string) (platform.PlatformEvaluator, error)) (*Router, error) { - if cfg.Router == nil { + rtCfg := cfg.Router + if rtCfg == nil { return nil, fmt.Errorf("router configuration is required") } @@ -110,15 +116,15 @@ func newRouter(cfg *config.Model, fs afs.Service, unloaders map[string]UnloadSer var fixedEvaluator *fixedEvaluator var fixedEvaluatorFields map[string]struct{} - if !cfg.Router.Global.Exists { + if !rtCfg.Global.Exists { replacementsByName := make(map[string]config.PredictionReplacement) - for _, repl := range cfg.Router.Global.PredictionReplacements { + for _, repl := range rtCfg.Global.PredictionReplacements { replacementsByName[repl.Name] = repl } var err error - fixedEvaluator, err = newFixedEvaluator(cfg.Router.Global.PredictionReplacements) + fixedEvaluator, err = newFixedEvaluator(rtCfg.Global.PredictionReplacements) if err != nil { return nil, fmt.Errorf("failed to create fixed evaluator: %w", err) } @@ -133,25 +139,26 @@ func newRouter(cfg *config.Model, fs afs.Service, unloaders map[string]UnloadSer debug: cfg.Debug, routerName: cfg.ID, - configURL: cfg.Router.ConfigURL, + configURL: rtCfg.ConfigURL, fs: fs, makeRoutedEvaluator: makeEvaluator, unloader: unloader, - outputConfig: cfg.Router.Output, - hasGlobalModel: cfg.Router.Global.Exists, + outputConfig: rtCfg.Output, + hasGlobalModel: rtCfg.Global.Exists, - modelOutputName: cfg.Router.Output.FieldName, + modelOutputName: rtCfg.Output.FieldName, fixedEvaluator: fixedEvaluator, fixedEvaluatorFields: fixedEvaluatorFields, configuredInputs: cfg.Inputs, - routerInputFieldName: cfg.Router.InputName, + routerInputFieldName: rtCfg.InputName, - forceBatchSize1: cfg.Router.ForceBatchSize1, - workers: cfg.Router.Workers, - maxQueueSize: cfg.Router.MaxQueueSize, + forceBatchSize1: rtCfg.ForceBatchSize1, + workerSemaphore: make(chan struct{}, rtCfg.Workers), + maxQueueSize: uint64(rtCfg.MaxQueueSize), + queued: &atomic.Uint64{}, } return r, nil @@ -204,14 +211,13 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface return nil, err } - var signature *domain.Signature - batches := make(map[string]*modelBatch) - - // Hold read lock to ensure evaluator references remain valid during prediction. r.routingTableLock.RLock() defer r.routingTableLock.RUnlock() - // Phase 1: Group rows into batches by name + var signature *domain.Signature + batches := make(map[string]*modelBatch) + + // Phase 1: Group rows into batches by model name err = func() error { if r.ioState == nil { return fmt.Errorf("ioState was not initialized") @@ -226,9 +232,10 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface numInputs := len(params) + routerInputBatch := params[routerInputOffset] for batchOffset := range expectedBatchSize { // Extract routing value for this row - routingValueBatched, err := shape.Debatch(params[routerInputOffset], batchOffset) + routingValueBatched, err := shape.Debatch(routerInputBatch, batchOffset) if err != nil { return fmt.Errorf("failed to debatch routing value for row %d: %w", batchOffset, err) } @@ -277,10 +284,9 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface // Determine batch key: unique per row when forceBatchSize1, otherwise by model name batchKey := routingValueString if r.forceBatchSize1 { - batchKey = routingValueString + "#" + strconv.Itoa(batchOffset) + batchKey = strconv.Itoa(batchOffset) } - // Get or create batch batch, exists := batches[batchKey] if !exists { batch = &modelBatch{ @@ -290,18 +296,17 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface inputsByName: make(map[string]interface{}), rowOffsets: make([]int, 0, 1), } + batches[batchKey] = batch } // Append this row's inputs to the batch (excluding router input) - // Use signature to resolve input names for name-based accumulation for paramOffset := range numInputs { if paramOffset == routerInputOffset { continue } inputName := signature.Inputs[paramOffset].Name - debatched, err := shape.Debatch(params[paramOffset], batchOffset) if err != nil { return fmt.Errorf("failed to debatch for row %d, input %s: %w", batchOffset, inputName, err) @@ -323,40 +328,49 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface return nil, err } - // Check queue size limit - if len(batches) > r.maxQueueSize { + // early queue size check + currentQ := r.queued.Load() + if uint64(len(batches))+currentQ > r.maxQueueSize { routerPredictDroppedCounter.WithLabelValues(r.routerName).Inc() return nil, fmt.Errorf("too many batches (%d) exceeds max queue size (%d)", len(batches), r.maxQueueSize) } // Phase 2: Execute predictions in parallel with bounded concurrency resultCh := make(chan batchResult, len(batches)) - semaphore := make(chan struct{}, r.workers) var wg sync.WaitGroup for _, batch := range batches { wg.Add(1) + + // this must be decremented if queue is full and once no longer in queue + nowQueued := r.queued.Add(1) + + if nowQueued > r.maxQueueSize { + r.queued.Add(^uint64(0)) + return nil, fmt.Errorf("queue size exceeded") + } + go func(b *modelBatch) { defer wg.Done() // Acquire semaphore slot - semaphore <- struct{}{} - defer func() { <-semaphore }() + r.workerSemaphore <- struct{}{} + r.queued.Add(^uint64(0)) + + defer func() { + <-r.workerSemaphore + }() // Reorder inputs to match each evaluator's expected order before calling Predict var results []interface{} var err error + + // Capture output names for reordering in Phase 3 var outputNames []string + bs := len(b.rowOffsets) if b.isFixedEval { - // Fixed evaluator returns constant values - // Pass any accumulated input so it can determine batch size - var fixedInputs []interface{} - for _, inputData := range b.inputsByName { - fixedInputs = append(fixedInputs, inputData) - break // only need one input for batch size - } - results, err = r.fixedEvaluator.Predict(ctx, fixedInputs) + results, err = r.fixedEvaluator.Predict(bs) outputNames = r.fixedEvaluator.OutputNames() } else { // Reorder inputs to match this evaluator's expected order @@ -375,7 +389,6 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface // Rely on downstream for timeouts results, err = b.evaluator.Predict(ctx, orderedInputs) - // Capture output names for reordering in Phase 3 outputNames = make([]string, len(evalSig.Outputs)) for i, out := range evalSig.Outputs { outputNames[i] = out.Name @@ -385,10 +398,11 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface // Append model name to results if configured if r.modelOutputName != "" && err == nil { - modelNames := make([][]string, len(b.rowOffsets)) + modelNames := make([][]string, bs) for i := range modelNames { modelNames[i] = []string{b.modelName} } + results = append(results, modelNames) outputNames = append(outputNames, r.modelOutputName) } @@ -408,11 +422,13 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface // Phase 3: Reassemble results in original order // Build router output name -> index mapping for reordering + // TODO see if memoizing this provides material performance boosts routerOutputIndex := make(map[string]int, len(signature.Outputs)) for i, out := range signature.Outputs { routerOutputIndex[out.Name] = i } + // allResults will be [expectedBatchSize][len(signature.Outputs)] allResults := make([][]interface{}, expectedBatchSize) for res := range resultCh { @@ -421,38 +437,41 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface } // Extract individual rows from the batched result and place at original offsets - // Reorder outputs to match router's expected output order - for localIdx, originalOffset := range res.offsets { + for evalOffset, originalOffset := range res.offsets { rowResult := make([]interface{}, len(signature.Outputs)) + + // Reorder outputs to match router's expected output order for evalOutputIdx, outputBatch := range res.results { - extracted, err := shape.ExtractRowFromBatch(outputBatch, localIdx) + extracted, err := shape.ExtractRowFromBatch(outputBatch, evalOffset) if err != nil { - return nil, fmt.Errorf("failed to extract row %d from model %s output %d: %w", - localIdx, res.modelName, evalOutputIdx, err) + return nil, fmt.Errorf("failed to extract row %d from model %s output index %d: %w", + evalOffset, res.modelName, evalOutputIdx, err) } // Map evaluator output index to router output index by name - var routerIdx int + var originalOutputIdx int if res.outputNames == nil { // Fallback: assume same order (shouldn't happen in normal operation) - routerIdx = evalOutputIdx + originalOutputIdx = evalOutputIdx } else { outputName := res.outputNames[evalOutputIdx] + var exists bool - routerIdx, exists = routerOutputIndex[outputName] + originalOutputIdx, exists = routerOutputIndex[outputName] if !exists { return nil, fmt.Errorf("output %s from model %s not found in router signature", outputName, res.modelName) } } - rowResult[routerIdx] = extracted + rowResult[originalOutputIdx] = extracted } + allResults[originalOffset] = rowResult } } - // Concatenate all rows into final output + // Reshape all values into [outputs][batch][M] endResults := make([]interface{}, len(signature.Outputs)) for i, results := range allResults { endResults, err = shape.ConcatAxis0(endResults, results) From 73b1b36285a7a6d7679cf53f7088f5c1990c7500 Mon Sep 17 00:00:00 2001 From: David Choi Date: Thu, 29 Jan 2026 17:21:16 -0800 Subject: [PATCH 25/50] Add auxiliary in test case for router. --- example/e2e/regression/cases/012_router/test.yaml | 4 ++-- example/server/etc/config.yaml | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/example/e2e/regression/cases/012_router/test.yaml b/example/e2e/regression/cases/012_router/test.yaml index 5008960..c3ee331 100644 --- a/example/e2e/regression/cases/012_router/test.yaml +++ b/example/e2e/regression/cases/012_router/test.yaml @@ -7,8 +7,8 @@ pipeline: target: $target checkError: true commands: - - /tmp/e2e/mlyc -m sli_router -a 'routing_id|int:1;sa:b;sl:a' - - /tmp/e2e/mlyc -m sli_router -a 'routing_id|int:2;sa:b;sl:a' + - /tmp/e2e/mlyc -m sli_router -a 'routing_id|int:1;sa:b;sl:a;aux:1' + - /tmp/e2e/mlyc -m sli_router -a 'routing_id|int:2;sa:b;sl:a;aux:2' # assert: # action: validator:assert diff --git a/example/server/etc/config.yaml b/example/server/etc/config.yaml index fd3bce9..2d9164d 100644 --- a/example/server/etc/config.yaml +++ b/example/server/etc/config.yaml @@ -76,6 +76,9 @@ models: - name: sa datatype: string wildcard: true + - name: aux + datatype: string + auxiliary: true Outputs: - name: expand datatype: int64 From 7ec1d4fe214d9773aeccd1aaf844bb03f33805a3 Mon Sep 17 00:00:00 2001 From: David Choi Date: Thu, 29 Jan 2026 17:21:42 -0800 Subject: [PATCH 26/50] Minor comment reformatting. --- service/triton/evaluator.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/service/triton/evaluator.go b/service/triton/evaluator.go index e287a52..5831bed 100644 --- a/service/triton/evaluator.go +++ b/service/triton/evaluator.go @@ -98,7 +98,8 @@ func createEvaluator(config *config.Model, tritonClients map[string]*Service) (* } -// Upward dependency, but provides Evaluators as needed for the service/platform/router module. +// Upward dependency. +// Provides Evaluators as needed for the service/platform/router module. func NewRoutedTritonEvaluator(modelName string, config *config.Model, tritonClients map[string]*Service) (*TritonEvaluator, error) { evaluator, err := createEvaluator(config, tritonClients) if err != nil { From c8689685980f0932e12b1d5ad3f1e0708ec72ec4 Mon Sep 17 00:00:00 2001 From: David Choi Date: Mon, 2 Feb 2026 14:42:39 -0800 Subject: [PATCH 27/50] Add back queueing, minor fix in router config init, refactor router tests. --- service/config/router.go | 8 +- service/platform/router/prometheus.go | 11 + service/platform/router/reload.go | 19 +- service/platform/router/reload_test.go | 89 +- service/platform/router/router.go | 39 +- service/platform/router/router_test.go | 1550 +++++++++--------------- 6 files changed, 738 insertions(+), 978 deletions(-) diff --git a/service/config/router.go b/service/config/router.go index 062d2ad..82638bd 100644 --- a/service/config/router.go +++ b/service/config/router.go @@ -66,6 +66,10 @@ func (o *RouterConfig) Init() { if o.MaxQueueSize == 0 { o.MaxQueueSize = 1000 } + + if o.Output.NoModelID == "" { + o.Output.NoModelID = "none" + } } func (o *RouterConfig) Validate() error { @@ -89,9 +93,5 @@ func (o *RouterConfig) Validate() error { return fmt.Errorf("global model does not exist but no prediction replacements were provided") } - if o.Output.NoModelID == "" { - o.Output.NoModelID = "none" - } - return nil } diff --git a/service/platform/router/prometheus.go b/service/platform/router/prometheus.go index 8fa9538..5598690 100644 --- a/service/platform/router/prometheus.go +++ b/service/platform/router/prometheus.go @@ -38,6 +38,16 @@ var ( []string{"router"}, ) + routerQueueDurationMicrosSummary = prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Namespace: "mly", + Subsystem: "router", + Name: "queue_duration_summary_us", + Help: "Duration of router queueing.", + }, + []string{"router"}, + ) + routerModelUnloadGauge = prometheus.NewGaugeVec( prometheus.GaugeOpts{ Namespace: "mly", @@ -53,5 +63,6 @@ func init() { prometheus.MustRegister(routerPredictDurationMicrosSummary) prometheus.MustRegister(routerReloadDurationMicrosSummary) prometheus.MustRegister(routerModelUnloadGauge) + prometheus.MustRegister(routerQueueDurationMicrosSummary) prometheus.MustRegister(routerPredictDroppedCounter) } diff --git a/service/platform/router/reload.go b/service/platform/router/reload.go index ab02186..9c27e28 100644 --- a/service/platform/router/reload.go +++ b/service/platform/router/reload.go @@ -35,6 +35,7 @@ func (r *Router) ReloadIfNeeded(ctx context.Context) error { } else { mode = "checks" } + routerReloadDurationMicrosSummary.WithLabelValues(r.routerName, mode).Observe(float64(time.Since(start).Microseconds())) }() @@ -414,9 +415,17 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Routin } } - // check that the new signature has all expected inputs except for the routing input + // check that the new signature has all expected inputs except for the routing and auxiliary inputs for expectedInput := range sigInputMap { - if _, ok := thisSignatureInputMap[expectedInput]; !ok && expectedInput != r.routerInputFieldName { + if sigInputMap[expectedInput].Auxiliary { + continue + } + + if expectedInput == r.routerInputFieldName { + continue + } + + if _, ok := thisSignatureInputMap[expectedInput]; !ok { return fmt.Errorf("signature input %s for was not found in model %s signature", expectedInput, dsmi.name) } } @@ -464,12 +473,12 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Routin }() for model := range modelsToUnload { - routerModelUnloadGauge.WithLabelValues(r.routerName).Inc() + r.unloadGauge.Inc() go func(modelName string) { - defer routerModelUnloadGauge.WithLabelValues(r.routerName).Dec() + defer r.unloadGauge.Dec() - ctxTo, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctxTo, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() r.debugLogf("request to unload model: %s", modelName) diff --git a/service/platform/router/reload_test.go b/service/platform/router/reload_test.go index 1b6d780..3c91dbc 100644 --- a/service/platform/router/reload_test.go +++ b/service/platform/router/reload_test.go @@ -7,6 +7,7 @@ import ( "strings" "sync" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/viant/mly/service/config" @@ -16,6 +17,67 @@ import ( sharedrouter "github.com/viant/mly/shared/config/router" ) +type mockUnloader struct { + tritonServer *mockTritonServer + unloadCh chan string +} + +func (m *mockUnloader) ModelUnload(ctx context.Context, tritonModelName string) error { + if m.tritonServer != nil { + m.tritonServer.mu.Lock() + defer m.tritonServer.mu.Unlock() + + if m.tritonServer.readyState == nil { + m.tritonServer.readyState = make(map[string]bool) + } + + m.tritonServer.readyState[tritonModelName] = false + } + + ch := m.unloadCh + + if ch != nil { + ch <- tritonModelName + } + + return nil +} + +type wrappedUnloader struct { + tritonService *triton.Service + + wg *sync.WaitGroup +} + +func (w *wrappedUnloader) UnloadModel(ctx context.Context, mlyModelID string, tritonModelName string) error { + defer w.wg.Done() + return w.tritonService.UnloadModel(ctx, mlyModelID, tritonModelName) +} + +type mockTritonServer struct { + mu sync.Mutex + + readyState map[string]bool + modelLoadErr map[string]error +} + +func (m *mockTritonServer) ModelLoad(ctx context.Context, modelName string) error { + m.mu.Lock() + defer m.mu.Unlock() + + if err := m.modelLoadErr[modelName]; err != nil { + return err + } + + if m.readyState == nil { + m.readyState = make(map[string]bool) + } + + m.readyState[modelName] = true + + return nil +} + func TestRouter_applyRouterConfig_LoadsAndSwaps(t *testing.T) { ctx := context.Background() mockClient := &mockUnloader{ @@ -57,6 +119,7 @@ func TestRouter_applyRouterConfig_LoadsAndSwaps(t *testing.T) { makeRoutedEvaluator: func(modelName string) (platform.PlatformEvaluator, error) { return &mockEvaluator{signature: makeSig}, nil }, + unloadGauge: routerModelUnloadGauge.WithLabelValues("test_router"), } reusedModelB := router.routingTable["modelB"] @@ -153,6 +216,7 @@ func TestRouter_applyRouterConfig_LoadError(t *testing.T) { signature: func() *domain.Signature { return signature }, }, nil }, + unloadGauge: routerModelUnloadGauge.WithLabelValues("load_error"), } newConfig := &sharedrouter.RoutingConfig{ @@ -274,17 +338,6 @@ func TestRouter_applyRouterConfig_signature(t *testing.T) { assert.Equal(t, 2, len(results)) } -type wrappedUnloader struct { - tritonService *triton.Service - - wg *sync.WaitGroup -} - -func (w *wrappedUnloader) UnloadModel(ctx context.Context, mlyModelID string, tritonModelName string) error { - defer w.wg.Done() - return w.tritonService.UnloadModel(ctx, mlyModelID, tritonModelName) -} - func TestRouter_applyRouterConfig_sharedTritonServer(t *testing.T) { tritonServer := &mockTritonServer{} modelUnloader := &mockUnloader{ @@ -438,3 +491,17 @@ func TestRouter_applyRouterConfig_sharedTritonServer(t *testing.T) { assert.True(t, tritonServer.readyState["modelB"], "modelB should still be loaded") assert.True(t, tritonServer.readyState["modelC"], "modelC should be loaded") } + +func waitForCalls(t *testing.T, ch <-chan string, count int) []string { + t.Helper() + var out []string + for i := 0; i < count; i++ { + select { + case v := <-ch: + out = append(out, v) + case <-time.After(time.Second): + t.Fatalf("timeout waiting for call %d/%d", i+1, count) + } + } + return out +} diff --git a/service/platform/router/router.go b/service/platform/router/router.go index c6af1c0..64305cd 100644 --- a/service/platform/router/router.go +++ b/service/platform/router/router.go @@ -8,6 +8,7 @@ import ( "sync/atomic" "time" + "github.com/prometheus/client_golang/prometheus" "github.com/viant/afs" "github.com/viant/mly/service/config" "github.com/viant/mly/service/domain" @@ -19,6 +20,8 @@ import ( "github.com/viant/mly/shared/config/router" ) +const queueSizeExceededError = "queue size exceeded" + type IOState struct { inputs map[string]*domain.Input signature *domain.Signature @@ -71,9 +74,10 @@ type Router struct { modelOutputName string - routerName string - debug bool - unloader UnloadService + routerName string + debug bool + unloader UnloadService + unloadGauge prometheus.Gauge configuredInputs []*shared.Field ioState *IOState @@ -87,7 +91,9 @@ type Router struct { // maxQueueSize limits queued batches before rejection maxQueueSize uint64 - queued *atomic.Uint64 + queued *atomic.Uint64 + queueDurationObserver prometheus.Observer + droppedCounter prometheus.Counter } // NewRouter creates a new Router instance. @@ -135,15 +141,18 @@ func newRouter(cfg *config.Model, fs afs.Service, unloaders map[string]UnloadSer } } + routerName := cfg.ID r := &Router{ debug: cfg.Debug, - routerName: cfg.ID, + routerName: routerName, configURL: rtCfg.ConfigURL, fs: fs, makeRoutedEvaluator: makeEvaluator, - unloader: unloader, + unloader: unloader, + unloadGauge: routerModelUnloadGauge.WithLabelValues(routerName), + outputConfig: rtCfg.Output, hasGlobalModel: rtCfg.Global.Exists, @@ -156,9 +165,13 @@ func newRouter(cfg *config.Model, fs afs.Service, unloaders map[string]UnloadSer routerInputFieldName: rtCfg.InputName, forceBatchSize1: rtCfg.ForceBatchSize1, + workerSemaphore: make(chan struct{}, rtCfg.Workers), maxQueueSize: uint64(rtCfg.MaxQueueSize), queued: &atomic.Uint64{}, + + queueDurationObserver: routerQueueDurationMicrosSummary.WithLabelValues(routerName), + droppedCounter: routerPredictDroppedCounter.WithLabelValues(routerName), } return r, nil @@ -167,7 +180,7 @@ func newRouter(cfg *config.Model, fs afs.Service, unloaders map[string]UnloadSer // modelBatch holds accumulated rows destined for a single model evaluator type modelBatch struct { evaluator platform.PlatformEvaluator // need Signature() for input reordering - isFixedEval bool // true if using fixedEvaluator (no reordering needed) + isFixedEval bool // true skips input reordering modelName string inputsByName map[string]interface{} // keyed by input name - accumulated batched inputs rowOffsets []int // original positions in the incoming batch @@ -263,7 +276,7 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface isFixedEval := false if !ok { if hasFixedEvaluator { - // No global model, use fixed evaluator (returns constant values) + // No global model, use fixed evaluator routingValueString = noModelName isFixedEval = true } else { @@ -331,8 +344,8 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface // early queue size check currentQ := r.queued.Load() if uint64(len(batches))+currentQ > r.maxQueueSize { - routerPredictDroppedCounter.WithLabelValues(r.routerName).Inc() - return nil, fmt.Errorf("too many batches (%d) exceeds max queue size (%d)", len(batches), r.maxQueueSize) + r.droppedCounter.Inc() + return nil, fmt.Errorf(queueSizeExceededError) } // Phase 2: Execute predictions in parallel with bounded concurrency @@ -344,10 +357,12 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface // this must be decremented if queue is full and once no longer in queue nowQueued := r.queued.Add(1) + startQueueTime := time.Now() if nowQueued > r.maxQueueSize { r.queued.Add(^uint64(0)) - return nil, fmt.Errorf("queue size exceeded") + r.droppedCounter.Inc() + return nil, fmt.Errorf(queueSizeExceededError) } go func(b *modelBatch) { @@ -355,7 +370,9 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface // Acquire semaphore slot r.workerSemaphore <- struct{}{} + r.queued.Add(^uint64(0)) + r.queueDurationObserver.Observe(float64(time.Since(startQueueTime).Microseconds())) defer func() { <-r.workerSemaphore diff --git a/service/platform/router/router_test.go b/service/platform/router/router_test.go index c2ecbf5..6de05c7 100644 --- a/service/platform/router/router_test.go +++ b/service/platform/router/router_test.go @@ -2,73 +2,79 @@ package router import ( "context" - "errors" "fmt" + "log" "reflect" + "slices" + "strconv" "strings" "sync" + "sync/atomic" "testing" - "time" - "github.com/stretchr/testify/assert" "github.com/viant/mly/service/config" "github.com/viant/mly/service/domain" "github.com/viant/mly/service/platform" + "github.com/viant/mly/service/request/shape" "github.com/viant/mly/service/triton" "github.com/viant/mly/shared/common" ) -// --- Router Predict scaffolds --- +// mockEvaluator currently handles all test case behaviors. +// This may be a sign that the router itself needs to be refactored into separate parts. +type mockEvaluator struct { + // reload-related objects + modelName string + + // see reload_test.go + tritonServer *mockTritonServer + + // used in batching tests + predictCalls int + mu sync.Mutex -type mockTritonServer struct { - mu sync.Mutex + // force an error + err error - readyState map[string]bool - modelLoadErr map[string]error + // for queueing tests + waitFor *sync.WaitGroup + doneGroup *sync.WaitGroup + + predictor func(params []interface{}, signature *domain.Signature) ([]interface{}, error) + signature func() *domain.Signature } -func (m *mockTritonServer) ModelLoad(ctx context.Context, modelName string) error { +func (m *mockEvaluator) Predict(ctx context.Context, params []interface{}) ([]interface{}, error) { m.mu.Lock() - defer m.mu.Unlock() + m.predictCalls++ + m.mu.Unlock() - if err := m.modelLoadErr[modelName]; err != nil { - return err + if m.waitFor != nil { + m.waitFor.Wait() + m.waitFor = nil } - if m.readyState == nil { - m.readyState = make(map[string]bool) + if m.doneGroup != nil { + defer m.doneGroup.Done() + m.doneGroup = nil } - m.readyState[modelName] = true - - return nil -} - -type mockEvaluator struct { - tritonServer *mockTritonServer - - modelName string - signature func() *domain.Signature -} + if m.err != nil { + return nil, m.err + } -func (m *mockEvaluator) Predict(ctx context.Context, params []interface{}) ([]interface{}, error) { - inputs := m.signature().Inputs + sig := m.signature() + inputs := sig.Inputs if len(inputs) != len(params) { return nil, fmt.Errorf("mock error: expected %d inputs, got %d", len(inputs), len(params)) } - // params is expected to have a single non-router input in this test: [][]string with shape [1][1] - var v string - switch typed := params[0].(type) { - case [][]string: - v = typed[0][0] - default: - tval := reflect.TypeOf(params[0]) - panic("unexpected input type in mock Predict(): " + tval.String()) + if m.predictor == nil { + // for test cases that do not validate results like reload-centric cases + return nil, nil } - // simple function: length of string as float32 - out := [][]float32{{float32(len(v))}} - return []interface{}{out}, nil + + return m.predictor(params, sig) } func (m *mockEvaluator) Signature() *domain.Signature { return m.signature() } @@ -85,1052 +91,702 @@ func (m *mockEvaluator) ReloadIfNeeded(ctx context.Context) error { return m.tritonServer.ModelLoad(ctx, m.modelName) } -type mockUnloader struct { - tritonServer *mockTritonServer - unloadCh chan string -} - -func (m *mockUnloader) ModelUnload(ctx context.Context, tritonModelName string) error { - if m.tritonServer != nil { - m.tritonServer.mu.Lock() - defer m.tritonServer.mu.Unlock() - - if m.tritonServer.readyState == nil { - m.tritonServer.readyState = make(map[string]bool) - } - - m.tritonServer.readyState[tritonModelName] = false - } - - ch := m.unloadCh - - if ch != nil { - ch <- tritonModelName - } - - return nil +func (m *mockEvaluator) getPredictCalls() int { + m.mu.Lock() + defer m.mu.Unlock() + return m.predictCalls } -func waitForCalls(t *testing.T, ch <-chan string, count int) []string { - t.Helper() - var out []string - for i := 0; i < count; i++ { - select { - case v := <-ch: - out = append(out, v) - case <-time.After(time.Second): - t.Fatalf("timeout waiting for call %d/%d", i+1, count) - } +// appendFloatPredict expects []interface{[][]string} and returns []interface{}{[][]float32} +// The output field names are expected to be an integer in string form, which will be parsed and appended. +func appendFloatPredict(params []interface{}, signature *domain.Signature) ([]interface{}, error) { + batchSize, err := shape.BatchSize(params[0]) + if err != nil { + return nil, fmt.Errorf("could not determine batch size: %w", err) } - return out -} - -func TestRouter_Predict(t *testing.T) { - ctx := context.Background() - tests := []struct { - name string - routerConfig *config.RouterConfig - verifier func(t *testing.T, results []interface{}) - }{ - { - name: "with global model", - routerConfig: &config.RouterConfig{ - InputName: "router_id", - Global: config.GlobalModelConfig{ - Exists: true, // avoid fixed replacements path - }, - }, - verifier: func(t *testing.T, results []interface{}) { - if len(results) != 1 { - t.Fatalf("expected 1 output, got %d", len(results)) - } - out, ok := results[0].([][]float32) - if !ok { - t.Fatalf("expected [][]float32, got %T", results[0]) - } - want := [][]float32{{1}, {4}} - if !reflect.DeepEqual(out, want) { - t.Errorf("output mismatch: got %#v, want %#v", out, want) - } - }, - }, - { - name: "without global model", - routerConfig: &config.RouterConfig{ - InputName: "router_id", - Global: config.GlobalModelConfig{ - PredictionReplacements: []config.PredictionReplacement{ - { - Name: "score", - Type: "float32", - Value: 1.0, - }, - }, - }, - }, - verifier: func(t *testing.T, results []interface{}) { - if len(results) != 1 { - t.Fatalf("expected 1 output, got %d", len(results)) + batchConcats := make([]*strings.Builder, batchSize) + for _, param := range params { + switch batch := param.(type) { + case [][]string: + for bi, s := range batch { + if batchConcats[bi] == nil { + batchConcats[bi] = &strings.Builder{} } - out, ok := results[0].([][]float32) - if !ok { - t.Fatalf("expected [][]float32, got %T", results[0]) - } - want := [][]float32{{1}, {4}} - if !reflect.DeepEqual(out, want) { - t.Errorf("output mismatch: got %#v, want %#v", out, want) - } - }, - }, - { - name: "with model output name", - routerConfig: &config.RouterConfig{ - InputName: "router_id", - Global: config.GlobalModelConfig{ - Exists: true, - }, - Output: config.OutputConfig{ - FieldName: "model_output", - }, - }, - verifier: func(t *testing.T, results []interface{}) { - if len(results) != 2 { - t.Fatalf("expected 2 outputs, got %d, %v", len(results), results) - } - - func() { - out, ok := results[0].([][]float32) - if !ok { - t.Fatalf("expected [][]float32, got %T", results[0]) - } - want := [][]float32{{1}, {4}} - if !reflect.DeepEqual(out, want) { - t.Errorf("output mismatch: got %#v, want %#v", out, want) - } - }() - - func() { - out, ok := results[1].([][]string) - if !ok { - t.Fatalf("expected [][]string, got %T", results[1]) - } - want := [][]string{{"model1"}, {"model2"}} - if !reflect.DeepEqual(out, want) { - t.Errorf("output mismatch: got %#v, want %#v", out, want) - } - }() - }, - }, - } - downstreamInput := domain.Input{ - Name: "text", - Index: 1, - Type: reflect.TypeOf(""), - } - - downstreamSignature := &domain.Signature{ - Inputs: []domain.Input{downstreamInput}, - Outputs: []domain.Output{ - {Name: "score", Index: 0, DataType: "float32"}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.routerConfig.ConfigURL == "" { - test.routerConfig.ConfigURL = "memory://router-config" + sb := batchConcats[bi] + sb.WriteString(s[0]) } + default: + return nil, fmt.Errorf("unexpected input type: %T", param) + } + } - cfg := &config.Model{ - ID: "router_test", - Mode: "router", - Platform: "triton", - Router: test.routerConfig, - Triton: &config.TritonConfig{ - ServerID: "test_server", - }, - } - - cfg.Init(nil) - cfg.Router.MaxQueueSize = 1000 - cfg.Router.Workers = 3 - - router, err := newRouter(cfg, nil, map[string]UnloadService{ - "test_server": &triton.Service{}, - }, func(modelName string) (platform.PlatformEvaluator, error) { - return &mockEvaluator{signature: func() *domain.Signature { return downstreamSignature }}, nil - }) - + numOutputs := len(signature.Outputs) + retVal := make([]interface{}, numOutputs) + for oi := range numOutputs { + outputBatch := make([][]float32, batchSize) + for i, v := range batchConcats { + outputName := signature.Outputs[oi].Name + floatStr := outputName + v.String() + floatVal, err := strconv.ParseFloat(floatStr, 32) if err != nil { - t.Fatalf("NewRouter error: %v", err) - } - - router.routingMap = map[int]string{ - 1: "model1", - 2: "model2", - } - - routerOutputs := []domain.Output{ - {Name: "score", Index: 0, DataType: "float32"}, - } - - if test.routerConfig.Output.FieldName != "" { - routerOutputs = append(routerOutputs, domain.Output{Name: test.routerConfig.Output.FieldName, Index: 1, DataType: "string"}) - } - - routerInputName := cfg.Router.InputName - routerInput := domain.Input{Name: routerInputName, Index: 0, Type: reflect.TypeOf(int64(0))} - - router.ioState = &IOState{ - inputs: map[string]*domain.Input{ - routerInputName: &routerInput, - downstreamInput.Name: &downstreamInput, - }, - - signature: &domain.Signature{ - Inputs: []domain.Input{ - routerInput, - downstreamInput, - }, - Outputs: routerOutputs, - }, - } - - mockEval := &mockEvaluator{signature: func() *domain.Signature { return downstreamSignature }} - router.routingTable = map[string]platform.PlatformEvaluator{ - "model1": mockEval, - "model2": mockEval, - } - - // batch of 2 - params := []interface{}{ - [][]int64{{1}, {2}}, // router id - [][]string{{"a"}, {"abcd"}}, // backend input + return nil, fmt.Errorf("could not parse float: %v", floatStr) } - results, err := router.Predict(ctx, params) - if err != nil { - t.Fatalf("Predict error: %v", err) - } + outputBatch[i] = []float32{float32(floatVal)} + } - test.verifier(t, results) - }) + retVal[oi] = outputBatch } -} - -// --- Batched Prediction Tests --- -// countingMockEvaluator is a configurable mock that handles batched inputs and tracks calls -type countingMockEvaluator struct { - modelName string - signature *domain.Signature - predictCalls int - mu sync.Mutex - err error // if set, Predict returns this error + return retVal, nil } -func newTestMockEvaluator(name string, sig *domain.Signature) *countingMockEvaluator { - return &countingMockEvaluator{modelName: name, signature: sig} +type configVariant struct { + numOutputs int + hasGlobalModel bool + hasOutputName bool + forceBatchSize1 bool + reverseFPOutputs bool } -func (m *countingMockEvaluator) Predict(ctx context.Context, params []interface{}) ([]interface{}, error) { - m.mu.Lock() - m.predictCalls++ - m.mu.Unlock() +func makeConfig(cv configVariant) *config.RouterConfig { + var prs []config.PredictionReplacement + if !cv.hasGlobalModel { + prs = make([]config.PredictionReplacement, cv.numOutputs) + for i := range cv.numOutputs { + prs[i] = config.PredictionReplacement{ + Name: strconv.Itoa(i), + Type: "float32", + Value: float32(i) + 0.1, + } + } + } - if m.err != nil { - return nil, m.err + if cv.reverseFPOutputs { + slices.Reverse(prs) } - if len(params) == 0 { - return nil, fmt.Errorf("no params provided") + var outputName = "" + if cv.hasOutputName { + outputName = "model_id" } - // Determine batch size from first input - var batchSize int - switch typed := params[0].(type) { - case [][]string: - batchSize = len(typed) - case [][]int32: - batchSize = len(typed) - case [][]int64: - batchSize = len(typed) - case [][]float32: - batchSize = len(typed) - case [][]float64: - batchSize = len(typed) - default: - return nil, fmt.Errorf("unexpected input type: %T", params[0]) + gmo := "" + if cv.hasGlobalModel { + gmo = "global" } - // Compute output: for each row, output the length of the first string input - results := make([][]float32, batchSize) - for i := 0; i < batchSize; i++ { - var length int - if typed, ok := params[0].([][]string); ok { - length = len(typed[i][0]) - } - results[i] = []float32{float32(length)} + rc := &config.RouterConfig{ + InputName: "router_id", + ConfigURL: "memory://router-config", + ForceBatchSize1: cv.forceBatchSize1, + Global: config.GlobalModelConfig{ + Exists: cv.hasGlobalModel, + PredictionReplacements: prs, + }, + Output: config.OutputConfig{ + FieldName: outputName, + GlobalModelOverride: gmo, + }, + MaxQueueSize: 100, + Workers: 3, } - return []interface{}{results}, nil + return rc } -func (m *countingMockEvaluator) Signature() *domain.Signature { return m.signature } -func (m *countingMockEvaluator) Dictionary() *common.Dictionary { return nil } -func (m *countingMockEvaluator) Inputs() map[string]*domain.Input { return nil } -func (m *countingMockEvaluator) Stats(map[string]interface{}) {} -func (m *countingMockEvaluator) Close() error { return nil } -func (m *countingMockEvaluator) ReloadIfNeeded(ctx context.Context) error { - return nil -} +type predictTestCase struct { + name string + routerConfig *config.RouterConfig -func (m *countingMockEvaluator) getPredictCalls() int { - m.mu.Lock() - defer m.mu.Unlock() - return m.predictCalls -} + reverseInputs bool + reverseOutputs bool -// predictTestSetup holds the common test setup for prediction tests -type predictTestSetup struct { - router *Router - evaluators map[string]*countingMockEvaluator - signature *domain.Signature -} + stringInputs [][]string + routerInputs []int -// predictTestOptions configures the test setup -type predictTestOptions struct { - forceBatchSize1 bool - modelOutputName string - modelNames []string // defaults to ["model1", "model2"] + // outputs + expectedOutputs [][]float32 + expectCallCounts map[string]int } -// setupPredictTest creates a router with mock evaluators for prediction testing -func setupPredictTest(t *testing.T, opts predictTestOptions) *predictTestSetup { - t.Helper() +// predictTest creates the signature and evaluators, runs Predict, and runs the tests +func predictTest(t *testing.T, test predictTestCase) { + routerInputs, router, mockEvaluators, globalModelName, params := prepareTestRouter(t, test) - if len(opts.modelNames) == 0 { - opts.modelNames = []string{"model1", "model2"} - } - - // Create shared signature - downstreamSig := &domain.Signature{ - Inputs: []domain.Input{ - {Name: "text", Index: 0, Type: reflect.TypeOf("")}, - }, - Outputs: []domain.Output{ - {Name: "score", Index: 0, DataType: "float32"}, - }, - } - - // Create evaluators - evaluators := make(map[string]*countingMockEvaluator) - for _, name := range opts.modelNames { - evaluators[name] = newTestMockEvaluator(name, downstreamSig) - } - - // Create config - cfg := &config.Model{ - ID: "router_predict_test", - Mode: "router", - Platform: "triton", - Router: &config.RouterConfig{ - ConfigURL: "memory://router-config", - InputName: "router_id", - ForceBatchSize1: opts.forceBatchSize1, - Global: config.GlobalModelConfig{Exists: true}, - }, - Triton: &config.TritonConfig{ServerID: "test_server"}, - } - - if opts.modelOutputName != "" { - cfg.Router.Output.FieldName = opts.modelOutputName + dl, _ := t.Deadline() + ctx, cancel := context.WithDeadline(context.Background(), dl) + defer cancel() + results, err := router.Predict(ctx, params) + if err != nil { + t.Fatalf("Predict error: %v", err) } - cfg.Init(nil) - cfg.Router.MaxQueueSize = 1000 - cfg.Router.Workers = 10 + // model name is appended at the end, first check normal outputs + for oi, outputBatch := range test.expectedOutputs { + actualOutput := results[oi] + switch aob := actualOutput.(type) { + case [][]float32: + for obi, ov := range outputBatch { + actualValue := aob[obi][0] + log.Printf("model output %d expected:%f actual:%f", obi, ov, actualValue) - // Create router - router, err := newRouter(cfg, nil, map[string]UnloadService{ - "test_server": &triton.Service{}, - }, func(modelName string) (platform.PlatformEvaluator, error) { - if eval, ok := evaluators[modelName]; ok { - return eval, nil + if actualValue != ov { + t.Fatalf("input %d offset %d expected %f, got %f", oi, obi, ov, actualValue) + } + } + default: + t.Fatalf("input %d expected [][]float32, got %T", oi, actualOutput) } - return evaluators[opts.modelNames[0]], nil - }) - if err != nil { - t.Fatalf("newRouter error: %v", err) } - // Setup routing map (1->model1, 2->model2, etc.) - router.routingMap = make(map[int]string) - for i, name := range opts.modelNames { - router.routingMap[i+1] = name - } + // then check model name outputs + if test.routerConfig.Output.FieldName != "" { + actualOutput := results[len(test.expectedOutputs)] + switch aob := actualOutput.(type) { + case [][]string: + for obi, ov := range aob { + routedNumber := routerInputs[obi] + routedModel, ok := router.routingMap[routedNumber] + if !ok { + if router.globalModel == nil { + routedModel = test.routerConfig.Output.NoModelID + } else { + routedModel = globalModelName + } + } - // Setup routing table - router.routingTable = make(map[string]platform.PlatformEvaluator) - for name, eval := range evaluators { - router.routingTable[name] = eval - } + log.Printf("model name expected:%s actual:%s", routedModel, ov[0]) - // Build signature outputs - outputs := []domain.Output{{Name: "score", Index: 0, DataType: "float32"}} - if opts.modelOutputName != "" { - outputs = append(outputs, domain.Output{Name: opts.modelOutputName, Index: 1, DataType: "string"}) + if ov[0] != routedModel { + t.Fatalf("model name output expected %s, got %s", "model"+strconv.Itoa(obi), ov[0]) + } + } + default: + t.Fatalf("model name output expected [][]string, got %T", actualOutput) + } } - // Setup ioState - routerInputName := cfg.Router.InputName - routerInput := domain.Input{Name: routerInputName, Index: 1, Type: reflect.TypeOf(int64(0))} - downstreamInput := domain.Input{Name: "text", Index: 0, Type: reflect.TypeOf("")} - - routerSig := &domain.Signature{ - Inputs: []domain.Input{downstreamInput, routerInput}, - Outputs: outputs, - } + // check number of predict calls + for _, evaluator := range mockEvaluators { + expectCallCount, hasExpect := test.expectCallCounts[evaluator.modelName] + if !hasExpect { + continue + } - router.ioState = &IOState{ - inputs: map[string]*domain.Input{ - routerInputName: &routerInput, - downstreamInput.Name: &downstreamInput, - }, - signature: routerSig, - routerInputOffset: 1, - } + predictCalls := evaluator.getPredictCalls() + log.Printf("predict calls %s expected:%d, actual:%d", evaluator.modelName, expectCallCount, predictCalls) - return &predictTestSetup{ - router: router, - evaluators: evaluators, - signature: routerSig, + if predictCalls != expectCallCount { + t.Fatalf("predict calls expected %d, got %d", expectCallCount, predictCalls) + } } } -func TestRouter_Predict_BatchingBehavior(t *testing.T) { - tests := []struct { - name string - forceBatchSize1 bool - inputs []string - routingIDs []int64 - wantScores [][]float32 - wantCallCounts map[string]int - }{ - { - name: "batched_groups_by_model", - forceBatchSize1: false, - inputs: []string{"a", "bb", "ccc", "dddd", "eeeee", "ffffff"}, - routingIDs: []int64{1, 2, 1, 2, 1, 2}, // alternating model1/model2 - wantScores: [][]float32{{1}, {2}, {3}, {4}, {5}, {6}}, - wantCallCounts: map[string]int{"model1": 1, "model2": 1}, // each model called once - }, - { - name: "batched_single_model", - forceBatchSize1: false, - inputs: []string{"a", "bb", "ccc", "dddd"}, - routingIDs: []int64{1, 1, 1, 1}, // all to model1 - wantScores: [][]float32{{1}, {2}, {3}, {4}}, - wantCallCounts: map[string]int{"model1": 1, "model2": 0}, - }, - { - name: "force_batch_size_1", - forceBatchSize1: true, - inputs: []string{"a", "bb", "ccc", "dddd"}, - routingIDs: []int64{1, 1, 1, 1}, // all to model1 - wantScores: [][]float32{{1}, {2}, {3}, {4}}, - wantCallCounts: map[string]int{"model1": 4, "model2": 0}, // 4 individual calls - }, - { - name: "force_batch_size_1_multiple_models", - forceBatchSize1: true, - inputs: []string{"a", "bb", "ccc", "dddd"}, - routingIDs: []int64{1, 2, 1, 2}, // alternating - wantScores: [][]float32{{1}, {2}, {3}, {4}}, - wantCallCounts: map[string]int{"model1": 2, "model2": 2}, // 2 calls each +func prepareTestRouter(t *testing.T, test predictTestCase) ([]int, *Router, map[string]*mockEvaluator, string, []interface{}) { + tritonServerID := "test_server" + cfg := &config.Model{ + ID: test.name, + Mode: "router", + Platform: "triton", + Router: test.routerConfig, + Triton: &config.TritonConfig{ + ServerID: tritonServerID, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - setup := setupPredictTest(t, predictTestOptions{ - forceBatchSize1: tt.forceBatchSize1, - }) - - // Build input params - stringInputs := make([][]string, len(tt.inputs)) - for i, s := range tt.inputs { - stringInputs[i] = []string{s} - } - routingInputs := make([][]int64, len(tt.routingIDs)) - for i, id := range tt.routingIDs { - routingInputs[i] = []int64{id} - } - - params := []interface{}{stringInputs, routingInputs} - - results, err := setup.router.Predict(context.Background(), params) - if err != nil { - t.Fatalf("Predict error: %v", err) - } - - // Verify scores - scores, ok := results[0].([][]float32) - if !ok { - t.Fatalf("expected [][]float32, got %T", results[0]) - } - if !reflect.DeepEqual(scores, tt.wantScores) { - t.Errorf("scores mismatch:\n got: %v\n want: %v", scores, tt.wantScores) - } + cfg.Init(nil) - // Verify call counts - for modelName, wantCount := range tt.wantCallCounts { - if eval, ok := setup.evaluators[modelName]; ok { - gotCount := eval.getPredictCalls() - if gotCount != wantCount { - t.Errorf("%s call count: got %d, want %d", modelName, gotCount, wantCount) - } - } - } + signature := &domain.Signature{} + for i := range test.stringInputs { + signature.Inputs = append(signature.Inputs, domain.Input{ + Name: strconv.Itoa(i), + Index: i, + Type: reflect.TypeOf(""), }) } -} -func TestRouter_Predict_WithModelOutput(t *testing.T) { - setup := setupPredictTest(t, predictTestOptions{ - modelOutputName: "model_id", - }) - - params := []interface{}{ - [][]string{{"a"}, {"bb"}, {"ccc"}, {"dddd"}}, - [][]int64{{1}, {2}, {1}, {2}}, + for i := range test.expectedOutputs { + signature.Outputs = append(signature.Outputs, domain.Output{ + Name: strconv.Itoa(i), + Index: i, + DataType: "float32", + }) } - results, err := setup.router.Predict(context.Background(), params) - if err != nil { - t.Fatalf("Predict error: %v", err) + var routerInputs []int = test.routerInputs + if routerInputs == nil { + sampledInput := test.stringInputs[0] + routerInputs = make([]int, len(sampledInput)) + for j := range len(sampledInput) { + routerInputs[j] = j + } } - if len(results) != 2 { - t.Fatalf("expected 2 outputs, got %d", len(results)) - } + router, err := newRouter(cfg, nil, map[string]UnloadService{ + tritonServerID: &triton.Service{}, + }, nil) - // Verify scores - scores, ok := results[0].([][]float32) - if !ok { - t.Fatalf("expected [][]float32, got %T", results[0]) + if err != nil { + t.Fatalf("NewRouter error: %v", err) } - wantScores := [][]float32{{1}, {2}, {3}, {4}} - assert.Equal(t, wantScores, scores, "scores mismatch") - // Verify model IDs - modelIDs, ok := results[1].([][]string) - if !ok { - t.Fatalf("expected [][]string for model_id, got %T", results[1]) + // manually generate router configuration + + // see how model names are constructed later + model0Name := "model0" + model1Name := "model1" + router.routingMap = map[int]string{ + 0: model0Name, + 1: model1Name, } - wantModelIDs := [][]string{{"model1"}, {"model2"}, {"model1"}, {"model2"}} - assert.Equal(t, wantModelIDs, modelIDs, "model_id mismatch") -} -func TestRouter_Predict_ErrorPropagation(t *testing.T) { - setup := setupPredictTest(t, predictTestOptions{ - modelNames: []string{"model1"}, - }) + routerInputName := cfg.Router.InputName + routerInput := domain.Input{Name: routerInputName, Index: 0, Type: reflect.TypeOf(int64(0))} - // Configure evaluator to return error - setup.evaluators["model1"].err = errors.New("model prediction failed") + routerOutputs := make([]domain.Output, len(signature.Outputs)) + copy(routerOutputs, signature.Outputs) - params := []interface{}{ - [][]string{{"a"}, {"bb"}}, - [][]int64{{1}, {1}}, + if test.routerConfig.Output.FieldName != "" { + routerOutputs = append(routerOutputs, domain.Output{ + Name: test.routerConfig.Output.FieldName, + Index: len(routerOutputs), + DataType: "string", + }) } - _, err := setup.router.Predict(context.Background(), params) - if err == nil { - t.Fatal("expected error but got nil") + // initialize ioState with base outputs + router.ioState = &IOState{ + inputs: map[string]*domain.Input{ + routerInputName: &routerInput, + }, + signature: &domain.Signature{ + Inputs: []domain.Input{ + routerInput, + }, + Outputs: routerOutputs, + }, + routerInputOffset: 0, } - if !strings.Contains(err.Error(), "model prediction failed") { - t.Errorf("expected error to contain 'model prediction failed', got: %v", err) + // add inputs + for i, sigInput := range signature.Inputs { + router.ioState.inputs[sigInput.Name] = &signature.Inputs[i] + router.ioState.signature.Inputs = append(router.ioState.signature.Inputs, sigInput) } -} - -// orderVerifyingEvaluator verifies inputs arrive in expected order and computes output -type orderVerifyingEvaluator struct { - t *testing.T - modelName string - signature *domain.Signature - expectedOrder []string // expected input names in order -} -func (e *orderVerifyingEvaluator) Predict(ctx context.Context, params []interface{}) ([]interface{}, error) { - if len(params) != len(e.expectedOrder) { - return nil, fmt.Errorf("expected %d inputs, got %d", len(e.expectedOrder), len(params)) + model1Signature := &domain.Signature{ + Inputs: make([]domain.Input, len(signature.Inputs)), + Outputs: make([]domain.Output, len(signature.Outputs)), } - // Verify we received inputs in the expected order by checking types match signature - for i, param := range params { - expectedName := e.expectedOrder[i] - - // Verify the type is a slice (basic sanity check) - paramType := reflect.TypeOf(param) - if paramType.Kind() != reflect.Slice { - return nil, fmt.Errorf("input %d (%s): expected slice, got %v", i, expectedName, paramType) - } - } + copy(model1Signature.Inputs, signature.Inputs) + copy(model1Signature.Outputs, signature.Outputs) - // Compute output: concatenate first input values (assumes string inputs) - var batchSize int - switch typed := params[0].(type) { - case [][]string: - batchSize = len(typed) - case [][]int64: - batchSize = len(typed) - default: - return nil, fmt.Errorf("unexpected first input type: %T", params[0]) + if test.reverseInputs { + slices.Reverse(model1Signature.Inputs) } - // Output: sum of string lengths from input_a + input_b - results := make([][]float32, batchSize) - for i := 0; i < batchSize; i++ { - var sum int - for _, param := range params { - if typed, ok := param.([][]string); ok { - sum += len(typed[i][0]) - } - } - results[i] = []float32{float32(sum)} + if test.reverseOutputs { + slices.Reverse(model1Signature.Outputs) } - return []interface{}{results}, nil -} - -func (e *orderVerifyingEvaluator) Signature() *domain.Signature { return e.signature } -func (e *orderVerifyingEvaluator) Dictionary() *common.Dictionary { return nil } -func (e *orderVerifyingEvaluator) Inputs() map[string]*domain.Input { return nil } -func (e *orderVerifyingEvaluator) Stats(map[string]interface{}) {} -func (e *orderVerifyingEvaluator) Close() error { return nil } -func (e *orderVerifyingEvaluator) ReloadIfNeeded(ctx context.Context) error { return nil } - -// TestRouter_Predict_DifferentInputOrdering verifies that models with different -// input orderings receive their inputs in the correct order -func TestRouter_Predict_DifferentInputOrdering(t *testing.T) { - // Model1 expects: [input_a, input_b] (indices 0, 1) - // Model2 expects: [input_b, input_a] (indices 0, 1) - REVERSED ORDER - model1Sig := &domain.Signature{ - Inputs: []domain.Input{ - {Name: "input_a", Index: 0, Type: reflect.TypeOf("")}, - {Name: "input_b", Index: 1, Type: reflect.TypeOf("")}, + mockEvaluators := map[string]*mockEvaluator{ + model0Name: { + signature: func() *domain.Signature { return signature }, + predictor: appendFloatPredict, + modelName: model0Name, }, - Outputs: []domain.Output{ - {Name: "score", Index: 0, DataType: "float32"}, + model1Name: { + signature: func() *domain.Signature { return model1Signature }, + predictor: appendFloatPredict, + modelName: model1Name, }, } - model2Sig := &domain.Signature{ - Inputs: []domain.Input{ - {Name: "input_b", Index: 0, Type: reflect.TypeOf("")}, // REVERSED - {Name: "input_a", Index: 1, Type: reflect.TypeOf("")}, // REVERSED - }, - Outputs: []domain.Output{ - {Name: "score", Index: 0, DataType: "float32"}, - }, + router.routingTable = make(map[string]platform.PlatformEvaluator) + for modelName, evaluator := range mockEvaluators { + router.routingTable[modelName] = evaluator } - model1Eval := &orderVerifyingEvaluator{ - t: t, - modelName: "model1", - signature: model1Sig, - expectedOrder: []string{"input_a", "input_b"}, - } + globalModelName := cfg.Router.Output.GlobalModelOverride + if router.hasGlobalModel { + router.globalModel = &mockEvaluator{ + signature: func() *domain.Signature { return signature }, + predictor: appendFloatPredict, + modelName: globalModelName, + } - model2Eval := &orderVerifyingEvaluator{ - t: t, - modelName: "model2", - signature: model2Sig, - expectedOrder: []string{"input_b", "input_a"}, // expects reversed order + router.routingTable[globalModelName] = router.globalModel } - cfg := &config.Model{ - ID: "router_ordering_test", - Mode: "router", - Platform: "triton", - Router: &config.RouterConfig{ - ConfigURL: "memory://router-config", - InputName: "router_id", - Global: config.GlobalModelConfig{Exists: true}, - }, - Triton: &config.TritonConfig{ServerID: "test_server"}, + // reshape inputs + params := []interface{}{} + + paramRouterInputs := make([][]int64, len(routerInputs)) + for pri, ri := range routerInputs { + paramRouterInputs[pri] = []int64{int64(ri)} } - cfg.Init(nil) - cfg.Router.MaxQueueSize = 1000 - cfg.Router.Workers = 10 + params = append(params, paramRouterInputs) - router, err := newRouter(cfg, nil, map[string]UnloadService{ - "test_server": &triton.Service{}, - }, func(modelName string) (platform.PlatformEvaluator, error) { - if modelName == "model1" { - return model1Eval, nil + for _, input := range test.stringInputs { + inputVals := [][]string{} + for _, inputVal := range input { + inputVals = append(inputVals, []string{inputVal}) } - return model2Eval, nil - }) - if err != nil { - t.Fatalf("newRouter error: %v", err) - } - router.routingMap = map[int]string{ - 1: "model1", - 2: "model2", - } - router.routingTable = map[string]platform.PlatformEvaluator{ - "model1": model1Eval, - "model2": model2Eval, + params = append(params, inputVals) } + return routerInputs, router, mockEvaluators, globalModelName, params +} - // Router's signature: [input_a, input_b, router_id] - // This is the order the router receives inputs from the request - routerSig := &domain.Signature{ - Inputs: []domain.Input{ - {Name: "input_a", Index: 0, Type: reflect.TypeOf("")}, - {Name: "input_b", Index: 1, Type: reflect.TypeOf("")}, - {Name: "router_id", Index: 2, Type: reflect.TypeOf(int64(0))}, - }, - Outputs: []domain.Output{ - {Name: "score", Index: 0, DataType: "float32"}, +func TestRouter_Predict_GlobalModel(t *testing.T) { + tests := []predictTestCase{ + { + name: "with_global_model", + routerConfig: makeConfig(configVariant{numOutputs: 1, hasGlobalModel: true, hasOutputName: false}), + stringInputs: [][]string{ + {"1", "2", "3"}, + {"4", "5", "6"}, + }, + expectedOutputs: [][]float32{ + {14, 25, 36}, + }, }, - } - - router.ioState = &IOState{ - inputs: map[string]*domain.Input{ - "input_a": &routerSig.Inputs[0], - "input_b": &routerSig.Inputs[1], - "router_id": &routerSig.Inputs[2], + { + name: "without_global_model", + routerConfig: makeConfig(configVariant{numOutputs: 1, hasGlobalModel: false, hasOutputName: false}), + stringInputs: [][]string{ + {"1", "2", "3"}, + {"4", "5", "6"}, + }, + expectedOutputs: [][]float32{ + // third offset should get fixed prediction for output 0 + {14, 25, 0.1}, + }, }, - signature: routerSig, - routerInputOffset: 2, // router_id is at index 2 } - // Input data: - // Row 0: input_a="aa", input_b="bbbb", router_id=1 (-> model1) - // Row 1: input_a="ccc", input_b="dd", router_id=2 (-> model2) - // Row 2: input_a="e", input_b="ffffff", router_id=1 (-> model1) - // Row 3: input_a="gggg", input_b="h", router_id=2 (-> model2) - params := []interface{}{ - [][]string{{"aa"}, {"ccc"}, {"e"}, {"gggg"}}, // input_a - [][]string{{"bbbb"}, {"dd"}, {"ffffff"}, {"h"}}, // input_b - [][]int64{{1}, {2}, {1}, {2}}, // router_id - } - - results, err := router.Predict(context.Background(), params) - if err != nil { - t.Fatalf("Predict error: %v", err) - } - - // Verify results - // Row 0: len("aa") + len("bbbb") = 2 + 4 = 6 - // Row 1: len("ccc") + len("dd") = 3 + 2 = 5 - // Row 2: len("e") + len("ffffff") = 1 + 6 = 7 - // Row 3: len("gggg") + len("h") = 4 + 1 = 5 - scores, ok := results[0].([][]float32) - if !ok { - t.Fatalf("expected [][]float32, got %T", results[0]) - } - - wantScores := [][]float32{{6}, {5}, {7}, {5}} - if !reflect.DeepEqual(scores, wantScores) { - t.Errorf("scores mismatch:\n got: %v\n want: %v", scores, wantScores) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + predictTest(t, test) + }) } } -// multiOutputEvaluator returns multiple outputs in the order specified by its signature -type multiOutputEvaluator struct { - modelName string - signature *domain.Signature +func TestRouter_Predict_ModelName(t *testing.T) { + t.Run("no global model", func(t *testing.T) { + predictTest(t, predictTestCase{ + routerConfig: makeConfig(configVariant{numOutputs: 1, hasGlobalModel: false, hasOutputName: true}), + stringInputs: [][]string{ + {"1", "2", "3"}, + {"4", "5", "6"}, + }, + expectedOutputs: [][]float32{ + {14, 25, 0.1}, + }, + }) + }) + + t.Run("with global model", func(t *testing.T) { + predictTest(t, predictTestCase{ + routerConfig: makeConfig(configVariant{numOutputs: 1, hasGlobalModel: true, hasOutputName: true}), + stringInputs: [][]string{ + {"1", "2", "3"}, + {"4", "5", "6"}, + }, + expectedOutputs: [][]float32{ + {14, 25, 36}, + }, + }) + }) } -func (e *multiOutputEvaluator) Predict(ctx context.Context, params []interface{}) ([]interface{}, error) { - // Get batch size from first input - var batchSize int - switch typed := params[0].(type) { - case [][]string: - batchSize = len(typed) - default: - return nil, fmt.Errorf("unexpected input type: %T", params[0]) - } +func TestRouter_Predict_BatchingBehavior(t *testing.T) { + t.Run("batch to model", func(t *testing.T) { + predictTest(t, predictTestCase{ + routerConfig: makeConfig(configVariant{ + numOutputs: 1, + hasGlobalModel: false, + hasOutputName: true, + }), + stringInputs: [][]string{ + {"1", "2", "3", "4", "5", "6"}, + {"1", "2", "3", "4", "5", "6"}, + }, + routerInputs: []int{ + 0, 1, 0, 1, 0, 1, + }, + expectedOutputs: [][]float32{ + {11, 22, 33, 44, 55, 66}, + }, + expectCallCounts: map[string]int{ + "model0": 1, + "model1": 1, + }, + }) + }) - // Return outputs in the order defined by this evaluator's signature - // For each row: score_a = len(input), score_b = len(input) * 2 - results := make([]interface{}, len(e.signature.Outputs)) - for outIdx, outDef := range e.signature.Outputs { - outputData := make([][]float32, batchSize) - for i := 0; i < batchSize; i++ { - inputStr := params[0].([][]string)[i][0] - var value float32 - switch outDef.Name { - case "score_a": - value = float32(len(inputStr)) - case "score_b": - value = float32(len(inputStr) * 2) - } - outputData[i] = []float32{value} - } - results[outIdx] = outputData - } + t.Run("force batch size 1", func(t *testing.T) { + predictTest(t, predictTestCase{ + routerConfig: makeConfig(configVariant{ + numOutputs: 1, + hasGlobalModel: false, + hasOutputName: true, + forceBatchSize1: true, + }), + stringInputs: [][]string{ + {"1", "2", "3", "4", "5", "6"}, + {"1", "2", "3", "4", "5", "6"}, + }, + routerInputs: []int{ + 0, 1, 0, 1, 0, 1, + }, + expectedOutputs: [][]float32{ + {11, 22, 33, 44, 55, 66}, + }, + expectCallCounts: map[string]int{ + "model0": 3, + "model1": 3, + }, + }) + }) - return results, nil + t.Run("batched with no global", func(t *testing.T) { + predictTest(t, predictTestCase{ + routerConfig: makeConfig(configVariant{ + numOutputs: 1, + hasGlobalModel: false, + hasOutputName: true, + }), + stringInputs: [][]string{ + {"1", "2", "3", "4", "5", "6"}, + {"1", "2", "3", "4", "5", "6"}, + }, + routerInputs: []int{ + 0, 1, 0, 2, 0, 1, + }, + expectedOutputs: [][]float32{ + {11, 22, 33, 0.1, 55, 66}, + }, + expectCallCounts: map[string]int{ + "model0": 1, + "model1": 1, + }, + }) + }) } -func (e *multiOutputEvaluator) Signature() *domain.Signature { return e.signature } -func (e *multiOutputEvaluator) Dictionary() *common.Dictionary { return nil } -func (e *multiOutputEvaluator) Inputs() map[string]*domain.Input { return nil } -func (e *multiOutputEvaluator) Stats(map[string]interface{}) {} -func (e *multiOutputEvaluator) Close() error { return nil } -func (e *multiOutputEvaluator) ReloadIfNeeded(ctx context.Context) error { return nil } - -// TestRouter_Predict_DifferentOutputOrdering verifies that models with different -// output orderings have their outputs correctly reordered to match the router's signature -func TestRouter_Predict_DifferentOutputOrdering(t *testing.T) { - // Model1 returns: [score_a, score_b] (indices 0, 1) - // Model2 returns: [score_b, score_a] (indices 0, 1) - REVERSED ORDER - model1Sig := &domain.Signature{ - Inputs: []domain.Input{ - {Name: "text", Index: 0, Type: reflect.TypeOf("")}, - }, - Outputs: []domain.Output{ - {Name: "score_a", Index: 0, DataType: "float32"}, - {Name: "score_b", Index: 1, DataType: "float32"}, - }, - } - - model2Sig := &domain.Signature{ - Inputs: []domain.Input{ - {Name: "text", Index: 0, Type: reflect.TypeOf("")}, - }, - Outputs: []domain.Output{ - {Name: "score_b", Index: 0, DataType: "float32"}, // REVERSED - {Name: "score_a", Index: 1, DataType: "float32"}, // REVERSED - }, - } +func TestRouter_Predict_SignatureReordering(t *testing.T) { + t.Run("reverse inputs", func(t *testing.T) { + predictTest(t, predictTestCase{ + routerConfig: makeConfig(configVariant{ + numOutputs: 1, + hasGlobalModel: false, + hasOutputName: true, + }), + stringInputs: [][]string{ + {"1", "2", "3", "4", "5", "6"}, + {"1", "2", "3", "4", "5", "6"}, + }, + routerInputs: []int{ + 0, 1, 0, 1, 0, 1, + }, + expectedOutputs: [][]float32{ + {11, 22, 33, 44, 55, 66}, + }, + expectCallCounts: map[string]int{ + "model0": 1, + "model1": 1, + }, + reverseInputs: true, + }) + }) - model1Eval := &multiOutputEvaluator{modelName: "model1", signature: model1Sig} - model2Eval := &multiOutputEvaluator{modelName: "model2", signature: model2Sig} + t.Run("reverse outputs", func(t *testing.T) { + predictTest(t, predictTestCase{ + routerConfig: makeConfig(configVariant{ + numOutputs: 1, + hasGlobalModel: false, + hasOutputName: true, + }), + stringInputs: [][]string{ + {"1", "2", "3", "4", "5", "6"}, + {"1", "2", "3", "4", "5", "6"}, + }, + routerInputs: []int{ + 0, 1, 0, 1, 0, 1, + }, + expectedOutputs: [][]float32{ + {11, 22, 33, 44, 55, 66}, + }, + expectCallCounts: map[string]int{ + "model0": 1, + "model1": 1, + }, + reverseOutputs: true, + }) + }) - cfg := &config.Model{ - ID: "router_output_ordering_test", - Mode: "router", - Platform: "triton", - Router: &config.RouterConfig{ - ConfigURL: "memory://router-config", - InputName: "router_id", - Global: config.GlobalModelConfig{Exists: true}, - }, - Triton: &config.TritonConfig{ServerID: "test_server"}, - } - cfg.Init(nil) - cfg.Router.MaxQueueSize = 1000 - cfg.Router.Workers = 10 + t.Run("reversed fixed evaluator", func(t *testing.T) { + predictTest(t, predictTestCase{ + routerConfig: makeConfig(configVariant{ + numOutputs: 2, + hasGlobalModel: false, + hasOutputName: true, + reverseFPOutputs: true, + }), + stringInputs: [][]string{ + {"1", "2", "3", "4", "5", "6"}, + {"1", "2", "3", "4", "5", "6"}, + }, + routerInputs: []int{ + 0, 1, 0, 2, 0, 1, + }, + expectedOutputs: [][]float32{ + {11, 22, 33, 0.1, 55, 66}, + {111, 122, 133, 1.1, 155, 166}, + }, + expectCallCounts: map[string]int{ + "model0": 1, + "model1": 1, + }, + }) + }) - router, err := newRouter(cfg, nil, map[string]UnloadService{ - "test_server": &triton.Service{}, - }, func(modelName string) (platform.PlatformEvaluator, error) { - if modelName == "model1" { - return model1Eval, nil - } - return model2Eval, nil + t.Run("reversed everything", func(t *testing.T) { + predictTest(t, predictTestCase{ + routerConfig: makeConfig(configVariant{ + numOutputs: 2, + hasGlobalModel: false, + hasOutputName: true, + reverseFPOutputs: true, + }), + stringInputs: [][]string{ + {"1", "2", "3", "4", "5", "6"}, + {"1", "2", "3", "4", "5", "6"}, + }, + routerInputs: []int{ + 0, 1, 0, 2, 0, 1, + }, + reverseInputs: true, + reverseOutputs: true, + expectedOutputs: [][]float32{ + {11, 22, 33, 0.1, 55, 66}, + {111, 122, 133, 1.1, 155, 166}, + }, + expectCallCounts: map[string]int{ + "model0": 1, + "model1": 1, + }, + }) }) - if err != nil { - t.Fatalf("newRouter error: %v", err) - } +} - router.routingMap = map[int]string{ - 1: "model1", - 2: "model2", - } - router.routingTable = map[string]platform.PlatformEvaluator{ - "model1": model1Eval, - "model2": model2Eval, - } +func TestRouter_Predict_Queuing(t *testing.T) { + rtCfg := makeConfig(configVariant{ + numOutputs: 1, + hasGlobalModel: false, + hasOutputName: true, + }) - // Router's signature: outputs are [score_a, score_b] (this is the canonical order) - routerSig := &domain.Signature{ - Inputs: []domain.Input{ - {Name: "text", Index: 0, Type: reflect.TypeOf("")}, - {Name: "router_id", Index: 1, Type: reflect.TypeOf(int64(0))}, - }, - Outputs: []domain.Output{ - {Name: "score_a", Index: 0, DataType: "float32"}, - {Name: "score_b", Index: 1, DataType: "float32"}, - }, - } + rtCfg.MaxQueueSize = 5 + rtCfg.Workers = 1 - router.ioState = &IOState{ - inputs: map[string]*domain.Input{ - "text": &routerSig.Inputs[0], - "router_id": &routerSig.Inputs[1], - }, - signature: routerSig, - routerInputOffset: 1, - } + _, router, mockEvaluators, _, params := prepareTestRouter(t, + predictTestCase{ + routerConfig: rtCfg, + stringInputs: [][]string{ + {"1", "2", "3", "4", "5", "6"}, + {"1", "2", "3", "4", "5", "6"}, + }, + // hack to work around how output signature is built + expectedOutputs: [][]float32{ + {1}, + }, + }) - // Input data: - // Row 0: text="aa", router_id=1 (-> model1) => score_a=2, score_b=4 - // Row 1: text="bbb", router_id=2 (-> model2) => score_a=3, score_b=6 - // Row 2: text="c", router_id=1 (-> model1) => score_a=1, score_b=2 - // Row 3: text="dddd", router_id=2 (-> model2) => score_a=4, score_b=8 - params := []interface{}{ - [][]string{{"aa"}, {"bbb"}, {"c"}, {"dddd"}}, // text - [][]int64{{1}, {2}, {1}, {2}}, // router_id - } + doneGroup := &sync.WaitGroup{} + for _, evaluator := range mockEvaluators { + doneGroup.Add(1) - results, err := router.Predict(context.Background(), params) - if err != nil { - t.Fatalf("Predict error: %v", err) + evaluator.waitFor = &sync.WaitGroup{} + evaluator.doneGroup = doneGroup + evaluator.waitFor.Add(1) } - if len(results) != 2 { - t.Fatalf("expected 2 outputs, got %d", len(results)) - } + errCh := make(chan error, 10) - // Verify score_a (output index 0 in router's signature) - scoreA, ok := results[0].([][]float32) - if !ok { - t.Fatalf("expected [][]float32 for score_a, got %T", results[0]) - } - wantScoreA := [][]float32{{2}, {3}, {1}, {4}} - if !reflect.DeepEqual(scoreA, wantScoreA) { - t.Errorf("score_a mismatch:\n got: %v\n want: %v", scoreA, wantScoreA) - } + var foundError uint32 + foundErrorLock := &sync.WaitGroup{} + foundErrorLock.Add(1) - // Verify score_b (output index 1 in router's signature) - scoreB, ok := results[1].([][]float32) - if !ok { - t.Fatalf("expected [][]float32 for score_b, got %T", results[1]) - } - wantScoreB := [][]float32{{4}, {6}, {2}, {8}} - if !reflect.DeepEqual(scoreB, wantScoreB) { - t.Errorf("score_b mismatch:\n got: %v\n want: %v", scoreB, wantScoreB) - } -} + ctx := context.Background() + runPredictWG := &sync.WaitGroup{} + for pi := range 10 { + runPredictWG.Add(1) + go func() { + defer runPredictWG.Done() + _, err := router.Predict(ctx, params) + log.Printf("predict %d error: %v", pi, err) + if err != nil { + if atomic.CompareAndSwapUint32(&foundError, 0, 1) { + foundErrorLock.Done() + } -// TestRouter_Predict_FixedEvaluatorOutputOrdering verifies that the fixed evaluator -// correctly reorders its outputs to match the router's signature order, even when -// the PredictionReplacements are in a different order than the model outputs -func TestRouter_Predict_FixedEvaluatorOutputOrdering(t *testing.T) { - // Model signature has outputs: [score_a, score_b] - // But PredictionReplacements are configured in reverse order: [score_b, score_a] - // The fixed evaluator should reorder to match the router signature - - modelSig := &domain.Signature{ - Inputs: []domain.Input{ - {Name: "text", Index: 0, Type: reflect.TypeOf("")}, - }, - Outputs: []domain.Output{ - {Name: "score_a", Index: 0, DataType: "float32"}, - {Name: "score_b", Index: 1, DataType: "float32"}, - }, + errCh <- err + } + }() } - modelEval := &multiOutputEvaluator{modelName: "model1", signature: modelSig} - - // Create config with PredictionReplacements in REVERSE order from model outputs - cfg := &config.Model{ - ID: "router_fixed_ordering_test", - Mode: "router", - Platform: "triton", - Router: &config.RouterConfig{ - ConfigURL: "memory://router-config", - InputName: "router_id", - Global: config.GlobalModelConfig{ - Exists: false, // This enables fixed evaluator - PredictionReplacements: []config.PredictionReplacement{ - {Name: "score_b", Type: "float32", Value: 99.0}, // REVERSED ORDER - {Name: "score_a", Type: "float32", Value: 42.0}, // REVERSED ORDER - }, - }, - }, - Triton: &config.TritonConfig{ServerID: "test_server"}, - } - cfg.Init(nil) - cfg.Router.MaxQueueSize = 1000 - cfg.Router.Workers = 10 + unlockedCh := make(chan struct{}, 1) - router, err := newRouter(cfg, nil, map[string]UnloadService{ - "test_server": &triton.Service{}, - }, func(modelName string) (platform.PlatformEvaluator, error) { - return modelEval, nil - }) - if err != nil { - t.Fatalf("newRouter error: %v", err) - } + go func() { + foundErrorLock.Wait() + unlockedCh <- struct{}{} + }() - router.routingMap = map[int]string{ - 1: "model1", - } - router.routingTable = map[string]platform.PlatformEvaluator{ - "model1": modelEval, + dl, ok := t.Deadline() + boundCtx := ctx + if ok { + var cancel context.CancelFunc + boundCtx, cancel = context.WithDeadline(ctx, dl) + defer cancel() } - // Router signature: outputs are [score_a, score_b] - routerSig := &domain.Signature{ - Inputs: []domain.Input{ - {Name: "text", Index: 0, Type: reflect.TypeOf("")}, - {Name: "router_id", Index: 1, Type: reflect.TypeOf(int64(0))}, - }, - Outputs: []domain.Output{ - {Name: "score_a", Index: 0, DataType: "float32"}, - {Name: "score_b", Index: 1, DataType: "float32"}, - }, - } + select { + case <-boundCtx.Done(): + t.Fatalf("test timed out") - router.ioState = &IOState{ - inputs: map[string]*domain.Input{ - "text": &routerSig.Inputs[0], - "router_id": &routerSig.Inputs[1], - }, - signature: routerSig, - routerInputOffset: 1, + case <-unlockedCh: + // positive case } - // Input data: - // Row 0: text="aa", router_id=1 (-> model1) - // Row 1: text="bbb", router_id=999 (-> fixed evaluator, ID not in routingMap) - params := []interface{}{ - [][]string{{"aa"}, {"bbb"}}, - [][]int64{{1}, {999}}, // 999 not in routingMap, uses fixed evaluator + for _, evaluator := range mockEvaluators { + // unblock evaluators + evaluator.waitFor.Done() } - results, err := router.Predict(context.Background(), params) - if err != nil { - t.Fatalf("Predict error: %v", err) - } + // wait for evaluators to finish + doneGroup.Wait() + runPredictWG.Wait() - if len(results) != 2 { - t.Fatalf("expected 2 outputs, got %d", len(results)) - } + close(errCh) - // Verify score_a (output index 0 in router's signature) - // Row 0: from model, len("aa") = 2 - // Row 1: from fixed evaluator, should be 42.0 (NOT 99.0) - scoreA, ok := results[0].([][]float32) - if !ok { - t.Fatalf("expected [][]float32 for score_a, got %T", results[0]) - } - wantScoreA := [][]float32{{2}, {42}} - if !reflect.DeepEqual(scoreA, wantScoreA) { - t.Errorf("score_a mismatch:\n got: %v\n want: %v", scoreA, wantScoreA) + foundQueueSizeError := false + for e := range errCh { + if e != nil { + if strings.Contains(e.Error(), queueSizeExceededError) { + foundQueueSizeError = true + } else { + t.Fatalf("Predict error: %v", e) + } + } } - // Verify score_b (output index 1 in router's signature) - // Row 0: from model, len("aa") * 2 = 4 - // Row 1: from fixed evaluator, should be 99.0 (NOT 42.0) - scoreB, ok := results[1].([][]float32) - if !ok { - t.Fatalf("expected [][]float32 for score_b, got %T", results[1]) - } - wantScoreB := [][]float32{{4}, {99}} - if !reflect.DeepEqual(scoreB, wantScoreB) { - t.Errorf("score_b mismatch:\n got: %v\n want: %v", scoreB, wantScoreB) + if !foundQueueSizeError { + t.Fatalf("queue size exceeded not found") } } From b6100a812273e3264107fd51df971ce3975919ca Mon Sep 17 00:00:00 2001 From: David Choi Date: Wed, 4 Feb 2026 14:13:51 -0800 Subject: [PATCH 28/50] Add ARCHITECTURE doc, minor comment updates, removed unused method (technically breaking, but should've been an unused API surface). --- ARCHITECTURE.md | 193 +++++++++++++++++++++++++++++++ README.md | 7 +- service/request/request.go | 3 +- service/service.go | 2 +- service/transform/transformer.go | 4 - shared/transfer/input.go | 9 +- 6 files changed, 207 insertions(+), 11 deletions(-) create mode 100644 ARCHITECTURE.md diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md new file mode 100644 index 0000000..77a5b19 --- /dev/null +++ b/ARCHITECTURE.md @@ -0,0 +1,193 @@ +# Code Architecture + +mly consists of the following large conceptual (not strictly programmatically linked) steps and sub-steps: + +1. Configuration processing +2. Model initialization + 1. Platform evaluator creation + 2. Caching support +3. Request handling + 1. Input processing + 2. Model inference + 3. Post-prediction processing + 4. Prediction logging +4. Model reloading + +--- + +## 1 Configuration processing + +This is a standard step in many services. This will read configuration files and populate defaults. +Details are mainly in [CONFIG.md](CONFIG.md). + +*Quirk*: The configuration struct contains both the read configuration and follow-up processed configuration values. For example, `Modified` is populated during model loading, and `DictMeta` is updated when the dictionary is loaded. + +--- + +## 2 Model Initialization + +Model initialization occurs in `service.New()` which orchestrates the creation of platform-specific evaluators and supporting infrastructure. + +### 2.1 Evaluator Creation + +A core concept in mly is the "Evaluator." +An Evaluator is essentially something that can provide some kind of model inference. +All Evaluators implement `platform.PlatformEvaluator`: + +```go +type PlatformEvaluator interface { + Predict(ctx context.Context, params []interface{}) ([]interface{}, error) + Signature() *domain.Signature + Dictionary() *common.Dictionary + Inputs() map[string]*domain.Input + ReloadIfNeeded(ctx context.Context) error + Close() error +} +``` + +There are currently 3 Evaluators: + +1. TensorFlow - this Evaluator operates with a `libtensorflow` backend, and has additional logic that supports timeout-based batching. +2. Triton - this Evaluator supports sending prediction requests to a single Triton server via HTTP or gRPC. +3. Router - this Evaluator does not generate any prediction but enables rows in a prediction request to be sent to other Evaluators based on the input. + +*Design issue*: Evaluator overloading and over-abstraction - the Router operates on the same interface as the TensorFlow and Triton evaluators, but vary in their behavioral labels. + +### 2.2 Caching Support + +Caching is implemented via `shared/datastore.Service`, which provides a multi-layer cache: + +1. **Local in-memory cache** ([`scache`](https://github.com/viant/scache)): Fast local cache with TTL expiration +2. **L1 cache** (Aerospike): Primary distributed cache +3. **L2 cache** (Aerospike): Secondary distributed cache for cache warming + +An important concept in mly caching is the *Dictionary hash*. +This is stored with cached values, and is intended to invalidate entries when the model changes (e.g., there is a model weights update). + +*Quirk*: Client-read, server-write - based on the observation that if a client does not find a cache entry, then the server is unlikely to also find a cache entry, and to skip the latency overhead from a remote cache check, the server does not check for a cache entry. + +*Design debt*: The current client-read, server-write introduces a case when multiple clients concurrently find that a cache entry is missing, and sends the same request to potentially the same mly server, causing the same server to run the same prediction multiple times. This should be controlled on the server side, to avoid unnecessary compute. + +*Design debt*: Aerospike coupling - the current implementation depends on Aerospike constructs. + +--- + +## 3 Request Handling + +The mly service occupies most of its lifetime serving this purpose. Currently, mly is designed to focus around HTTP requests, using HTTP/2. + +Data flow: + +``` +HTTP Request +→ service.Handler.ServeHTTP() +→ service.Service.Do() +→ service.Evaluator.Predict() +→ service.domain.Transformer +→ service.Response +``` + +### 3.1 Input Processing + +The Input processing step is primarily focused around logic of pulling data from an HTTP-compliant, JSON or URL-based payload and pushing it into a Go (and CGo) compatible data structure for model inference. + +This step revolves mainly around the `service/request.Request` struct. + +Key components: +- **Feeds**: `[]interface{}` shaped as `[numInputs]([batchSize][1]T)` for model consumption +- **Input**: `*transfer.Input` for Transformer support + +The `UnmarshalJSONObject()` method implements `gojay.UnmarshalerJSONObject` for high-performance JSON parsing. + +The interaction with *Model inference* involves the `Feeds` field. + +*Quirk*: client batching payload - mly provides a convenience / payload reduction feature that permits payloads to have both inputs with a list of 1 values as well as inputs with a list of batch size of values. The server will expand the payload to fit the expected batch size times inputs matrix for the Evaluators. + +*Quirk*: payload reading order - the JSON payload must have the `batch_size` key existing before other input keys, as that is required to know if the parser should be expecting a list of values or scalar values. + +*Design debt*: `Feeds` type - most of the requests are tracked via input names than offsets; the intermediate data form should be a `map[string]interface{}` (or even `map[string][]interface{}` to capture a potential batch layer), and the conversion to an offset-based slice should be isolated to TensorFlow graph related code. + +### 3.2 Model Inference + +Model inference is delegated to the platform-specific evaluator via `Predict()`: + +**TensorFlow** (`service/tfmodel`): +- Optional batching via `service/tfmodel/batcher.Service` aggregates concurrent requests +- Direct evaluation via `service/tfmodel/evaluator.Service` runs TensorFlow session +- Semaphore-controlled concurrency prevents overload + +For Triton, [concurrency and timeout-based batching is controlled via Triton](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/model_configuration.html). + +**Triton** (`service/triton`): +- HTTP or gRPC call to Triton Inference Server +- Input tensors serialized per Triton protocol +- Timeout-controlled requests + +The router is a control layer that can route to various model inference services. +Each downstream evaluator is responsible for controlling their own lifetime and batching concerns. +In theory, a router can also route to another router, but no such implementation yet exists. + +**Router** (`service/platform/router`): +- Extracts routing key from input +- Groups rows by target model +- Parallel dispatch to downstream evaluators +- Result reassembly in original order + +### 3.3 Post-Prediction Processing + +After inference, `Service.buildResponse()` handles output transformation: + +**Transformer execution** +The configured `domain.Transformer` function transforms raw model output into a `common.Storable` for serialization. +The default transformer extracts values keyed by output tensor names. + +The `domain.Transformer` signature: + +```go +type Transformer func(ctx context.Context, signature *Signature, input *gtly.Object, output interface{}) (common.Storable, error) +``` + +**Cache storage** +If caching is enabled, transformed results are stored asynchronously via `datastore.Put()`. + +*Design debt*: Batch-based Transformer - the current Transformer API operates at the request level outputs but at row-level inputs, and is invoked per row. + +### 3.4 Prediction Logging + +If `Stream` is configured, the `stream.Service` logs requests for analytics: +- Request body +- Model output +- Inference duration + +Logging uses `github.com/viant/tapper` for configurable output destinations. + +--- + +## 4 Model Reloading + +Model reloading runs continuously in a background goroutine (`Service.pollModelReload()`), checking for updates at configurable intervals (`ReloadPollIntervalSeconds`). + +The `ReloadIfNeeded()` implementation is platform-specific, and varies similarly to model prediction in how much is implemented vs. delegated: + +**TensorFlow** (`service/tfmodel.Service`): +1. Check file modification times at `URL` +2. If changed: copy model files to `Location`, load SavedModel +3. Extract signature and dictionary from graph +4. Create new `service/tfmodel/evaluator.Service` and optionally `service/tfmodel/batcher.Service` +5. Atomically swap Evaluators under mutex protection + +**Triton** (`service/triton.TritonEvaluator`): +1. Check model health via `ModelReady()` API +2. If not ready and in EXPLICIT mode: call `ModelLoad()` +3. Refresh metadata via `ModelMetadata()` if signature not yet captured + +**Router** (`service/platform/router.Router`): +1. Check routing configuration file modification +2. Reload routing table if changed +3. Create/destroy downstream Evaluators as needed +4. Atomically swap routing table under mutex protection +5. Unload unused models from Triton via Model Control API + +Reload health is tracked via `Service.ReloadOK` for centralized health reporting. + +*Design issue*: Over-abstraction of `ReloadIfNeeded()` - we note that this is a very high-level abstraction that could be broken down into separate concerns e.g., check health, load model, check if reload needed, etc. \ No newline at end of file diff --git a/README.md b/README.md index 7756def..e48f79e 100644 --- a/README.md +++ b/README.md @@ -60,9 +60,9 @@ By default, the client will configure itself using the web service cache setting This enables the `mly` client to handle key generation without additional configuration or code. The library supports 3 types of caching: -- in-(process) memory -- external Aerospike cache -- hybrid +- in-(process) memory using [scache](https://github.com/viant/scache) +- external Aerospike cache (supports L1/L2 tiered caching for larger key spaces) +- hybrid (in-memory + external) The in-memory cache uses [scache](https://github.com/viant/scache)'s most-recently-used implementation. @@ -77,6 +77,7 @@ In this scenario, the L2 cache can be a very large SSD-backed Aerospike instance In this case, when we look for a cached value, first the in-memory cache is checked, followed by L1, then L2. Then with a cache miss, the value is calculated then copied to L2 - then from L2 to L1 and L1 to local memory. + **Example of `config.yaml` with both an in-memory and an Aerospike cache** ```yaml diff --git a/service/request/request.go b/service/request/request.go index ed1f3cf..2225070 100644 --- a/service/request/request.go +++ b/service/request/request.go @@ -27,7 +27,8 @@ type Request struct { supplied map[string]struct{} // used to check if the required inputs were provided - Input *transfer.Input // cache metadata + // Input is primarily used with Transformer. + Input *transfer.Input // type metadata from service/tfservice.Service.inputs // see service/tfmodel.(*Service).reconcileIOFromSignature diff --git a/service/service.go b/service/service.go index 949fcb5..c4be09d 100644 --- a/service/service.go +++ b/service/service.go @@ -45,7 +45,7 @@ type Service struct { // continueOnRecover if false, will re-panic on recover continueOnRecover bool - // TODO how does this interact with Service.inputs + // inputProvider is used in transformer.Transform() inputProvider *gtly.Provider // health status for centralized health reporting diff --git a/service/transform/transformer.go b/service/transform/transformer.go index 4120c4b..8437842 100644 --- a/service/transform/transformer.go +++ b/service/transform/transformer.go @@ -19,7 +19,3 @@ func Get(name string) (domain.Transformer, error) { // otherwise return default transformer return domain.Transform, nil } - -func ExecuteTransform() interface{} { - return nil -} diff --git a/shared/transfer/input.go b/shared/transfer/input.go index 22dfd13..6775061 100644 --- a/shared/transfer/input.go +++ b/shared/transfer/input.go @@ -6,9 +6,14 @@ import ( type Input struct { BatchSize int - Keys Strings + + // cache keys + Keys Strings + Values - Unmapped Values // values that are not part of an input + + // values that are not part of an input + Unmapped Values } func (i *Input) BatchMode() bool { From 2ed59cecdb8c17e0cfe7aba41ec0e14332ddb19b Mon Sep 17 00:00:00 2001 From: David Choi Date: Thu, 12 Feb 2026 15:27:15 -0800 Subject: [PATCH 29/50] Update ARCHITECTURE.md to clarify potential design issues and enhance descriptions of client batching payload features. --- ARCHITECTURE.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 77a5b19..9c34e38 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -51,7 +51,7 @@ There are currently 3 Evaluators: 2. Triton - this Evaluator supports sending prediction requests to a single Triton server via HTTP or gRPC. 3. Router - this Evaluator does not generate any prediction but enables rows in a prediction request to be sent to other Evaluators based on the input. -*Design issue*: Evaluator overloading and over-abstraction - the Router operates on the same interface as the TensorFlow and Triton evaluators, but vary in their behavioral labels. +*Potential design issue*: Evaluator overloading and over-abstraction - the Router operates on the same interface as the TensorFlow and Triton evaluators, but vary in their behavioral labels. ### 2.2 Caching Support @@ -101,7 +101,7 @@ The `UnmarshalJSONObject()` method implements `gojay.UnmarshalerJSONObject` for The interaction with *Model inference* involves the `Feeds` field. -*Quirk*: client batching payload - mly provides a convenience / payload reduction feature that permits payloads to have both inputs with a list of 1 values as well as inputs with a list of batch size of values. The server will expand the payload to fit the expected batch size times inputs matrix for the Evaluators. +*Quirk*: client batching payload reduction - mly provides a convenience / payload reduction feature that permits payloads to have both inputs with a list of 1 values as well as inputs with a list of batch size of values. The server will expand the payload to fit the expected batch size times inputs matrix for the Evaluators. *Quirk*: payload reading order - the JSON payload must have the `batch_size` key existing before other input keys, as that is required to know if the parser should be expecting a list of values or scalar values. From 3633eec289c75dab5e9436c28ea3028ea1be290c Mon Sep 17 00:00:00 2001 From: David Choi Date: Fri, 27 Feb 2026 09:43:28 -0800 Subject: [PATCH 30/50] Add debugging binaries to gitignore, comment router. --- .gitignore | 4 +++- service/platform/router/reload.go | 9 +++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 8408ccc..1b9dda6 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,8 @@ .DS_Cache .DS_Store +__debug_bin* + /devdata /vendor @@ -15,4 +17,4 @@ /toolsv2/aerospike/aerospike /toolsv2/smasher/cmd/cmd -/toolsv2/toolsv2 \ No newline at end of file +/toolsv2/toolsv2 diff --git a/service/platform/router/reload.go b/service/platform/router/reload.go index 9c27e28..2b0b9bd 100644 --- a/service/platform/router/reload.go +++ b/service/platform/router/reload.go @@ -294,6 +294,8 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Routin } sigInputMap := make(map[string]*domain.Input) + + // sigOutputMap is for validating output consistency sigOutputMap := make(map[string]*domain.Output) // we only create ioState on the first reload @@ -313,6 +315,8 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Routin // accept first available signature as the final signature if finalSignature == nil { srcSig := dsmi.signature + + // copy signature from downstream finalSignature = &domain.Signature{ Inputs: make([]domain.Input, len(srcSig.Inputs), len(srcSig.Inputs)+1), Outputs: make([]domain.Output, len(srcSig.Outputs)), @@ -320,8 +324,8 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Routin copy(finalSignature.Inputs, srcSig.Inputs) copy(finalSignature.Outputs, srcSig.Outputs) + // add router input inputOffset := len(finalSignature.Inputs) - routerInput := domain.Input{ Name: r.routerInputFieldName, Type: reflect.TypeOf(int64(0)), @@ -332,6 +336,7 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Routin finalSignature.Inputs = append(finalSignature.Inputs, routerInput) + // sigInputMap is for Request validation for _, input := range finalSignature.Inputs { sigInputMap[input.Name] = &input } @@ -435,7 +440,7 @@ func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.Routin // TODO this is actually an acceptable case, we can simply ignore fixed evaluator fields that aren't applicable for field := range r.fixedEvaluatorFields { if _, ok := sigOutputMap[field]; !ok { - return fmt.Errorf("fixed evaluator field: %s was not found in the signature outputs", field) + return fmt.Errorf("fixed evaluator field: %s was not found in any model outputs", field) } } From 6cb409ed3f584cf1abbf9dc536ae0382f8760a20 Mon Sep 17 00:00:00 2001 From: David Choi Date: Mon, 2 Mar 2026 15:14:12 -0800 Subject: [PATCH 31/50] Add some comments. --- service/service.go | 1 + shared/client/service.go | 2 +- shared/field.go | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/service/service.go b/service/service.go index c4be09d..34f248a 100644 --- a/service/service.go +++ b/service/service.go @@ -93,6 +93,7 @@ func (s *Service) Config() *config.Model { } // Signature is invoked after at least 1 successful ReloadIfNeeded(). +// Considered hot path. func (s *Service) Signature() *domain.Signature { return s.evaluator.Signature() } diff --git a/shared/client/service.go b/shared/client/service.go index 80ec461..81a370b 100644 --- a/shared/client/service.go +++ b/shared/client/service.go @@ -577,7 +577,7 @@ func (s *Service) discoverConfig(host *Host, URL string) (*config.Remote, error) cfg := &config.Remote{} err = json.Unmarshal(data, cfg) if err != nil { - return nil, fmt.Errorf("failed to parse load %v, config: %s, %v", URL, data, err) + return nil, fmt.Errorf("failed to parse load %v, config: %s, %v", URL, data, err) } if s.Config.Debug { diff --git a/shared/field.go b/shared/field.go index 66c9e43..e05bf2f 100644 --- a/shared/field.go +++ b/shared/field.go @@ -38,8 +38,8 @@ type ( MetaInput struct { Inputs []*Field - // This is used to order inputs and provide extra caching information to the client. - // All inputs from the model will automatically be added here. + // KeyFields is a method of forcing inputs to be part of the key even if not part of the model input. + // The primary use case of this is when there is a Transformer that depends on an Auxiliary input. KeyFields []string `json:",omitempty" yaml:",omitempty"` // Deprecated: use Field.Auxiliary From 10fb18ae9dd1eab0fcc1a1d2fe518ed77779bee5 Mon Sep 17 00:00:00 2001 From: David Choi Date: Mon, 2 Mar 2026 15:29:23 -0800 Subject: [PATCH 32/50] Fix leak in reloading. --- service/service.go | 49 ++++++++++++++++++++++++---------------------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/service/service.go b/service/service.go index 34f248a..15b3a77 100644 --- a/service/service.go +++ b/service/service.go @@ -451,33 +451,36 @@ func (s *Service) GetHealth() int32 { func (s *Service) pollModelReload() { for range s.reloadPollTicker.C { - ctx, cancel := context.WithTimeout(context.Background(), s.reloadTimeout) - defer cancel() - - stats := sstat.NewValues() - if s.reloadMetric != nil { - onDone := s.reloadMetric.Begin(time.Now()) - defer func() { - onDone(time.Now(), stats.Values()...) - }() - } - - var reloadOK int32 - err := s.evaluator.ReloadIfNeeded(ctx) - if err != nil { - stats.AppendError(err) - log.Printf("[%s reload] failed to reload model:%v", s.config.ID, err) - - reloadOK = 0 - } else { - reloadOK = 1 - } - - atomic.StoreInt32(&s.ReloadOK, reloadOK) if atomic.LoadInt32(&s.closed) != 0 { log.Printf("[%s reload] shutting down, stopping reload loop", s.config.ID) return } + + func() { + ctx, cancel := context.WithTimeout(context.Background(), s.reloadTimeout) + defer cancel() + + stats := sstat.NewValues() + if s.reloadMetric != nil { + onDone := s.reloadMetric.Begin(time.Now()) + defer func() { + onDone(time.Now(), stats.Values()...) + }() + } + + var reloadOK int32 + err := s.evaluator.ReloadIfNeeded(ctx) + if err != nil { + stats.AppendError(err) + log.Printf("[%s reload] failed to reload model:%v", s.config.ID, err) + + reloadOK = 0 + } else { + reloadOK = 1 + } + + atomic.StoreInt32(&s.ReloadOK, reloadOK) + }() } } From 75fa79255f955a118e0364d5a684ea4e88e7666f Mon Sep 17 00:00:00 2001 From: David Choi Date: Mon, 2 Mar 2026 16:33:32 -0800 Subject: [PATCH 33/50] Add Prometheus Gauge for model health/reload success, some refactors. --- service/endpoint/model.go | 28 +++++++--- service/handler.go | 2 +- service/new.go | 105 ++++++++++++++++++++++++++++++++++++++ service/service.go | 65 +++-------------------- 4 files changed, 136 insertions(+), 64 deletions(-) create mode 100644 service/new.go diff --git a/service/endpoint/model.go b/service/endpoint/model.go index cddb771..b065443 100644 --- a/service/endpoint/model.go +++ b/service/endpoint/model.go @@ -84,16 +84,26 @@ func Build( Namespace: "mly", Subsystem: "model", Name: "idletime", - - Help: "measured time between requests in nanoseconds", - - Buckets: buckets, + Help: "measured time between requests in nanoseconds", + Buckets: buckets, }, []string{"model"}) var err error err = promReg.Register(obsv) if err != nil { - return err + return fmt.Errorf("failed to register idletime histogram: %w", err) + } + + healthGauge := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: "mly", + Subsystem: "model", + Name: "reload_success", + Help: "successfully reloaded model", + }, []string{"model"}) + + err = promReg.Register(healthGauge) + if err != nil { + return fmt.Errorf("failed to register health gauge: %w", err) } serviceOpts := make([]service.Option, 0) @@ -129,7 +139,13 @@ func Build( var modelSrv *service.Service var err error - modelSrv, err = service.New(context.Background(), model, fs, metrics, datastores, tritonServices, sema, cfge.MaxEvaluatorWait, serviceOpts...) + modelSrv, err = service.NewV2(context.Background(), model, fs, metrics, service.NewArgs{ + Datastores: datastores, + TritonServices: tritonServices, + Semaphore: sema, + MaxEvaluatorWait: cfge.MaxEvaluatorWait, + HealthGauge: healthGauge, + }, serviceOpts...) if err != nil { return fmt.Errorf("failed to create service for model:%v, err:%w", model.ID, err) diff --git a/service/handler.go b/service/handler.go index 5713ea3..6ed53ac 100644 --- a/service/handler.go +++ b/service/handler.go @@ -76,7 +76,7 @@ func (h *Handler) ServeHTTP(writer http.ResponseWriter, httpRequest *http.Reques defer func() { onDone(time.Now(), stats.Values()...) }() if err != nil { - stats.Append(sstat.ReadError{err}) + stats.Append(sstat.ReadError{Error: err}) if isDebug { log.Printf("[%v http] read error: %v\n", h.service.config.ID, err) } diff --git a/service/new.go b/service/new.go new file mode 100644 index 0000000..98f54bf --- /dev/null +++ b/service/new.go @@ -0,0 +1,105 @@ +package service + +import ( + "context" + "fmt" + "reflect" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/viant/afs" + "github.com/viant/gmetric" + "github.com/viant/mly/service/config" + "github.com/viant/mly/service/platform/factory" + "github.com/viant/mly/service/stat" + "github.com/viant/mly/service/triton" + "github.com/viant/mly/shared/datastore" + sstat "github.com/viant/mly/shared/stat" + "golang.org/x/sync/semaphore" +) + +// NewArgs is an open-to-extension approach to keeping the NewV2() API invariant +// Potential tech debt: most likely we should be encapsulating parameters more appropriately. +// Natural encapsulation boundaries will emerge when we start seeing what New() also initializes along with just Service. +type NewArgs struct { + Datastores map[string]*datastore.Service + TritonServices map[string]*triton.Service + Semaphore *semaphore.Weighted + MaxEvaluatorWait time.Duration + + HealthGauge *prometheus.GaugeVec +} + +// New creates a service with platform router support +func New( + ctx context.Context, + cfg *config.Model, + fs afs.Service, + metrics *gmetric.Service, + datastores map[string]*datastore.Service, + tritonServices map[string]*triton.Service, + sema *semaphore.Weighted, + maxEvaluatorWait time.Duration, + options ...Option, +) (*Service, error) { + return NewV2(ctx, cfg, fs, metrics, NewArgs{ + Datastores: datastores, + TritonServices: tritonServices, + Semaphore: sema, + MaxEvaluatorWait: maxEvaluatorWait, + }, options...) +} + +// New creates a service with platform router support +func NewV2( + ctx context.Context, + cfg *config.Model, + fs afs.Service, + metrics *gmetric.Service, + args NewArgs, + options ...Option, +) (*Service, error) { + + if metrics == nil { + metrics = gmetric.New() + } + + location := reflect.TypeOf(Service{}).PkgPath() + + cfg.Init(nil) + + // Create platform evaluator context + evaluatorContext, err := factory.CreateEvaluator(cfg, fs, metrics, args.Semaphore, args.MaxEvaluatorWait, args.TritonServices) + if err != nil { + return nil, fmt.Errorf("failed to create platform evaluator for model %s: %w", cfg.ID, err) + } + + srv := &Service{ + config: cfg, + evaluator: evaluatorContext, + useDatastore: cfg.UseDictionary() && cfg.DataStore != "", + serviceMetric: metrics.MultiOperationCounter(location, cfg.ID+"Perf", cfg.ID+" service performance", time.Microsecond, time.Minute, 2, stat.NewProvider()), + reloadPollTicker: time.NewTicker(time.Duration(cfg.ReloadPollIntervalSeconds) * time.Second), + reloadTimeout: time.Duration(cfg.ReloadTimeoutSeconds) * time.Second, + } + + if args.HealthGauge != nil { + srv.healthGauge = args.HealthGauge.With(prometheus.Labels{"model": cfg.ID}) + } + + // Set up reload metrics for platforms that support reloading + srv.reloadMetric = metrics.MultiOperationCounter(location, cfg.ID+"Reload", cfg.ID+" reloading", time.Microsecond, time.Minute, 1, sstat.NewCtxErrOnly()) + + for _, opt := range options { + opt.Apply(srv) + } + + err = srv.initializeService(ctx, cfg, fs, metrics, args.Datastores) + if err != nil { + return nil, err + } + + go srv.pollModelReload() + + return srv, err +} diff --git a/service/service.go b/service/service.go index 15b3a77..bc9aaae 100644 --- a/service/service.go +++ b/service/service.go @@ -9,6 +9,7 @@ import ( "sync/atomic" "time" + "github.com/prometheus/client_golang/prometheus" "github.com/viant/afs" "github.com/viant/gmetric" "github.com/viant/gtly" @@ -18,19 +19,16 @@ import ( serrs "github.com/viant/mly/service/errors" "github.com/viant/mly/service/gtlyop" "github.com/viant/mly/service/platform" - "github.com/viant/mly/service/platform/factory" "github.com/viant/mly/service/request" "github.com/viant/mly/service/stat" "github.com/viant/mly/service/stream" "github.com/viant/mly/service/transform" - "github.com/viant/mly/service/triton" "github.com/viant/mly/shared" "github.com/viant/mly/shared/common" "github.com/viant/mly/shared/common/storable" "github.com/viant/mly/shared/datastore" sstat "github.com/viant/mly/shared/stat" "github.com/viant/xunsafe" - "golang.org/x/sync/semaphore" ) // Service serves as the entrypoint for using the ML model. @@ -49,9 +47,11 @@ type Service struct { inputProvider *gtly.Provider // health status for centralized health reporting - // Deprecated: use GetHealth() instead + // Deprecated: use GetHealth() or healthGauge ReloadOK int32 + healthGauge prometheus.Gauge + reloadPollTicker *time.Ticker reloadTimeout time.Duration @@ -342,59 +342,6 @@ func (s *Service) initializeService(ctx context.Context, cfg *config.Model, fs a return nil } -// New creates a service with platform router support -func New( - ctx context.Context, - cfg *config.Model, - fs afs.Service, - metrics *gmetric.Service, - datastores map[string]*datastore.Service, - tritonServices map[string]*triton.Service, - sema *semaphore.Weighted, - maxEvaluatorWait time.Duration, - options ...Option, -) (*Service, error) { - - if metrics == nil { - metrics = gmetric.New() - } - - location := reflect.TypeOf(Service{}).PkgPath() - - cfg.Init(nil) - - // Create platform evaluator context - evaluatorContext, err := factory.CreateEvaluator(cfg, fs, metrics, sema, maxEvaluatorWait, tritonServices) - if err != nil { - return nil, fmt.Errorf("failed to create platform evaluator for model %s: %w", cfg.ID, err) - } - - srv := &Service{ - config: cfg, - evaluator: evaluatorContext, - useDatastore: cfg.UseDictionary() && cfg.DataStore != "", - serviceMetric: metrics.MultiOperationCounter(location, cfg.ID+"Perf", cfg.ID+" service performance", time.Microsecond, time.Minute, 2, stat.NewProvider()), - reloadPollTicker: time.NewTicker(time.Duration(cfg.ReloadPollIntervalSeconds) * time.Second), - reloadTimeout: time.Duration(cfg.ReloadTimeoutSeconds) * time.Second, - } - - // Set up reload metrics for platforms that support reloading - srv.reloadMetric = metrics.MultiOperationCounter(location, cfg.ID+"Reload", cfg.ID+" reloading", time.Microsecond, time.Minute, 1, sstat.NewCtxErrOnly()) - - for _, opt := range options { - opt.Apply(srv) - } - - err = srv.initializeService(ctx, cfg, fs, metrics, datastores) - if err != nil { - return nil, err - } - - go srv.pollModelReload() - - return srv, err -} - // NewRequest should be used for Do() func (s *Service) NewRequest() *request.Request { numKeyInputs := s.config.KeysLen() @@ -480,6 +427,10 @@ func (s *Service) pollModelReload() { reloadOK = 1 } + if s.healthGauge != nil { + s.healthGauge.Set(float64(reloadOK)) + } + atomic.StoreInt32(&s.ReloadOK, reloadOK) }() } From 511e00187250a49b2e38b632f09a87e1a0fbacc1 Mon Sep 17 00:00:00 2001 From: David Choi Date: Mon, 2 Mar 2026 16:37:17 -0800 Subject: [PATCH 34/50] Also make Prom gauge for reload OK 1 on boot. --- service/service.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/service/service.go b/service/service.go index bc9aaae..3b94133 100644 --- a/service/service.go +++ b/service/service.go @@ -307,6 +307,9 @@ func (s *Service) initializeService(ctx context.Context, cfg *config.Model, fs a } atomic.StoreInt32(&s.ReloadOK, 1) + if s.healthGauge != nil { + s.healthGauge.Set(1) + } signature := s.Signature() if signature == nil { From c10e88bfd6a3290c8c2ba15fe6776efddceba44d Mon Sep 17 00:00:00 2001 From: David Choi Date: Wed, 11 Mar 2026 14:19:54 -0700 Subject: [PATCH 35/50] Update documentation for dictionary hash behavior. --- CONFIG.md | 13 +++++----- README.md | 7 +++--- WORKFLOW.md | 12 ++++----- shared/client/dictionary.go | 24 ++++++++++-------- shared/client/host.go | 16 ++++++------ shared/client/message-spec.yaml | 43 +++++++++++++++++++++++++++++++++ shared/common/dictionary.go | 3 +-- 7 files changed, 84 insertions(+), 34 deletions(-) create mode 100644 shared/client/message-spec.yaml diff --git a/CONFIG.md b/CONFIG.md index 2256d13..8118bc6 100644 --- a/CONFIG.md +++ b/CONFIG.md @@ -19,10 +19,10 @@ Properties: * to use GCS, set environment variable `GOOGLE_APPLICATION_CREDENTIALS=true` - `Location`: `string` - optional - where a copy of the models will be stored when loading the model. Defaults to the system temporary directory. - `Dir`: `string` - optional - any further path elements in `Location`. Mainly used if using a ZIP file with additional directories. -- `DataStore`: `string` - optional - name of Datastore to cache, should match `Datastores[].ID`. +- `DataStore`: `string` - optional - name of Datastore to use for caching, should match `Datastores[].ID`. Server-side datastore writes are enabled only when `UseDict` is `true` or unset. - `Transformer`: `string` - optional - name of model output transformer. See [#Transformer](#Transformer). - `Batch`: optional - enables or overrides server-side batching configuration. See [`service/tfmodel/batcher/config/config.go`](service/tfmodel/batcher/config/config.go). -- `UseDict`: `bool` - optional - if true, enables capabilities designed to shrink the cache key space by replacing out-of-vocabulary inputs from cache keys with a special token. +- `UseDict`: `bool` - optional - if true or unset, enables dictionary-based cache behavior, including replacing out-of-vocabulary inputs in cache keys with a special token and allowing the server to generate datastore cache entries when `DataStore` is configured. If false, the server will not generate new datastore cache entries for the model. - `Inputs`: used to further provide or define inputs, a list of `shared.Field`. For TensorFlow models, this is automatically populated, but further caching configurations need to be specified. * `Name`: `string` - required - input name, only required if an entry is provided. * `Index`: `int` - optional - used to maintain cache key ordering. @@ -73,7 +73,7 @@ Can be empty - represent a list of caching data stores. Properties: -- `ID`: `string` - required - datastore ID (to be matched with `Models[].DataStores[].ID`) +- `ID`: `string` - required - datastore ID (to be matched with `Models[].DataStore`) - `Connection`: `string` - optional - connection ID - `Namespace`: `string` - optional - Aerospike namespace - `Dataset`: `string` - optional - Aerospike dataset @@ -109,9 +109,10 @@ mly := client.New("$modelID", []*client.Host{client.NewHost("mlServiceHost", mlS ``` Where optional `options` can be of, but not limited to, the following: - * `NewCacheSize(sizeOption)` - * `NewCacheScope(CacheScopeLocal|CacheScopeL1|CacheScopeL2)` - * `NewGmetric()` - custom instance of `gmetric` service + * `WithCacheSize(sizeOption)` + * `WithCacheScope(CacheScopeLocal|CacheScopeL1|CacheScopeL2)` + * `WithGmetrics()` - custom instance of `gmetric` service + * `WithHashValidation(true)` - enables client-side rejection of cached entries with a non-zero hash that differs from the client's current dictionary hash See [`shared/client/option.go`](shared/client/option.go) for more options. diff --git a/README.md b/README.md index e48f79e..6807a15 100644 --- a/README.md +++ b/README.md @@ -105,10 +105,11 @@ See [WORKFLOW.md](WORKFLOW.md) for Mermaid diagrams explaining the Client and mo ## Dictionary hash code In caching mode, in order to manage cache and client/server consistency every time a model/dictionary gets re/loaded, `mly` computes a dictionary hash code. -This hash code gets stored in the cache along with model prediction and is passed to the client in every response. -Once a client detects a change in dictionary hash code, it automatically initiates a dictionary reload and invalidates cache entries. +This hash code gets stored in the cache along with model prediction and is passed to the client in every non-cached response. +Once a client detects a change in the dictionary hash code, it will initiate a dictionary reload and, if `client.WithHashValidation(true)` was an option on client initialization, reject any cache entry with a non-zero, different hash code. -Note: The dictionary hash code is stored under a special key in Aerospike defined in `shared/common.HashBin`. To prevent conflicts, do not use that same key name for storing your own model predictions. +**Note** The dictionary hash code is stored under a bin in Aerospike defined in `shared/common.HashBin`. +To prevent conflicts, do not use that same bin name for storing your own model predictions. # Configuration diff --git a/WORKFLOW.md b/WORKFLOW.md index d3fd475..03fc0dc 100644 --- a/WORKFLOW.md +++ b/WORKFLOW.md @@ -51,7 +51,7 @@ sequenceDiagram ```mermaid sequenceDiagram participant client as client.Service - + participant mlyserver as mly Server participant serverds as Server datastore.Service @@ -74,7 +74,7 @@ sequenceDiagram aerospike-->>datastore: KEY_NOT_FOUND_ERROR Note over datastore: L1NoSuchKey - alt L2 is configured + alt L2 is configured datastore->>aerospikel2: Get() aerospikel2-->>datastore: KEY_NOT_FOUND_ERROR Note over datastore: L2NoSuchKey @@ -94,14 +94,14 @@ sequenceDiagram alt mly Prediction Required client->>mlyserver: postRequest() - + activate mlyserver Note over mlyserver: Run TensorFlow model graph - par + par mlyserver->>serverds: Put() serverds->>aerospike: Put() - and + and mlyserver-->>client: response end @@ -143,6 +143,6 @@ sequenceDiagram client->>datastore: Put() datastore->>scache: Put() - + deactivate client ``` \ No newline at end of file diff --git a/shared/client/dictionary.go b/shared/client/dictionary.go index 663c411..10594e6 100644 --- a/shared/client/dictionary.go +++ b/shared/client/dictionary.go @@ -11,6 +11,7 @@ type fieldOffset int const ( // oov = out of vocabulary + // TODO - technically the OOV value can be overwritten - [UNK] may be a valid value oovString = "[UNK]" oovInt = 0 @@ -19,13 +20,16 @@ const ( unknownKeyField = fieldOffset(-1) ) -// Dictionary helps identify any out-of-vocabulary input values for reducing the cache space - this enables us to leverage any -// dimensionality reduction within the model to optimize wall-clock performance. This is primarily useful for categorical inputs -// as well as any continous inputs with an acceptable quantization. +// Dictionary helps identify any out-of-vocabulary input values for reducing the cache space, as well as an explicit cache-invalidation strategy via hash. +// See shared/common.Dictionary type Dictionary struct { - hash int + hash int + + // registry key is the input name registry map[string]*entry - inputs map[string]*shared.Field + + // inputs is an index, key is the input name + inputs map[string]*shared.Field } func (d *Dictionary) KeysLen() int { @@ -36,10 +40,6 @@ func (d *Dictionary) inputSize() int { return len(d.inputs) } -func (d *Dictionary) size() int { - return len(d.registry) -} - // TODO refactor, this has a singular use case func (d *Dictionary) Fields() map[string]*shared.Field { return d.inputs @@ -73,12 +73,15 @@ func (d *Dictionary) getEntry(n string) *entry { } if elem == nil { + // generally speaking, if d.registry has data, it should have data for ALL columns + // TODO this shouldn't print, it should tick some counter log.Printf("registry entry was nil for %v", n) } return elem } +// lookupString returns the mapped key, or unknownKeyField, meaning no mapping exists func (d *Dictionary) lookupString(key string, value string) (string, fieldOffset) { input := d.getInput(key) if input == nil { @@ -103,7 +106,7 @@ func (d *Dictionary) lookupString(key string, value string) (string, fieldOffset return oovString, ii } -// TODO integration and boundary testing; OOV may depend on vocabulary +// lookupInt returns the mapped key, or unknownKeyField, meaning no mapping exists func (d *Dictionary) lookupInt(key string, value int) (int, fieldOffset) { input := d.getInput(key) if input == nil { @@ -128,6 +131,7 @@ func (d *Dictionary) lookupInt(key string, value int) (int, fieldOffset) { return oovInt, ii } +// reduceFloat returns a lower-precision float key, or unknownKeyField, meaning no reduction exists func (d *Dictionary) reduceFloat(key string, value float32) (float32, int, fieldOffset) { input := d.getInput(key) if input == nil { diff --git a/shared/client/host.go b/shared/client/host.go index 5580db3..7c2f92c 100644 --- a/shared/client/host.go +++ b/shared/client/host.go @@ -13,7 +13,7 @@ import ( var defaultRequestTimeout = 50 * time.Millisecond -//Host represents endpoint host +// Host represents endpoint host type Host struct { name string port int @@ -33,22 +33,24 @@ func isSecurePort(port int) bool { return port == 443 || port == 1443 } -//IsSecurePort() returns true if secure port +// IsSecurePort() returns true if secure port func (h *Host) IsSecurePort() bool { return isSecurePort(h.port) } -//URL returns model eval URL +// URL returns model eval URL func (h *Host) evalURL(model string) string { return h.prefix + fmt.Sprintf(common.ModelURI, model) } -//URL returns meta config model eval URL +// URL returns meta config model eval URL +// See service/endpoint/meta.(*metaHandler).ServeHTTP func (h *Host) metaConfigURL(model string) string { return h.prefix + fmt.Sprintf(common.MetaConfigURI, model) } -//URL returns meta config model eval URL +// URL returns meta config model eval URL +// See service/endpoint/meta.(*metaHandler).ServeHTTP func (h *Host) metaDictionaryURL(model string) string { return h.prefix + fmt.Sprintf(common.MetaDictionaryURI, model) } @@ -93,7 +95,7 @@ func (h *Host) Port() int { return h.port } -//NewHost returns new host +// NewHost returns new host func NewHost(name string, port int) *Host { if port <= 0 { port = 80 @@ -106,7 +108,7 @@ func NewHost(name string, port int) *Host { } } -//NewHosts creates hosts +// NewHosts creates hosts func NewHosts(port int, names []string) []*Host { var result = make([]*Host, 0) for _, name := range names { diff --git a/shared/client/message-spec.yaml b/shared/client/message-spec.yaml new file mode 100644 index 0000000..c18b352 --- /dev/null +++ b/shared/client/message-spec.yaml @@ -0,0 +1,43 @@ +"$schema": "https://json-schema.org/draft/2020-12/schema" +$title: Message Specification v1.0 +type: object + +properties: + batch_size: + type: integer + default: 0 + description: | + The number of samples in this request. + Must be the first property present if batch mode is to be enabled, with a value greater than 0. + + cache_key: + oneOf: + - type: string + description: A single cache key for this request. + - type: array + items: + type: string + description: | + An array of cache keys for this request, corresponding to each batched sample. + Must be of length equal to batch_size. + +patternProperties: + ".*": + oneOf: + - oneOf: + - type: string + - type: number + description: A single value for this input. + - type: array + items: + oneOf: + - type: string + - type: number + description: | + A batch of values for input. + Must be of length equal to batch_size, or of length 1. + +examples: + - cache_key: "1234567890" + input1: "value1" + input2: "value2" \ No newline at end of file diff --git a/shared/common/dictionary.go b/shared/common/dictionary.go index 715bf0b..24d13ba 100644 --- a/shared/common/dictionary.go +++ b/shared/common/dictionary.go @@ -28,9 +28,8 @@ type ( } ) -// TODO this should use a fixed size integer? // UpdateHash will memoize dictionary hashing. -// Since wildcard fields don't provide an actual dictionary, we use the modification time information to generate a hash based on the file, passed in as fsHash. +// fsHash provides a base hash value, used for cases when a particular model doesn't have a vocabulary to be hashed. func (d *Dictionary) UpdateHash(fsHash int64) int { d.Hash = int(fsHash) From 5e6b360b60d209d1eac5dafe0c8be922eb62002d Mon Sep 17 00:00:00 2001 From: David Choi Date: Mon, 6 Apr 2026 14:01:01 -0700 Subject: [PATCH 36/50] add client support for outputing intended outputs to a file Co-authored-by: Cursor (Opus-4.6) --- example/client/option.go | 5 +++++ example/client/runner.go | 33 ++++++++++++++++++++++++++------- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/example/client/option.go b/example/client/option.go index 3adea75..2d3f045 100644 --- a/example/client/option.go +++ b/example/client/option.go @@ -48,6 +48,11 @@ type Options struct { // Will generate a final JSON object as its only output to stdout. // stderr may have other output if Debug is true or there are other errors. Report bool `long:"report"` + + // OutputFile redirects result output to a file instead of stdout. + // When set, all model output, reports, metrics, and error history + // are written to this file path. + OutputFile string `short:"o" long:"output" description:"write results to file instead of stdout"` } type C uint8 diff --git a/example/client/runner.go b/example/client/runner.go index 662f46b..55032e6 100644 --- a/example/client/runner.go +++ b/example/client/runner.go @@ -3,6 +3,7 @@ package client import ( "context" "fmt" + "io" "log" "os" "sync" @@ -21,6 +22,14 @@ import ( // Use CustomMakerRegistry with --maker to use a specific entity for Response.Data. var CustomMakerRegistry *customMakerRegistry = new(customMakerRegistry) +func dumpTo(w io.Writer, data interface{}) { + text, err := toolbox.AsJSONText(data) + if err != nil { + return + } + fmt.Fprintf(w, "%v\n", text) +} + func RunWithOptions(runOpts *Options) error { runOpts.Init() if err := runOpts.Validate(); err != nil { @@ -31,6 +40,16 @@ func RunWithOptions(runOpts *Options) error { return fmt.Errorf("could not determine model") } + output := io.Writer(os.Stdout) + if runOpts.OutputFile != "" { + f, err := os.Create(runOpts.OutputFile) + if err != nil { + return fmt.Errorf("failed to open output file %s: %w", runOpts.OutputFile, err) + } + defer f.Close() + output = f + } + payloads, err := runOpts.Payloads() if err != nil { return err @@ -135,7 +154,7 @@ func RunWithOptions(runOpts *Options) error { for i, pload := range payloads { rs.WPayloads[i] = WorkerPayload{Payload: pload} rd := &rs.WPayloads[i] - payloadedRunner := makePayloadRunner(cli, pload, runOpts, dataSetter) + payloadedRunner := makePayloadRunner(cli, pload, runOpts, dataSetter, output) fchan <- runContext{ WP: rd, @@ -183,13 +202,13 @@ func RunWithOptions(runOpts *Options) error { report.Metrics = opcs if runOpts.Metrics { - toolbox.Dump(opcs) + dumpTo(output, opcs) } if runOpts.ErrorHistory { tops := cli.ErrorHistory.TopK() for _, t := range tops { - fmt.Printf("%d %s\n", t.Count, string(t.Data)) + fmt.Fprintf(output, "%d %s\n", t.Count, string(t.Data)) } } @@ -199,7 +218,7 @@ func RunWithOptions(runOpts *Options) error { return fmt.Errorf("failed to gather prometheus metrics: %w", err) } - encoder := expfmt.NewEncoder(os.Stdout, expfmt.FmtText) + encoder := expfmt.NewEncoder(output, expfmt.FmtText) for _, mf := range mfs { if err := encoder.Encode(mf); err != nil { return fmt.Errorf("failed to encode metric family %s: %w", mf.GetName(), err) @@ -208,7 +227,7 @@ func RunWithOptions(runOpts *Options) error { } if runOpts.Report { - toolbox.Dump(report) + dumpTo(output, report) } return err @@ -254,7 +273,7 @@ func worker(worker int, echan chan error, fchan chan runContext, closed chan str } func makePayloadRunner(cli *client.Service, pl *CliPayload, runOpts *Options, - builder func(int) func() interface{}) func() (*client.Response, error) { + builder func(int) func() interface{}, output io.Writer) func() (*client.Response, error) { maker := builder(pl.Batch) @@ -284,7 +303,7 @@ func makePayloadRunner(cli *client.Service, pl *CliPayload, runOpts *Options, } if !runOpts.NoOutput { - toolbox.Dump(response) + dumpTo(output, response) } return response, nil From c5ad92dea7c856bcfa3a530f798570cbbded4a15 Mon Sep 17 00:00:00 2001 From: David Choi Date: Fri, 24 Apr 2026 15:40:10 -0700 Subject: [PATCH 37/50] fix(server handler): no longer swallow Marshal error --- service/handler.go | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/service/handler.go b/service/handler.go index 6ed53ac..c621815 100644 --- a/service/handler.go +++ b/service/handler.go @@ -189,13 +189,23 @@ func (h *Handler) handleAppRequest(ctx context.Context, writer io.Writer, reques } func (h *Handler) writeResponse(writer io.Writer, appResponse *Response) error { - appResponse.ServiceTimeMcs = int(time.Now().Sub(appResponse.started).Microseconds()) + appResponse.ServiceTimeMcs = int(time.Since(appResponse.started).Microseconds()) data, err := gojay.Marshal(appResponse) if h.service.config.Debug { log.Printf("[%v write] output:%s", h.service.config.ID, data) } + + if err != nil { + return fmt.Errorf("failed to marshal: %w", err) + } + _, err = writer.Write(data) - return err + + if err != nil { + return fmt.Errorf("failed to write: %w", err) + } + + return nil } func (h *Handler) trackIdle() { From 3d04cb322de570884c644827a0f9a105eac66e5e Mon Sep 17 00:00:00 2001 From: David Choi Date: Mon, 27 Apr 2026 14:27:55 -0700 Subject: [PATCH 38/50] Surface io.ReadAll error on 200 OK in client httpPost Previously, a partial or aborted body read on a 200 OK response was silently discarded and httpPost returned (possibly-empty data, nil). Callers then unmarshaled empty bytes and the failure surfaced downstream as "Invalid JSON, wrong char ' ' found at position 0", which is indistinguishable from a server-side encoding bug. Return the read error wrapped with the partial byte count so the caller sees an honest transport failure instead of a silent empty-body success. Made-with: Cursor Co-authored-by: Claude --- shared/client/service.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/shared/client/service.go b/shared/client/service.go index 1b72425..99a7b88 100644 --- a/shared/client/service.go +++ b/shared/client/service.go @@ -664,6 +664,14 @@ func (s *Service) httpPost(ctx context.Context, data []byte, host *Host) ([]byte response.StatusCode, string(data), response.Body == nil, err) } + if err != nil { + // 200 OK with a partial / aborted body read is not a success. + // Surfacing this prevents callers from silently unmarshaling an empty body + // (observed downstream as "Invalid JSON, wrong char ' ' found at position 0"). + return nil, fmt.Errorf("HTTP Code:%d, partial body read: %w (got %d bytes)", + response.StatusCode, err, len(data)) + } + return data, nil }() From 96786634d71c35353eb917d5dc2f0dd03132158f Mon Sep 17 00:00:00 2001 From: David Choi Date: Mon, 27 Apr 2026 16:59:13 -0700 Subject: [PATCH 39/50] refactor(service/handler): explicit-commit response writing with dedicated metrics Replace the implicit-commit pattern in writeResponse where the first writer.Write call auto-fired WriteHeader(200) at the same time as the body bytes. Under load this produced a wire trace of "200 OK + empty body" when the body Write failed (broken pipe due to client cancellation), and the server-side error was silently discarded by the surrounding code. writeResponse now: - Marshals first; on marshal failure returns a typed responseMarshalError so ServeHTTP can emit an explicit 5xx with a meaningful body. - Sets Content-Type and Content-Length explicitly so clients can detect a truncated body via io.ErrUnexpectedEOF instead of receiving a silent empty 200 OK. - Calls WriteHeader(200) explicitly so the status line is committed in a known order, not as a side effect of the first Write. - Returns a typed responseCommittedError when Write fails after the status was committed, so ServeHTTP knows to log unconditionally rather than calling http.Error (which would silently drop the new status code with a "superfluous WriteHeader" warning). Inline the single-caller handleAppRequest helper into ServeHTTP -- the io.Writer narrowing it provided is no longer useful now that writeResponse takes an http.ResponseWriter directly. Add a new gmetric counter provider in service/stat (NewHandler) that extends NewCtxErrOnly with two response-write failure classes (responseMarshalError, responseCommittedError) so they can be alerted independently of the generic error bucket. The first three keys (error, canceled, deadlineExceeded) preserve their indices so existing Prometheus dashboards and alerts continue to emit at the same labels. Add unit tests in service/handler_test.go covering the success path, total and partial post-commit Write failures, the Content-Length invariant across a range of Response shapes, and the typed-error chain participation in errors.As / errors.Is. Document the response wire contract in README under /v1/api/model/%s/eval (Content-Type and Content-Length explicit) and the new HTTPHandler metric keys under /v1/api/metric/operations. Made-with: Cursor Co-authored-by: Claude --- README.md | 3 +- service/handler.go | 108 ++++++++++++++++---- service/handler_test.go | 217 ++++++++++++++++++++++++++++++++++++++++ service/stat/handler.go | 106 ++++++++++++++++++++ 4 files changed, 413 insertions(+), 21 deletions(-) create mode 100644 service/handler_test.go create mode 100644 service/stat/handler.go diff --git a/README.md b/README.md index 94ae4c8..55d4b41 100644 --- a/README.md +++ b/README.md @@ -226,6 +226,7 @@ In all these, `%s` is `Model[].ID` (i.e. from `config.yaml`) - `/v1/api/metric/operation/%sDictMeta` - Records metrics to client dictionary fetch. - `/v1/api/metric/operation/%sCfgMeta` - Records metrics to client configuration fetch. - `/v1/api/metric/operation/%sMetaHandler` - Records server-side metrics to client set up. +- `/v1/api/metric/operation/%sHTTPHandler` - Records per-request HTTP handler metrics. Includes `responseMarshalError` (response struct could not be marshaled; server returned 500) and `responseCommittedError` (body write failed after status was committed; client sees `200 OK` with a body shorter than `Content-Length`). ## `/v1/api/debug` @@ -238,7 +239,7 @@ Model operations. In all these, `%s` is `Model[].ID` (i.e. from `config.yaml`) -- `/v1/api/model/%s/eval` - runs `GET` / `POST` model prediction. +- `/v1/api/model/%s/eval` - runs `GET` / `POST` model prediction. Successful responses set `Content-Type: application/json` and an explicit `Content-Length`; a short read against the declared length indicates a transport failure, not an empty payload. - `/v1/api/model/%s/meta/config` - provides configuration for client related to model - `/v1/api/model/%s/meta/dictionary` - provides current dictionary diff --git a/service/handler.go b/service/handler.go index 5713ea3..5a9baf3 100644 --- a/service/handler.go +++ b/service/handler.go @@ -5,10 +5,10 @@ import ( "encoding/json" "errors" "fmt" - "io" "log" "net/http" "reflect" + "strconv" "strings" "sync" "time" @@ -25,6 +25,29 @@ import ( "github.com/viant/mly/shared/stat" ) +// responseMarshalError signals that gojay.Marshal of the Response struct +// failed during writeResponse. The HTTP response is NOT yet committed, +// so ServeHTTP can still emit an explicit 5xx with a meaningful body. +// Surfaced as a typed error so it can be routed to its own metric bucket +// (sstat.ResponseMarshalError) and distinguished from upstream errors. +type responseMarshalError struct{ err error } + +func (e *responseMarshalError) Error() string { return e.err.Error() } +func (e *responseMarshalError) Unwrap() error { return e.err } + +// responseCommittedError signals that the HTTP response status line and +// headers have already been flushed to the client when the wrapped error +// occurred. The caller MUST NOT attempt to send a different status code: +// net/http will drop the second WriteHeader and emit a "superfluous +// response.WriteHeader call" warning, while the client still observes the +// original (200) status. Surfaced so ServeHTTP can log + exit instead of +// trying to overwrite the status line, and so the failure can be routed +// to its own metric bucket (sstat.ResponseCommittedError). +type responseCommittedError struct{ err error } + +func (e *responseCommittedError) Error() string { return e.err.Error() } +func (e *responseCommittedError) Unwrap() error { return e.err } + // Handler converts a model prediction HTTP request to its internal calls. type Handler struct { maxDuration time.Duration @@ -76,7 +99,7 @@ func (h *Handler) ServeHTTP(writer http.ResponseWriter, httpRequest *http.Reques defer func() { onDone(time.Now(), stats.Values()...) }() if err != nil { - stats.Append(sstat.ReadError{err}) + stats.Append(sstat.ReadError{Error: err}) if isDebug { log.Printf("[%v http] read error: %v\n", h.service.config.ID, err) } @@ -101,7 +124,7 @@ func (h *Handler) ServeHTTP(writer http.ResponseWriter, httpRequest *http.Reques err = gojay.Unmarshal(data[:size], request) if err != nil { werr := fmt.Errorf("unmarshal error: %w data: %s", err, string(data[:size])) - stats.Append(sstat.UnmarshalError{werr}) + stats.Append(sstat.UnmarshalError{Error: werr}) if isDebug { log.Printf("[%v http] unmarshal error: %v\n", h.service.config.ID, err) @@ -127,7 +150,13 @@ func (h *Handler) ServeHTTP(writer http.ResponseWriter, httpRequest *http.Reques return } - err := h.handleAppRequest(ctx, writer, request, response) + err := h.service.Do(ctx, request, response) + if err != nil { + response.SetError(err) + } else { + err = h.writeResponse(writer, response) + } + if isDebug { data, merr := json.Marshal(response.Data) @@ -144,6 +173,29 @@ func (h *Handler) ServeHTTP(writer http.ResponseWriter, httpRequest *http.Reques } if err != nil { + // If the response was already committed (status + headers flushed), + // the wire status code is fixed at 200 and cannot be changed. Calling + // http.Error here would log "superfluous WriteHeader" and silently + // drop the new status — the bidder still sees 200 + truncated body. + // Log unconditionally so this defect is visible in production, and + // emit a dedicated metric so it can be alerted independently of + // the generic ErrorKey bucket. + var committed *responseCommittedError + if errors.As(err, &committed) { + hStats.Append(sstat.ResponseCommittedError{Error: err}) + log.Printf("[%v http] response committed but write failed: %v", h.service.config.ID, err) + return + } + + // Marshal failure: response NOT committed; we will emit an explicit + // 5xx below. Track it in its own metric bucket so the operator can + // distinguish "we never sent anything" from "we sent something we + // shouldn't have". + var marshal *responseMarshalError + if errors.As(err, &marshal) { + hStats.Append(sstat.ResponseMarshalError{Error: err}) + } + var status int if _, ok := err.(*clienterr.ClientError); ok { status = http.StatusBadRequest @@ -175,27 +227,43 @@ func (h *Handler) buildRequestFromQuery(httpRequest *http.Request, request *requ return nil } -func (h *Handler) handleAppRequest(ctx context.Context, writer io.Writer, request *request.Request, response *Response) error { - if err := h.service.Do(ctx, request, response); err != nil { - response.SetError(err) - return err - } +// writeResponse marshals appResponse and emits it with explicit-commit +// semantics: +// +// - Marshal first; on failure return a plain error — the response is NOT +// yet committed and ServeHTTP can still set a 5xx status. +// - Set Content-Length explicitly so a truncated body is detectable on +// the client side as io.ErrUnexpectedEOF (without it, the client cannot +// distinguish "done" from "connection broke mid-body" on a 200 OK). +// - Call WriteHeader(200) explicitly so the status line is committed in a +// known order, not as a side effect of the first Write. +// - On Write failure return responseCommittedError so the caller knows +// the status code can no longer be changed. +// +// This addresses the silent "200 OK + empty body" failure mode where a +// canceled connection caused the implicit auto-200 from Write to flush +// headers while the body bytes were lost. +func (h *Handler) writeResponse(writer http.ResponseWriter, appResponse *Response) error { + appResponse.ServiceTimeMcs = int(time.Since(appResponse.started).Microseconds()) - if err := h.writeResponse(writer, response); err != nil { - return err + data, err := gojay.Marshal(appResponse) + if err != nil { + return &responseMarshalError{err: fmt.Errorf("marshal response: %w", err)} } - return nil -} - -func (h *Handler) writeResponse(writer io.Writer, appResponse *Response) error { - appResponse.ServiceTimeMcs = int(time.Now().Sub(appResponse.started).Microseconds()) - data, err := gojay.Marshal(appResponse) if h.service.config.Debug { log.Printf("[%v write] output:%s", h.service.config.ID, data) } - _, err = writer.Write(data) - return err + + writer.Header().Set("Content-Type", "application/json") + writer.Header().Set("Content-Length", strconv.Itoa(len(data))) + writer.WriteHeader(http.StatusOK) + + if _, err := writer.Write(data); err != nil { + return &responseCommittedError{err: fmt.Errorf("write response body: %w", err)} + } + + return nil } func (h *Handler) trackIdle() { @@ -224,6 +292,6 @@ func NewHandler(service *Service, pool *buffer.Pool, maxDuration time.Duration, lrObserver: lrOV.With(prometheus.Labels{"model": modelID}), overheadMetrics: m.MultiOperationCounter(location, modelID+"HTTPOverhead", modelID+" server HTTP startup overhead", time.Microsecond, time.Minute, 2, sstat.NewHttp()), - httpContextMetrics: m.MultiOperationCounter(location, modelID+"HTTPHandler", modelID+" server HTTP handler", time.Microsecond, time.Minute, 2, stat.NewCtxErrOnly()), + httpContextMetrics: m.MultiOperationCounter(location, modelID+"HTTPHandler", modelID+" server HTTP handler", time.Microsecond, time.Minute, 2, sstat.NewHandler()), } } diff --git a/service/handler_test.go b/service/handler_test.go new file mode 100644 index 0000000..bf7a171 --- /dev/null +++ b/service/handler_test.go @@ -0,0 +1,217 @@ +package service + +import ( + "errors" + "net/http" + "net/http/httptest" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/mly/service/config" +) + +// writeFailingResponseWriter wraps httptest.ResponseRecorder so that +// Write returns a configurable error after a configurable number of +// bytes. Used to simulate a broken-pipe condition where the client has +// already closed the connection (e.g. the bidder's 40 ms client timeout +// fired while MLY was mid-response). httptest.ResponseRecorder by itself +// never errors on Write. +// +// failAfter == 0 → the very first Write call errors after writing zero +// body bytes (the headers have still been committed at that point by +// writer.WriteHeader, which is the precise condition that produced the +// observed 200-OK-with-empty-body wire trace). +type writeFailingResponseWriter struct { + *httptest.ResponseRecorder + failAfter int // bytes written successfully before Write starts erroring + written int // total successful body bytes so far + failErr error +} + +func newWriteFailingResponseWriter(failAfter int) *writeFailingResponseWriter { + return &writeFailingResponseWriter{ + ResponseRecorder: httptest.NewRecorder(), + failAfter: failAfter, + failErr: errors.New("simulated broken pipe"), + } +} + +func (w *writeFailingResponseWriter) Write(p []byte) (int, error) { + remaining := w.failAfter - w.written + if remaining <= 0 { + return 0, w.failErr + } + if len(p) <= remaining { + n, err := w.ResponseRecorder.Write(p) + w.written += n + return n, err + } + n, _ := w.ResponseRecorder.Write(p[:remaining]) + w.written += n + return n, w.failErr +} + +// newTestHandler constructs a Handler with the minimum scaffolding +// required to exercise writeResponse. The metric Operations are left +// nil — writeResponse does not touch them, only ServeHTTP does. Tests +// that exercise ServeHTTP need a different fixture (not provided here +// because Service.Do depends on a fully wired tfmodel.Service which is +// out of scope for a unit test). +func newTestHandler(modelID string, debug bool) *Handler { + return &Handler{ + service: &Service{ + config: &config.Model{ID: modelID, Debug: debug}, + }, + } +} + +// TestWriteResponse_Success verifies the happy path: a fully populated +// Response is marshaled, headers are set explicitly (Content-Type, +// Content-Length), the status is committed at 200, and the body bytes +// match the marshaled JSON. This locks in the explicit-commit contract +// that lets clients detect truncation via Content-Length mismatch. +func TestWriteResponse_Success(t *testing.T) { + h := newTestHandler("test", false) + resp := &Response{ + Status: "ok", + DictHash: 42, + started: time.Now().Add(-time.Millisecond), + } + + rec := httptest.NewRecorder() + err := h.writeResponse(rec, resp) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "application/json", rec.Header().Get("Content-Type"), + "Content-Type must be set explicitly") + + cl, atoiErr := strconv.Atoi(rec.Header().Get("Content-Length")) + require.NoError(t, atoiErr, "Content-Length must be a parseable integer") + assert.Equal(t, rec.Body.Len(), cl, + "Content-Length must match actual body length so clients can detect truncation") + + body := rec.Body.String() + assert.Contains(t, body, `"status":"ok"`) + assert.Contains(t, body, `"dictHash":42`) + assert.Contains(t, body, `"serviceTimeMcs":`, + "serviceTimeMcs must be present so the bidder can record mly_eval_duration_us") +} + +// TestWriteResponse_WriteFailureReturnsCommittedError simulates the +// failure mode that drives the bidder's invalid_json class on the wire: +// the body Write fails (broken pipe) AFTER WriteHeader has already +// committed the 200 status line. The post-condition is that: +// - writeResponse returns *responseCommittedError so the caller knows +// the status code can no longer be changed, +// - the status code on the wire is the originally-committed 200 (NOT +// the 500 we would otherwise want to send), +// - Content-Length was set, so a downstream client correctly checking +// it would observe an early-EOF / unexpected-EOF condition rather +// than silently treating the empty body as a valid response. +func TestWriteResponse_WriteFailureReturnsCommittedError(t *testing.T) { + h := newTestHandler("test", false) + resp := &Response{Status: "ok", started: time.Now()} + + rec := newWriteFailingResponseWriter(0) + err := h.writeResponse(rec, resp) + + require.Error(t, err) + + var committed *responseCommittedError + require.True(t, errors.As(err, &committed), + "expected *responseCommittedError, got %T: %v", err, err) + assert.ErrorIs(t, err, rec.failErr, + "wrapped error chain must reach the underlying broken-pipe error") + + assert.Equal(t, http.StatusOK, rec.Code, + "status was committed before Write failed; explicit-commit contract") + assert.NotEmpty(t, rec.Header().Get("Content-Length"), + "Content-Length must be set BEFORE Write so client can detect truncation") + assert.Equal(t, 0, rec.Body.Len(), + "no body bytes should have been written on the failAfter=0 case") +} + +// TestWriteResponse_PartialWriteReturnsCommittedError covers the +// truncated-body case: the headers + status flush, then a few body +// bytes succeed, then the connection breaks. The committed-error type +// must still surface so ServeHTTP's error branch knows not to call +// http.Error (which would emit "superfluous WriteHeader" log noise). +func TestWriteResponse_PartialWriteReturnsCommittedError(t *testing.T) { + h := newTestHandler("test", false) + resp := &Response{Status: "ok", started: time.Now()} + + rec := newWriteFailingResponseWriter(5) + err := h.writeResponse(rec, resp) + + require.Error(t, err) + var committed *responseCommittedError + require.True(t, errors.As(err, &committed), + "partial write must also yield *responseCommittedError") + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, 5, rec.Body.Len(), + "exactly 5 body bytes should have been flushed before failure") +} + +// TestWriteResponse_HasContentLengthMatchingBody locks in the invariant +// that Content-Length declared in the header equals the bytes the +// handler intends to write. Without this, a client cannot distinguish +// "done" from "connection broke mid-body" on a 200 OK response — which +// is the root mechanism that allowed the bidder-side io.ReadAll swallow +// (shared/client/service.go) to silently produce empty-body +// invalid_json events. +func TestWriteResponse_HasContentLengthMatchingBody(t *testing.T) { + h := newTestHandler("test", false) + + cases := []struct { + name string + resp *Response + }{ + {"empty", &Response{started: time.Now()}}, + {"with-status", &Response{Status: "ok", started: time.Now()}}, + {"with-error", &Response{Status: "error", Error: "something failed", started: time.Now()}}, + {"with-dict-hash", &Response{Status: "ok", DictHash: 0xdeadbeef, started: time.Now()}}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + rec := httptest.NewRecorder() + require.NoError(t, h.writeResponse(rec, tc.resp)) + + cl, atoiErr := strconv.Atoi(rec.Header().Get("Content-Length")) + require.NoError(t, atoiErr) + assert.Equal(t, rec.Body.Len(), cl, + "declared Content-Length must equal actual body bytes") + }) + } +} + +// TestResponseCommittedError_Unwrap verifies that the typed error +// participates correctly in errors.Is / errors.As chains. ServeHTTP +// relies on errors.As to detect committed-error from arbitrarily-deep +// wrappings. +func TestResponseCommittedError_Unwrap(t *testing.T) { + inner := errors.New("underlying broken pipe") + wrapped := &responseCommittedError{err: inner} + + var target *responseCommittedError + assert.True(t, errors.As(wrapped, &target)) + assert.True(t, errors.Is(wrapped, inner), + "errors.Is must traverse Unwrap to the underlying cause") +} + +// TestResponseMarshalError_Unwrap is the symmetric assertion for the +// marshal-failure sentinel, used by ServeHTTP to route marshal failures +// to their own metric bucket while still emitting an HTTP 5xx response. +func TestResponseMarshalError_Unwrap(t *testing.T) { + inner := errors.New("malformed Response struct") + wrapped := &responseMarshalError{err: inner} + + var target *responseMarshalError + assert.True(t, errors.As(wrapped, &target)) + assert.True(t, errors.Is(wrapped, inner)) +} diff --git a/service/stat/handler.go b/service/stat/handler.go new file mode 100644 index 0000000..e62e98a --- /dev/null +++ b/service/stat/handler.go @@ -0,0 +1,106 @@ +package stat + +import ( + "github.com/viant/gmetric/counter" + "github.com/viant/mly/shared/stat" +) + +const ( + // ResponseMarshalErrorKey counts requests where the prediction succeeded + // but gojay.Marshal of the Response struct failed. The HTTP response was + // NOT committed: ServeHTTP recovers by emitting an explicit 500. + ResponseMarshalErrorKey = "responseMarshalError" + + // ResponseCommittedErrorKey counts requests where status + headers were + // already flushed to the client (200 OK) when the body Write failed. + // This is the server-side counterpart to the bidder-observed + // "200 OK + empty/truncated body → invalid_json" failure mode. + // A non-zero rate here indicates either: + // - clients are closing the connection mid-response (most common + // under load when client deadline < server response time), or + // - HTTP/1.1 keepalive desync producing broken pipes on reuse. + // Distinct from ErrorKey so it can be alerted independently. + ResponseCommittedErrorKey = "responseCommittedError" +) + +// ResponseMarshalError is a stat marker for the gmetric provider. The +// embedded error is retained for top-K error sampling; the struct itself +// is intentionally NOT an `error` so the type-switch in Map can route it +// to its own bucket without colliding with the generic error case. +type ResponseMarshalError struct{ Error error } + +// String implements fmt.Stringer (used by gmetric top-K error sampling). +func (r ResponseMarshalError) String() string { return r.Error.Error() } + +// Aggregate implements github.com/viant/gmetric/counter.CustomCounter. +func (r ResponseMarshalError) Aggregate(interface{}) {} + +// ResponseCommittedError is the analogous stat marker for post-commit +// write failures. See ResponseCommittedErrorKey for the operational +// significance. +type ResponseCommittedError struct{ Error error } + +func (r ResponseCommittedError) String() string { return r.Error.Error() } +func (r ResponseCommittedError) Aggregate(interface{}) {} + +// handler is the gmetric counter.Provider for service.Handler.ServeHTTP. +// It is a strict superset of shared/stat.NewCtxErrOnly(): the first three +// keys (ErrorKey, Canceled, DeadlineExceeded) preserve their indices so +// any existing Prometheus dashboards/alerts on those buckets continue to +// emit at the same labels. Two new keys are appended for the explicit +// response-write failure classes introduced by the explicit-commit +// refactor of writeResponse. +type handler struct{} + +// Keys returns the stat key labels in stable index order. Order matters: +// gmetric's counter buckets are addressed by index, and changing the +// order of the first three would silently re-label existing series. +func (h handler) Keys() []string { + return []string{ + stat.ErrorKey, // 0 + stat.Canceled, // 1 + stat.DeadlineExceeded, // 2 + ResponseMarshalErrorKey, // 3 + ResponseCommittedErrorKey, // 4 + } +} + +// Map routes a value to its key index. Concrete struct cases come BEFORE +// the generic `error` case to ensure typed stat markers route to their +// dedicated buckets even if a future change makes them satisfy `error`. +func (h handler) Map(value interface{}) int { + if value == nil { + return -1 + } + + if _, ok := value.(ResponseMarshalError); ok { + return 3 + } + if _, ok := value.(ResponseCommittedError); ok { + return 4 + } + + switch v := value.(type) { + case error: + return 0 + case string: + switch v { + case stat.Canceled: + return 1 + case stat.DeadlineExceeded: + return 2 + case ResponseMarshalErrorKey: + return 3 + case ResponseCommittedErrorKey: + return 4 + } + } + + return -1 +} + +// NewHandler returns the counter.Provider used by Handler.ServeHTTP's +// per-request httpContextMetrics. +func NewHandler() counter.Provider { + return handler{} +} From 7a77e4544ad1a185caf585a2b113a7bd098d331f Mon Sep 17 00:00:00 2001 From: David Choi Date: Mon, 27 Apr 2026 17:08:36 -0700 Subject: [PATCH 40/50] docs(readme): restructure server endpoints with dedicated eval section The /v1/api/model/%s/eval endpoint is the only one with a structured request/response shape; collapse the previous flat bullet list under /v1/api/model and promote each model endpoint to a peer top-level section so eval can grow sub-sections (Methods, Request body, Successful response, Error response, Example) without distorting the surrounding density. Document the current contract: - Methods (GET / POST) with single-mode vs batch-mode semantics. - Request body shape (input keys + reserved batch_size / cache_key). - Successful response shape (status, dictHash, data, serviceTimeMcs) with Content-Type and Content-Length set explicitly so clients can detect transport failures via short-read against declared length. - Error response status-code table (400 / 413 / 429 / 500) with a note that the body format is plain text today and is expected to align with the success response shape in a future release. - Curl examples for both methods. Promote the meta endpoints to peer ## sections with one-line descriptions of their purpose for the mly client bootstrap path. Made-with: Cursor Co-authored-by: Claude --- README.md | 85 +++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 79 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 55d4b41..2eb3ab0 100644 --- a/README.md +++ b/README.md @@ -233,15 +233,88 @@ In all these, `%s` is `Model[].ID` (i.e. from `config.yaml`) Requires `EnableMemProf` and / or `EnableCPUProf` to be enabled. See [`service/endpoint/prof.go`](service/endpoint/prof.go) for details - otherwise, refer to `pprof` documentation. -## `/v1/api/model` +In the following sections, `%s` is `Model[].ID` (i.e. from `config.yaml`). -Model operations. +## `/v1/api/model/%s/eval` -In all these, `%s` is `Model[].ID` (i.e. from `config.yaml`) +Runs a model prediction. This is the primary data-plane endpoint and the only one with a structured request and response shape; the others are administrative or metadata. + +### Methods + +- `GET` - input values supplied as URL query parameters. Single-prediction mode only. +- `POST` - JSON body containing input values plus optional `batch_size` and `cache_key`. Supports both single-prediction and batch mode. + +### Request body (POST) + +A JSON object whose keys are model input names (as defined by the model's signature). Values are either scalars (single mode) or arrays of length `batch_size` or `1` (batch mode). Two reserved optional keys: + +- `batch_size` (integer, optional) - if present and `> 0`, switches the request to batch mode. Other input values must then be arrays. +- `cache_key` (string in single mode, string array of length `batch_size` in batch mode, optional) - explicit cache key(s) to use instead of letting the server derive one from the input values. + +A minimal single-mode request: + +```json +{"input1": "value1", "input2": 42} +``` + +A batch-mode request with two predictions and explicit cache keys: + +```json +{ + "batch_size": 2, + "cache_key": ["k1", "k2"], + "input1": ["value1", "value2"], + "input2": [42, 43] +} +``` + +### Successful response + +`200 OK` with `Content-Type: application/json` and an explicit `Content-Length`. The body is a JSON object: + +```json +{"status": "ok", "dictHash": 12345, "data": {...}, "serviceTimeMcs": 1100} +``` + +- `status` - always `"ok"` on success. +- `dictHash` - hash of the dictionary the prediction was made against. Clients use this to detect dictionary changes and trigger a reload (see [Dictionary hash code](#dictionary-hash-code)). +- `data` - the model output, shape determined by the model and any registered transformer. +- `serviceTimeMcs` - server-side time spent on this request in microseconds. + +A short read against the declared `Content-Length` indicates a transport failure (peer closed mid-response, broken pipe, etc.), not a successful empty response. Clients should surface short reads as errors rather than treating them as empty bodies. + +### Error response + +Errors return a non-2xx HTTP status code: + +| status | cause | +| ------ | ----- | +| `400 Bad Request` | malformed query string, malformed JSON body, type mismatch on an input value, or any client-side input error | +| `413 Request Entity Too Large` | POST body exceeds the server's request buffer | +| `429 Too Many Requests` | server is overloaded (evaluator queue rejected the request) | +| `500 Internal Server Error` | prediction failure, server-side encoding failure, or any other server-side error | + +The error body is currently a plain-text message (Go `http.Error` format). Future versions are expected to align the error response with the success response shape; this section will be updated when that lands. + +### Example + +```bash +# GET, single prediction +curl 'http://localhost:8086/v1/api/model/ml0/eval?input1=value1&input2=42' + +# POST, batch prediction +curl -X POST 'http://localhost:8086/v1/api/model/ml0/eval' \ + -H 'Content-Type: application/json' \ + -d '{"batch_size":2,"input1":["a","b"],"input2":[1,2]}' +``` + +## `/v1/api/model/%s/meta/config` + +Returns the client configuration derived from the model (cache settings, input/output schema, etc.). Used by `mly` clients to bootstrap. + +## `/v1/api/model/%s/meta/dictionary` -- `/v1/api/model/%s/eval` - runs `GET` / `POST` model prediction. Successful responses set `Content-Type: application/json` and an explicit `Content-Length`; a short read against the declared length indicates a transport failure, not an empty payload. -- `/v1/api/model/%s/meta/config` - provides configuration for client related to model -- `/v1/api/model/%s/meta/dictionary` - provides current dictionary +Returns the current dictionary (categorical input vocabularies + dictionary hash). Used by `mly` clients to populate the local cache and detect dictionary changes. # Client Metrics (`gmetric`) From b6da5cec9986af9083ed41420c23547d8bfe8cea Mon Sep 17 00:00:00 2001 From: David Choi Date: Mon, 27 Apr 2026 17:13:53 -0700 Subject: [PATCH 41/50] refactor(service/handler): emit JSON error responses with proper status codes Unify the wire-format contract across success and failure: every response from /v1/api/model/%s/eval is now a JSON-encoded Response object with Content-Type: application/json and explicit Content-Length, regardless of whether the prediction succeeded or any phase failed. Errors retain their proper HTTP status codes (400 / 413 / 429 / 500); only the body shape changes, from plain-text http.Error output to the same Response struct used on success with status="error" and the error field populated. writeResponse now takes an http.StatusXxx parameter so the same explicit-commit machinery (marshal first, declare Content-Length, WriteHeader, Write) can serve both 200 success and 4xx/5xx error responses. writeError is a new helper that: - Calls SetError(err) and clears response.Data so the marshal cannot fail on a Data value that may have triggered the original failure. - Delegates to writeResponse with the supplied status code. - On post-commit Write failure, appends ResponseCommittedError to hStats and logs unconditionally (matching the success-path contract). - On marshal failure of the cleared error response (essentially impossible -- the struct now contains only string + int fields), falls back to http.Error so the client at least receives a status code. All five http.Error call sites in ServeHTTP (GET-path query parse, POST body read, POST body unmarshal, no-request fall-through, and the post-Service.Do error block) now route through writeError. The post-Service.Do block keeps its existing responseCommittedError short-circuit since the success-response status is already on the wire and cannot be changed. Cross-repo audit (viant/mly, adelphic/mly, adelphic/mediator) found no consumer that parses the previous plain-text error body format. The one near-miss -- mediator's filter/ml/roas/service.go isTimeoutError -- substring-matches "context deadline exceeded" / "timeout" / "deadline exceeded" / "context canceled" against the error string, which keeps working because the substring is now embedded in the JSON error body wrapped by the client's httpPost error format. Update tests: - Existing writeResponse tests get the new status parameter. - New TestWriteResponse_HonorsStatusParam confirms the supplied status reaches the wire across the four error status codes plus 200. - New TestWriteError_EmitsJSONErrorWithStatus locks in the wire shape (status code, Content-Type, Content-Length, JSON body with status=error and populated error message; cleared Data). - New TestWriteError_FallsBackToHTTPErrorOnCommittedFailure confirms hStats records the post-commit failure when the body write fails after the error status was committed. Update README error response section to describe the new contract. Made-with: Cursor Co-authored-by: Claude --- README.md | 8 +++- service/handler.go | 83 +++++++++++++++++++++++++++------- service/handler_test.go | 99 +++++++++++++++++++++++++++++++++++++++-- 3 files changed, 168 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 2eb3ab0..b4252e9 100644 --- a/README.md +++ b/README.md @@ -285,7 +285,11 @@ A short read against the declared `Content-Length` indicates a transport failure ### Error response -Errors return a non-2xx HTTP status code: +Errors return a non-2xx HTTP status code with the same `Content-Type: application/json` and explicit `Content-Length` as the success response. The body is the same `Response` JSON object with `status` set to `"error"`, the `error` field populated, and `data` omitted: + +```json +{"status": "error", "error": "", "serviceTimeMcs": 1100} +``` | status | cause | | ------ | ----- | @@ -294,7 +298,7 @@ Errors return a non-2xx HTTP status code: | `429 Too Many Requests` | server is overloaded (evaluator queue rejected the request) | | `500 Internal Server Error` | prediction failure, server-side encoding failure, or any other server-side error | -The error body is currently a plain-text message (Go `http.Error` format). Future versions are expected to align the error response with the success response shape; this section will be updated when that lands. +Clients can therefore parse the response body the same way regardless of HTTP status — the only differences are the status code and which fields are populated. As a fallback for the rare case where the server cannot encode an error response, a plain-text body may be returned with the same status code. ### Example diff --git a/service/handler.go b/service/handler.go index 5a9baf3..45dadae 100644 --- a/service/handler.go +++ b/service/handler.go @@ -85,7 +85,7 @@ func (h *Handler) ServeHTTP(writer http.ResponseWriter, httpRequest *http.Reques if httpRequest.Method == http.MethodGet { request = h.service.NewRequest() if err := h.buildRequestFromQuery(httpRequest, request); err != nil { - http.Error(writer, err.Error(), http.StatusBadRequest) + h.writeError(writer, response, hStats, http.StatusBadRequest, err) return } } else { @@ -109,7 +109,7 @@ func (h *Handler) ServeHTTP(writer http.ResponseWriter, httpRequest *http.Reques code = http.StatusRequestEntityTooLarge } - http.Error(writer, err.Error(), code) + h.writeError(writer, response, hStats, code, err) return err } @@ -130,8 +130,8 @@ func (h *Handler) ServeHTTP(writer http.ResponseWriter, httpRequest *http.Reques log.Printf("[%v http] unmarshal error: %v\n", h.service.config.ID, err) } - rmsg := fmt.Sprintf("%s (are your input types correct?)", err.Error()) - http.Error(writer, rmsg, http.StatusBadRequest) + displayErr := fmt.Errorf("%s (are your input types correct?)", err.Error()) + h.writeError(writer, response, hStats, http.StatusBadRequest, displayErr) return err } @@ -146,7 +146,7 @@ func (h *Handler) ServeHTTP(writer http.ResponseWriter, httpRequest *http.Reques if request == nil { // This isn't a particularly helpful message. // Currently, the only case this handles is if the request is too large. - http.Error(writer, "no request", http.StatusBadRequest) + h.writeError(writer, response, hStats, http.StatusBadRequest, errors.New("no request")) return } @@ -154,7 +154,7 @@ func (h *Handler) ServeHTTP(writer http.ResponseWriter, httpRequest *http.Reques if err != nil { response.SetError(err) } else { - err = h.writeResponse(writer, response) + err = h.writeResponse(writer, response, http.StatusOK) } if isDebug { @@ -175,8 +175,8 @@ func (h *Handler) ServeHTTP(writer http.ResponseWriter, httpRequest *http.Reques if err != nil { // If the response was already committed (status + headers flushed), // the wire status code is fixed at 200 and cannot be changed. Calling - // http.Error here would log "superfluous WriteHeader" and silently - // drop the new status — the bidder still sees 200 + truncated body. + // writeError here would log "superfluous WriteHeader" and silently + // drop the new status — the client still sees 200 + truncated body. // Log unconditionally so this defect is visible in production, and // emit a dedicated metric so it can be alerted independently of // the generic ErrorKey bucket. @@ -190,7 +190,8 @@ func (h *Handler) ServeHTTP(writer http.ResponseWriter, httpRequest *http.Reques // Marshal failure: response NOT committed; we will emit an explicit // 5xx below. Track it in its own metric bucket so the operator can // distinguish "we never sent anything" from "we sent something we - // shouldn't have". + // shouldn't have". writeError clears response.Data before retrying, + // so the second marshal cannot fail for the same reason. var marshal *responseMarshalError if errors.As(err, &marshal) { hStats.Append(sstat.ResponseMarshalError{Error: err}) @@ -209,7 +210,7 @@ func (h *Handler) ServeHTTP(writer http.ResponseWriter, httpRequest *http.Reques log.Printf("[%v http] status:%d error:%v", h.service.config.ID, status, err) } - http.Error(writer, err.Error(), status) + h.writeError(writer, response, hStats, status, err) } } @@ -230,20 +231,25 @@ func (h *Handler) buildRequestFromQuery(httpRequest *http.Request, request *requ // writeResponse marshals appResponse and emits it with explicit-commit // semantics: // -// - Marshal first; on failure return a plain error — the response is NOT -// yet committed and ServeHTTP can still set a 5xx status. +// - Marshal first; on failure return a typed responseMarshalError -- the +// response is NOT yet committed and the caller can still set a different +// status (typically a 5xx). // - Set Content-Length explicitly so a truncated body is detectable on // the client side as io.ErrUnexpectedEOF (without it, the client cannot // distinguish "done" from "connection broke mid-body" on a 200 OK). -// - Call WriteHeader(200) explicitly so the status line is committed in a -// known order, not as a side effect of the first Write. +// - Call WriteHeader(status) explicitly so the status line is committed +// in a known order, not as a side effect of the first Write. // - On Write failure return responseCommittedError so the caller knows // the status code can no longer be changed. // +// status is typically http.StatusOK for success responses; the writeError +// helper passes the appropriate 4xx/5xx for error responses so the wire +// shape is uniform across success and failure paths. +// // This addresses the silent "200 OK + empty body" failure mode where a // canceled connection caused the implicit auto-200 from Write to flush // headers while the body bytes were lost. -func (h *Handler) writeResponse(writer http.ResponseWriter, appResponse *Response) error { +func (h *Handler) writeResponse(writer http.ResponseWriter, appResponse *Response, status int) error { appResponse.ServiceTimeMcs = int(time.Since(appResponse.started).Microseconds()) data, err := gojay.Marshal(appResponse) @@ -257,7 +263,7 @@ func (h *Handler) writeResponse(writer http.ResponseWriter, appResponse *Respons writer.Header().Set("Content-Type", "application/json") writer.Header().Set("Content-Length", strconv.Itoa(len(data))) - writer.WriteHeader(http.StatusOK) + writer.WriteHeader(status) if _, err := writer.Write(data); err != nil { return &responseCommittedError{err: fmt.Errorf("write response body: %w", err)} @@ -266,6 +272,51 @@ func (h *Handler) writeResponse(writer http.ResponseWriter, appResponse *Respons return nil } +// writeError emits an error response with the given HTTP status code as +// a JSON-encoded Response object (status="error", populated error +// message). It is the error-path counterpart to writeResponse and shares +// the same explicit-commit contract so clients always see Content-Length +// and a parseable JSON body regardless of success or failure. +// +// Side-effects on the response struct: +// +// - response.SetError(err) populates response.Error and sets +// response.Status = "error". +// - response.Data is cleared. This guarantees the marshal will succeed +// regardless of the prior state of Data, which matters when the +// original failure was itself a marshal error on a populated Data +// value. +// +// On a post-commit write failure (responseCommittedError) the status is +// already on the wire; we only log + emit the dedicated metric. +// +// On a marshal failure of the (cleared) error response (essentially +// impossible -- the struct now contains only string + int fields), we +// fall back to http.Error so the client at least receives a status code. +func (h *Handler) writeError(writer http.ResponseWriter, response *Response, hStats *stat.Values, status int, err error) { + response.SetError(err) + response.Data = nil + + werr := h.writeResponse(writer, response, status) + if werr == nil { + return + } + + var committed *responseCommittedError + if errors.As(werr, &committed) { + hStats.Append(sstat.ResponseCommittedError{Error: werr}) + log.Printf("[%v http] error response committed but write failed: %v (original error: %v)", h.service.config.ID, werr, err) + return + } + + var marshal *responseMarshalError + if errors.As(werr, &marshal) { + hStats.Append(sstat.ResponseMarshalError{Error: werr}) + } + log.Printf("[%v http] failed to write error response: %v (original error: %v)", h.service.config.ID, werr, err) + http.Error(writer, err.Error(), status) +} + func (h *Handler) trackIdle() { now := time.Now() h.lrLock.Lock() diff --git a/service/handler_test.go b/service/handler_test.go index bf7a171..a33e59d 100644 --- a/service/handler_test.go +++ b/service/handler_test.go @@ -11,6 +11,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/viant/mly/service/config" + sstat "github.com/viant/mly/service/stat" + "github.com/viant/mly/shared/stat" ) // writeFailingResponseWriter wraps httptest.ResponseRecorder so that @@ -82,7 +84,7 @@ func TestWriteResponse_Success(t *testing.T) { } rec := httptest.NewRecorder() - err := h.writeResponse(rec, resp) + err := h.writeResponse(rec, resp, http.StatusOK) require.NoError(t, err) assert.Equal(t, http.StatusOK, rec.Code) @@ -117,7 +119,7 @@ func TestWriteResponse_WriteFailureReturnsCommittedError(t *testing.T) { resp := &Response{Status: "ok", started: time.Now()} rec := newWriteFailingResponseWriter(0) - err := h.writeResponse(rec, resp) + err := h.writeResponse(rec, resp, http.StatusOK) require.Error(t, err) @@ -145,7 +147,7 @@ func TestWriteResponse_PartialWriteReturnsCommittedError(t *testing.T) { resp := &Response{Status: "ok", started: time.Now()} rec := newWriteFailingResponseWriter(5) - err := h.writeResponse(rec, resp) + err := h.writeResponse(rec, resp, http.StatusOK) require.Error(t, err) var committed *responseCommittedError @@ -180,7 +182,7 @@ func TestWriteResponse_HasContentLengthMatchingBody(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { rec := httptest.NewRecorder() - require.NoError(t, h.writeResponse(rec, tc.resp)) + require.NoError(t, h.writeResponse(rec, tc.resp, http.StatusOK)) cl, atoiErr := strconv.Atoi(rec.Header().Get("Content-Length")) require.NoError(t, atoiErr) @@ -215,3 +217,92 @@ func TestResponseMarshalError_Unwrap(t *testing.T) { assert.True(t, errors.As(wrapped, &target)) assert.True(t, errors.Is(wrapped, inner)) } + +// TestWriteResponse_HonorsStatusParam verifies that writeResponse +// commits the supplied status code rather than always 200. This is the +// foundation for the unified error-response wire format: writeError +// uses writeResponse with 4xx/5xx so success and error responses share +// shape (Content-Type, Content-Length, JSON body) and only differ in +// status line + populated fields. +func TestWriteResponse_HonorsStatusParam(t *testing.T) { + h := newTestHandler("test", false) + cases := []int{ + http.StatusOK, + http.StatusBadRequest, + http.StatusRequestEntityTooLarge, + http.StatusTooManyRequests, + http.StatusInternalServerError, + } + for _, status := range cases { + t.Run(http.StatusText(status), func(t *testing.T) { + resp := &Response{Status: "ok", started: time.Now()} + rec := httptest.NewRecorder() + require.NoError(t, h.writeResponse(rec, resp, status)) + assert.Equal(t, status, rec.Code) + assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) + assert.NotEmpty(t, rec.Header().Get("Content-Length")) + }) + } +} + +// TestWriteError_EmitsJSONErrorWithStatus locks in the wire shape +// promised to clients on the error path: 4xx/5xx + JSON Response body +// with status="error", populated error message, and serviceTimeMcs. +// This is what makes a defensive consumer-side check +// (e.g. mediator's `if response.Error != "" { ... }`) actually fire on +// real predict-time errors instead of silently no-op'ing. +func TestWriteError_EmitsJSONErrorWithStatus(t *testing.T) { + h := newTestHandler("test", false) + resp := &Response{Status: "ok", started: time.Now(), Data: "leftover-data"} + rec := httptest.NewRecorder() + hStats := stat.NewValues() + + h.writeError(rec, resp, hStats, http.StatusInternalServerError, errors.New("upstream blew up")) + + assert.Equal(t, http.StatusInternalServerError, rec.Code, + "error status must reach the wire") + assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) + + cl, atoiErr := strconv.Atoi(rec.Header().Get("Content-Length")) + require.NoError(t, atoiErr) + assert.Equal(t, rec.Body.Len(), cl) + + body := rec.Body.String() + assert.Contains(t, body, `"status":"error"`, + "writeError must populate response.Status as error") + assert.Contains(t, body, `"error":"upstream blew up"`, + "writeError must populate response.Error from the supplied error") + assert.NotContains(t, body, "leftover-data", + "writeError must clear response.Data so the original Data does not leak into the error body") +} + +// TestWriteError_FallsBackToHTTPErrorOnCommittedFailure verifies that +// when the error response's body Write fails after status commit, the +// fallback path appends a metric and returns without panic. The status +// is already on the wire so http.Error inside the fallback is a no-op, +// but the metric attribution and log line are what matter for +// diagnosing the cliff scenario where both the success and error +// responses fail to flush. +func TestWriteError_FallsBackToHTTPErrorOnCommittedFailure(t *testing.T) { + h := newTestHandler("test", false) + resp := &Response{Status: "ok", started: time.Now()} + rec := newWriteFailingResponseWriter(0) + hStats := stat.NewValues() + + h.writeError(rec, resp, hStats, http.StatusInternalServerError, errors.New("upstream blew up")) + + assert.Equal(t, http.StatusInternalServerError, rec.Code, + "status was committed before the body write failed") + assert.NotEmpty(t, hStats.Values(), + "hStats must record the post-commit failure for metric attribution") + + var sawCommitted bool + for _, v := range hStats.Values() { + if _, ok := v.(sstat.ResponseCommittedError); ok { + sawCommitted = true + break + } + } + assert.True(t, sawCommitted, + "hStats must include a sstat.ResponseCommittedError marker") +} From 90a8e1cd0c9dcb6f7615dbd3d1bab9a69308a7e0 Mon Sep 17 00:00:00 2001 From: David Choi Date: Mon, 27 Apr 2026 17:17:56 -0700 Subject: [PATCH 42/50] feat(shared/client): parse JSON error body to populate response.Error Surface the server-side error message into the caller's Response struct on non-2xx responses, so callers that defensively check response.Error (e.g. mediator's ml/fraud and ml/fraudv3 services) finally observe predict-time errors instead of silently swallowing them. Two coordinated changes: 1. httpPost now returns the response body alongside the error on a non-2xx terminal status. Previously it returned (nil, err) on non-200, which discarded the body and prevented any structured parsing. The retry-loop logic is updated to capture the body in a postBody variable so terminal-error iterations preserve the bytes for the outer return. Successful iterations return data directly with err == nil as before. 2. Run() does a best-effort gojay.Unmarshal of the body into the caller's response struct before returning the err. For a v0.20.0+ server's JSON error body this populates response.Status = "error" and response.Error = "". For an older server's plain-text body the unmarshal silently fails and the response struct stays untouched. The non-nil err return is unchanged in either case and remains the source-of-truth signal; this is purely additive population of the response struct for the callers that want it. Backward compatibility is preserved both directions: - New v0.20.0 client against an older server: response.Error stays empty (best-effort unmarshal silently fails), err is returned as before, and any consumer that checks err keeps working unchanged. - Older client against a new v0.20.0 server: client never sees the populated response.Error (the older client's httpPost returns nil body on non-200), but err is still surfaced and consumers that check err keep working unchanged. Add TestService_Run_ParsesErrorBody covering four cases: 400+JSON, 500+JSON, 400+plain-text, 500+plain-text. Each asserts both the err return and the response.Error/Status population behavior so the contract is locked in for both server-version axes. Made-with: Cursor Co-authored-by: Claude --- shared/client/service.go | 33 +++++++++- shared/client/service_test.go | 116 ++++++++++++++++++++++++++++++++++ 2 files changed, 146 insertions(+), 3 deletions(-) diff --git a/shared/client/service.go b/shared/client/service.go index 99a7b88..0108a73 100644 --- a/shared/client/service.go +++ b/shared/client/service.go @@ -164,6 +164,17 @@ func (s *Service) Run(ctx context.Context, input interface{}, response *Response }() if err != nil { + // Best-effort: parse the body as a Response struct so callers + // that check response.Error see the server-side error message + // (v0.20.0+ servers emit a structured JSON error body alongside + // the HTTP 4xx/5xx; older servers emit plain text and the + // unmarshal silently fails, leaving response untouched). + // The returned err remains the source-of-truth signal; this is + // purely additive population of the response struct. + if len(body) > 0 { + _ = gojay.Unmarshal(body, response) + } + stats.AppendError(err) if ctx.Err() == nil && s.ErrorHistory != nil { go s.ErrorHistory.AddBytes([]byte(err.Error())) @@ -625,6 +636,13 @@ func (s *Service) httpPost(ctx context.Context, data []byte, host *Host) ([]byte evalUrl := host.evalURL(s.Model) var terminate bool var postErr error + // postBody captures the response body across retry iterations so that + // non-2xx terminal errors can return the JSON error body alongside the + // error. Run() does a best-effort unmarshal of this body to populate + // response.Error and response.Status from a v0.20.0+ server's structured + // error response. Older servers return plain-text bodies; the best-effort + // unmarshal silently fails on those, leaving response untouched. + var postBody []byte for i := 0; i < s.MaxRetry; i++ { data, err := func() ([]byte, error) { onDone := s.httpCliCounter.Begin(time.Now()) @@ -660,7 +678,11 @@ func (s *Service) httpPost(ctx context.Context, data []byte, host *Host) ([]byte // as long as this func is run synchronously, // this is safe terminate = true - return nil, fmt.Errorf("HTTP Code:%d, Body:\"%s\" (read nil:%v error:%v)", + // Return the body so the caller can parse the JSON error response + // (v0.20.0+ servers emit a Response struct here; older servers emit + // plain text). The error keeps the same wrapping format for backward + // compatibility with consumers that string-match on it. + return data, fmt.Errorf("HTTP Code:%d, Body:\"%s\" (read nil:%v error:%v)", response.StatusCode, string(data), response.Body == nil, err) } @@ -677,6 +699,11 @@ func (s *Service) httpPost(ctx context.Context, data []byte, host *Host) ([]byte if err != nil { postErr = err + // Capture body for terminal errors so the caller can parse it. + // On retryable errors data is nil, so this is a no-op there. + if data != nil { + postBody = data + } } if terminate || ctx.Err() != nil { @@ -684,12 +711,12 @@ func (s *Service) httpPost(ctx context.Context, data []byte, host *Host) ([]byte break } - if data != nil { + if data != nil && err == nil { return data, nil } } - return nil, postErr + return postBody, postErr } func (s *Service) getHost() (*Host, error) { diff --git a/shared/client/service_test.go b/shared/client/service_test.go index 4d3810f..1b9be76 100644 --- a/shared/client/service_test.go +++ b/shared/client/service_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/viant/bintly" "github.com/viant/mly/shared" cconfig "github.com/viant/mly/shared/client/config" @@ -228,3 +229,118 @@ func TestService_Run(t *testing.T) { } } } + +// TestService_Run_ParsesErrorBody verifies that when the server returns +// a non-2xx response with a JSON-encoded Response body (the v0.20.0+ +// error-response contract), Run() does a best-effort unmarshal of the +// body so the caller's response struct has Status="error" and Error +// populated -- in addition to receiving a non-nil err return value. +// +// Backward-compatibility: when the server returns a plain-text body +// (older mly versions, or any non-JSON body), the unmarshal silently +// fails and the response struct stays untouched. The non-nil err +// return remains the source-of-truth signal in either case. +func TestService_Run_ParsesErrorBody(t *testing.T) { + baseURL := toolbox.CallerDirectory(3) + + selectPort := 8088 + server := faker.Server{URL: path.Join(baseURL, "testdata"), Port: selectPort, Debug: true} + server.Start() + defer server.Stop() + + metaInput := shared.MetaInput{ + Inputs: []*shared.Field{ + {Name: "i1"}, + {Name: "i2", Wildcard: true}, + }, + } + dictionary := NewDictionary(&common.Dictionary{ + Layers: []common.Layer{{Name: "i1", Strings: []string{"v1", "v2"}}}, + Hash: 123, + }, metaInput.Inputs) + hosts := []*Host{NewHost("localhost", selectPort)} + options := []Option{ + WithRemoteConfig(&cconfig.Remote{ + Datastore: config.Datastore{ + Cache: &scache.Config{SizeMb: 64, Shards: 10, EntrySize: 1024}, + }, + MetaInput: metaInput, + }), + WithCacheScope(CacheScopeLocal), + WithDictionary(dictionary), + WithDataStorer(mock.New()), + WithDebug(true), + } + + cases := []struct { + description string + bodyContentType string + body string + statusCode int + expectErrorMsg string // non-empty if response.Error should be populated + expectStatus string // non-empty if response.Status should be populated + }{ + { + description: "v0.20.0 server: 400 JSON error body populates response.Error", + bodyContentType: "application/json", + body: `{"status":"error","error":"invalid input shape","serviceTimeMcs":150}`, + statusCode: http.StatusBadRequest, + expectErrorMsg: "invalid input shape", + expectStatus: common.StatusError, + }, + { + description: "v0.20.0 server: 500 JSON error body populates response.Error", + bodyContentType: "application/json", + body: `{"status":"error","error":"upstream blew up","serviceTimeMcs":2200}`, + statusCode: http.StatusInternalServerError, + expectErrorMsg: "upstream blew up", + expectStatus: common.StatusError, + }, + { + description: "older server: 400 plain-text body leaves response.Error empty", + bodyContentType: "text/plain", + body: "bad request\n", + statusCode: http.StatusBadRequest, + expectErrorMsg: "", + expectStatus: "", // gojay.Unmarshal silently fails on non-JSON; struct untouched + }, + { + description: "older server: 500 plain-text body leaves response.Error empty", + bodyContentType: "text/plain", + body: "server error\n", + statusCode: http.StatusInternalServerError, + expectErrorMsg: "", + expectStatus: "", + }, + } + + for _, tc := range cases { + t.Run(tc.description, func(t *testing.T) { + body := tc.body + contentType := tc.bodyContentType + statusCode := tc.statusCode + server.Handler.Then(func(d []byte, w http.ResponseWriter) { + w.Header().Set("Content-Type", contentType) + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(body))) + w.WriteHeader(statusCode) + _, _ = w.Write([]byte(body)) + }) + + srv, err := New("error_body_case", hosts, options...) + require.NoError(t, err) + + msg := srv.NewMessage() + msg.StringKey("i1", "v1") + msg.StringKey("i2", "v10") + + response := &Response{Data: &TestOutput{}} + err = srv.Run(context.Background(), msg, response) + + assert.Error(t, err, "non-2xx must always surface as a non-nil err return") + assert.Equal(t, tc.expectErrorMsg, response.Error, + "response.Error population (best-effort JSON unmarshal of error body)") + assert.Equal(t, tc.expectStatus, response.Status, + "response.Status population (best-effort JSON unmarshal of error body)") + }) + } +} From b4db014e6012b9c9a50a6abed959cb92040bfcca Mon Sep 17 00:00:00 2001 From: David Choi Date: Mon, 27 Apr 2026 17:18:41 -0700 Subject: [PATCH 43/50] docs(readme): add v0.20.0 versioning note Record the v0.20.0 wire-format change: error responses from /v1/api/model/%s/eval now share the same shape as success responses (JSON-encoded Response object, Content-Type: application/json, explicit Content-Length) instead of plain text. HTTP status codes are unchanged. The client side picks this up automatically: response.Error is populated from the parsed body on non-2xx responses, so callers can rely on either the err return value or response.Error as the error signal. Older clients connecting to a v0.20.0+ server still work (the err return is unchanged); newer clients connecting to a pre- v0.20.0 server still work (the best-effort body parse silently fails on plain text and response.Error stays empty, with the err return remaining the source-of-truth signal). Made-with: Cursor Co-authored-by: Claude --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index b4252e9..3b75c82 100644 --- a/README.md +++ b/README.md @@ -340,6 +340,7 @@ all compatible with Apache License, Version 2. Please see individual files for d # Versioning Notes +- `v0.20.0` - error responses from `/v1/api/model/%s/eval` are now JSON-encoded `Response` objects (same `Content-Type` / `Content-Length` contract as success responses) instead of plain text; HTTP status codes are unchanged. The client populates `response.Error` from the parsed body, so consumers can rely on either the `err` return value or `response.Error` as the error signal. - `v0.14.1` last support for go 1.17 - `v0.8.0` - numeric features are supported. From 70da0ae2942caa95c2aab4a385c9e44ce8633aac Mon Sep 17 00:00:00 2001 From: David Choi Date: Thu, 30 Apr 2026 13:19:42 -0700 Subject: [PATCH 44/50] feat(shared/client): add ClientHTTP_shed marker for breaker-rejected requests Add stat.Shed key to the http counter provider and append it in postRequest when getHost returns an error wrapping common.ErrNodeDown. Existing _down (trip event) and _error counters are unchanged; the new _shed disambiguates breaker-rejected requests from requests that reached httpPost and failed there. Co-authored-by: Claude Opus 4.7 Made-with: Cursor --- shared/client/service.go | 9 ++++ shared/client/service_test.go | 81 +++++++++++++++++++++++++++++++++++ shared/stat/http.go | 16 ++++++- 3 files changed, 105 insertions(+), 1 deletion(-) diff --git a/shared/client/service.go b/shared/client/service.go index 0108a73..81a8916 100644 --- a/shared/client/service.go +++ b/shared/client/service.go @@ -5,6 +5,7 @@ import ( "context" "crypto/tls" "encoding/json" + "errors" "fmt" "io" "log" @@ -615,6 +616,14 @@ func (s *Service) postRequest(ctx context.Context, data []byte, mvt *stat.Values // TODO per-host counters host, err := s.getHost() if err != nil { + // getHost returns ErrNodeDown when the host's breaker IsUp() is + // false. Mark the request as shed so the operator can distinguish + // requests rejected pre-flight by the breaker from requests that + // reached httpPost and failed there. Without this, every shed + // request was conflated into the generic _error counter. + if errors.Is(err, common.ErrNodeDown) { + mvt.Append(stat.Shed) + } return nil, err } diff --git a/shared/client/service_test.go b/shared/client/service_test.go index 1b9be76..ed299ec 100644 --- a/shared/client/service_test.go +++ b/shared/client/service_test.go @@ -2,6 +2,7 @@ package client import ( "context" + "errors" "fmt" "net/http" "path" @@ -11,12 +12,14 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/viant/bintly" + "github.com/viant/gmetric" "github.com/viant/mly/shared" cconfig "github.com/viant/mly/shared/client/config" "github.com/viant/mly/shared/client/faker" "github.com/viant/mly/shared/common" "github.com/viant/mly/shared/config" "github.com/viant/mly/shared/datastore/mock" + "github.com/viant/mly/shared/stat" "github.com/viant/scache" "github.com/viant/toolbox" ) @@ -230,6 +233,84 @@ func TestService_Run(t *testing.T) { } } +// TestService_Run_ShedIncrementsBreakerShedMetric verifies that when the +// host's circuit breaker is in the down state at request time, Run() +// increments the new ClientHTTP_shed marker on the http counter (and does +// NOT increment _down, which is reserved for the trip event itself). +// +// Before this fix, shed requests were conflated into the generic _error +// counter, leaving operators unable to distinguish "request rejected +// pre-flight by the breaker" from "request reached httpPost and failed +// there." See shared/client/service.go postRequest. +func TestService_Run_ShedIncrementsBreakerShedMetric(t *testing.T) { + baseURL := toolbox.CallerDirectory(3) + + selectPort := 8089 + server := faker.Server{URL: path.Join(baseURL, "testdata"), Port: selectPort, Debug: true} + server.Start() + defer server.Stop() + + metaInput := shared.MetaInput{ + Inputs: []*shared.Field{ + {Name: "i1"}, + {Name: "i2", Wildcard: true}, + }, + } + dictionary := NewDictionary(&common.Dictionary{ + Layers: []common.Layer{{Name: "i1", Strings: []string{"v1", "v2"}}}, + Hash: 123, + }, metaInput.Inputs) + hosts := []*Host{NewHost("localhost", selectPort)} + + gmetrics := gmetric.New() + const modelID = "shed_metric_case" + options := []Option{ + WithGmetrics(gmetrics), + WithRemoteConfig(&cconfig.Remote{ + Datastore: config.Datastore{ + Cache: &scache.Config{SizeMb: 64, Shards: 10, EntrySize: 1024}, + }, + MetaInput: metaInput, + }), + WithCacheScope(CacheScopeLocal), + WithDictionary(dictionary), + WithDataStorer(mock.New()), + WithDebug(true), + } + srv, err := New(modelID, hosts, options...) + require.NoError(t, err) + + // Force the host's breaker into the down state so getHost() will + // return ErrNodeDown without ever calling httpPost. + hosts[0].FlagDown() + require.False(t, hosts[0].IsUp(), "host must be flagged down for the shed path") + + msg := srv.NewMessage() + msg.StringKey("i1", "v1") + msg.StringKey("i2", "v10") + + response := &Response{Data: &TestOutput{}} + err = srv.Run(context.Background(), msg, response) + + require.Error(t, err, "shed request must surface as a non-nil err") + assert.True(t, errors.Is(err, common.ErrNodeDown), "shed err must wrap ErrNodeDown, got %v", err) + + // Inspect the cumulative counter values for ClientHTTP. The + // new Shed marker must increment by 1; the existing Down marker must + // stay at 0 (no FlagDown was called by this request -- the breaker + // was already down before getHost was called). + shedCount := gmetrics.LookupOperationCumulativeMetric(modelID+"ClientHTTP", stat.Shed) + downCount := gmetrics.LookupOperationCumulativeMetric(modelID+"ClientHTTP", stat.Down) + errorCount := gmetrics.LookupOperationCumulativeMetric(modelID+"ClientHTTP", stat.ErrorKey) + + assert.EqualValues(t, 1, shedCount, "ClientHTTP_shed must increment on shed") + assert.EqualValues(t, 0, downCount, "ClientHTTP_down must NOT increment on shed (only on trip)") + // _error still increments because Run()'s AppendError fires for the + // non-context ErrNodeDown -- this is the historical behavior that + // the new _shed marker disambiguates without changing. + assert.EqualValues(t, 1, errorCount, "ClientHTTP_error continues to increment as before") +} + // TestService_Run_ParsesErrorBody verifies that when the server returns // a non-2xx response with a JSON-encoded Response body (the v0.20.0+ // error-response contract), Run() does a best-effort unmarshal of the diff --git a/shared/stat/http.go b/shared/stat/http.go index a2e4320..1d199b4 100644 --- a/shared/stat/http.go +++ b/shared/stat/http.go @@ -5,15 +5,27 @@ import "github.com/viant/gmetric/counter" // TODO move to shared/client type http struct{} -const Pending = "pending" +const ( + Pending = "pending" + // Shed marks a request that the client did NOT send because the host's + // circuit breaker was already in the down state when getHost() was + // called. Distinct from Down, which marks the trip event itself + // (the request that observed the connection error and called + // FlagDown). Shed is the count of subsequent requests that the + // breaker rejected before recovery. + Shed = "shed" +) func (p http) Keys() []string { + // New keys must be appended at the end so existing column indices + // remain stable for downstream consumers (Mimir queries, dashboards). return []string{ ErrorKey, Pending, Down, Canceled, DeadlineExceeded, + Shed, } } @@ -35,6 +47,8 @@ func (p http) Map(value interface{}) int { return 3 case DeadlineExceeded: return 4 + case Shed: + return 5 } case Dir: return 1 From d7730a692a297e5e42dee68a2b97e433473f8d91 Mon Sep 17 00:00:00 2001 From: David Choi Date: Thu, 30 Apr 2026 13:38:59 -0700 Subject: [PATCH 45/50] fix(circut): atomic Down writes and lock-protected resetDuration reset FlagUp and FlagDown now use atomic.CompareAndSwapInt32 on Down so the flag is consistent with IsUp's atomic.LoadInt32, and FlagUp's resetDuration reset runs under the mutex so it cannot clobber a concurrent FlagDown's resetDuration *= 2 (lost-update bug that defeated exponential backoff under flapping). Adds breaker_test.go (first tests in the package) covering the data race under -race, backoff accumulation across trips, and CAS-based idempotency of FlagUp / FlagDown. Co-authored-by: Claude Opus 4.7 Made-with: Cursor --- shared/circut/breaker.go | 28 ++++--- shared/circut/breaker_test.go | 139 ++++++++++++++++++++++++++++++++++ 2 files changed, 156 insertions(+), 11 deletions(-) create mode 100644 shared/circut/breaker_test.go diff --git a/shared/circut/breaker.go b/shared/circut/breaker.go index 04b1f73..10baba5 100644 --- a/shared/circut/breaker.go +++ b/shared/circut/breaker.go @@ -26,11 +26,19 @@ func (b *Breaker) IsUp() bool { } // FlagUp is used to reset the backoff. +// +// Uses CompareAndSwap so the resetDuration reset only fires on an actual +// down->up transition (not on idempotent FlagUp calls), and so the write +// to b.Down is atomic with respect to IsUp's atomic.LoadInt32. The +// resetDuration write is performed under the mutex so it cannot race +// with FlagDown's resetDuration *= 2 (lost-update bug). func (b *Breaker) FlagUp() { + if !atomic.CompareAndSwapInt32(&b.Down, 1, 0) { + return + } b.mux.Lock() - b.Down = 0 - b.mux.Unlock() b.resetDuration = b.initialResetDuration + b.mux.Unlock() } // resetIfDue will spawn a goroutine to probe the resource if the backoff time @@ -57,21 +65,19 @@ func (b *Breaker) resetIfDue() { } // FlagDown is used to indicate the resource is down. +// +// CompareAndSwap atomically transitions the Down flag exactly once per +// up->down edge, so backoff state (resetTime, resetDuration) is updated +// once per trip even under concurrent FlagDown calls. The atomic write +// is also synchronized with IsUp's atomic.LoadInt32. func (b *Breaker) FlagDown() { - down := atomic.LoadInt32(&b.Down) - if down == 1 { + if !atomic.CompareAndSwapInt32(&b.Down, 0, 1) { return } - b.mux.Lock() - defer b.mux.Unlock() - if b.Down == 1 { - return - } - b.Down = 1 - b.resetTime = time.Now().Add(b.resetDuration) b.resetDuration *= 2 //double reset time each time service is Down + b.mux.Unlock() } // New creates a new circut breaker diff --git a/shared/circut/breaker_test.go b/shared/circut/breaker_test.go new file mode 100644 index 0000000..0ddcd4e --- /dev/null +++ b/shared/circut/breaker_test.go @@ -0,0 +1,139 @@ +package circut + +import ( + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// fakeProber records probe invocations and never calls FlagUp. +// Tests that need probe-driven recovery call b.FlagUp() directly. +type fakeProber struct { + probes int64 +} + +func (f *fakeProber) Probe() { + atomic.AddInt64(&f.probes, 1) +} + +// TestBreaker_Concurrent_NoDataRace exercises FlagDown / FlagUp / IsUp +// from many goroutines simultaneously. Run with `-race` to catch the +// previously-existing data race on b.Down (atomic read, non-atomic write). +// +// Without the fix this test reliably triggers a race-detector report: +// +// WARNING: DATA RACE +// Read at 0x... by goroutine N (atomic.LoadInt32): +// shared/circut.(*Breaker).IsUp +// Previous write at 0x... by goroutine M: +// shared/circut.(*Breaker).FlagUp / FlagDown +func TestBreaker_Concurrent_NoDataRace(t *testing.T) { + b := New(50*time.Millisecond, &fakeProber{}) + + const goroutines = 16 + const iterations = 5000 + + var wg sync.WaitGroup + wg.Add(goroutines) + for g := 0; g < goroutines; g++ { + go func(seed int) { + defer wg.Done() + for i := 0; i < iterations; i++ { + switch (seed + i) % 3 { + case 0: + b.FlagDown() + case 1: + b.FlagUp() + case 2: + _ = b.IsUp() + } + } + }(g) + } + wg.Wait() + + // Final state assertion is intentionally weak; the point of this test + // is the race detector, not the terminal flag value. + _ = b.IsUp() +} + +// TestBreaker_BackoffAccumulates verifies that resetDuration doubles on +// each successive trip and is NOT clobbered by an interleaved FlagUp. +// Catches the lost-update bug where FlagUp's resetDuration reset ran +// outside the mutex and could race with FlagDown's resetDuration *= 2. +func TestBreaker_BackoffAccumulates(t *testing.T) { + const initial = 50 * time.Millisecond + b := New(initial, &fakeProber{}) + + require.Equal(t, initial, b.resetDuration, "initial resetDuration") + + // First trip: doubles to 100ms. + b.FlagDown() + require.Equal(t, 2*initial, b.resetDuration, "after 1st trip") + + // Recover and trip again: must double from 100ms to 200ms (NOT + // reset to 100ms, which is what the lost-update bug would do + // under racy timing). + b.FlagUp() + require.Equal(t, initial, b.resetDuration, "FlagUp resets to initial") + + b.FlagDown() + require.Equal(t, 2*initial, b.resetDuration, "after 2nd trip from initial") + + // Multiple FlagDowns without intervening FlagUp must NOT + // re-double. Idempotency comes from the CAS. + b.FlagDown() + b.FlagDown() + b.FlagDown() + assert.Equal(t, 2*initial, b.resetDuration, "extra FlagDowns are no-ops while down") +} + +// TestBreaker_FlagUp_Idempotent verifies that FlagUp on an already-up +// breaker does NOT reset resetDuration (which would be wrong if the +// breaker is in the middle of a backoff sequence and a stale Probe +// callback fires FlagUp redundantly). +func TestBreaker_FlagUp_Idempotent(t *testing.T) { + const initial = 50 * time.Millisecond + b := New(initial, &fakeProber{}) + + // Trip and recover -- resetDuration is back to initial. + b.FlagDown() + b.FlagUp() + require.Equal(t, initial, b.resetDuration) + + // Trip again -- resetDuration doubles. + b.FlagDown() + require.Equal(t, 2*initial, b.resetDuration) + + // Spurious FlagUp callback on an already-up breaker would be the + // state where Down is already 0 -- but we just trip'd, so it's 1. + // The realistic spurious FlagUp scenario is Probe firing twice + // after recovery has already happened. Simulate that. + b.FlagUp() // legitimate recovery; resetDuration -> initial + require.Equal(t, initial, b.resetDuration) + b.FlagUp() // spurious; must NOT clobber any subsequent backoff state + require.Equal(t, initial, b.resetDuration, "spurious FlagUp is a no-op") +} + +// TestBreaker_FlagDown_Idempotent verifies that repeated FlagDown calls +// while the breaker is already down do not advance resetTime or grow +// resetDuration further. +func TestBreaker_FlagDown_Idempotent(t *testing.T) { + const initial = 50 * time.Millisecond + b := New(initial, &fakeProber{}) + + b.FlagDown() + firstResetTime := b.resetTime + require.Equal(t, 2*initial, b.resetDuration) + + // Subsequent FlagDowns while down must be no-ops. + for i := 0; i < 5; i++ { + b.FlagDown() + } + assert.Equal(t, 2*initial, b.resetDuration, "resetDuration unchanged across redundant FlagDowns") + assert.Equal(t, firstResetTime, b.resetTime, "resetTime unchanged across redundant FlagDowns") +} From 0c2dd61381378f31002ce92c741377b819755a9e Mon Sep 17 00:00:00 2001 From: David Choi Date: Thu, 30 Apr 2026 13:52:49 -0700 Subject: [PATCH 46/50] docs(stat): document _pending bugs and per-regime utility Pending Inc/Dec via metric.EnterThenExit captures the recent-bucket index at Enter time and decrements that bucket on Exit; rotation between Enter and Exit lands the decrement in the wrong bucket. Also: Dir is not a string, so MultiCounter takes a mutex per Inc/Dec which serializes hot paths. Both _pending and _pending_Max remain operationally useful in different regimes (high-QPS vs low-QPS operations); the doc comment explains which signal to query in each case. Co-authored-by: Claude Opus 4.7 Made-with: Cursor --- shared/stat/http.go | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/shared/stat/http.go b/shared/stat/http.go index 1d199b4..a70ad25 100644 --- a/shared/stat/http.go +++ b/shared/stat/http.go @@ -6,6 +6,42 @@ import "github.com/viant/gmetric/counter" type http struct{} const ( + // Pending is the column for the in-flight gauge maintained by the + // metric.EnterThenExit Inc/Dec pattern. The exporter publishes: + // - _pending -- the current-in-flight counter (defective; see below) + // - _pending_Max -- per-bucket peak from the Occupancy CustomCounter + // + // Known defects in the underlying mechanism: + // + // 1. Bucket-mismatch on Exit. EnterThenExit captures the recent-bucket + // index at Enter time and decrements that same bucket on Exit. If + // the bucket has rotated between Enter and Exit, the wrong bucket + // is decremented -- the previous bucket's value goes negative + // while the current bucket's value drifts high. + // + // 2. Mutex serialization on Inc/Dec. The Dir typed value is not a + // string, so MultiCounter.incrementValueBy takes c.locker.Lock() + // on every Enter and Exit. Under high QPS this is a real + // serialization point. + // + // Defect #1 inflates both _pending (current) and _pending_Max (per-bucket + // peak); the inflation is in the conservative direction (over-estimation), + // so the metrics are still operationally useful in different regimes: + // + // - _pending_Max grouped per reporting dimension (e.g. by + // availability_zone, environment, op): for high-QPS operations + // the per-group peak rises substantially above baseline noise + // during fleet-wide saturation events, making this the cleaner + // saturation signal for those operations. + // + // - _pending summed across reporting instances: exhibits dramatic + // spikes during saturation for any QPS profile, partially + // amplified by defect #1. For low-QPS operations where the + // per-group _pending_Max signal is lost in baseline noise, the + // fleet sum is the more visible saturation signal. + // + // _pending_Max is the cleaner peak-concurrency signal for capacity + // sizing; pick per-operation based on QPS profile. Pending = "pending" // Shed marks a request that the client did NOT send because the host's // circuit breaker was already in the down state when getHost() was From 1e958f98a394b46d0b62b6809998ebee0887a572 Mon Sep 17 00:00:00 2001 From: David Choi Date: Thu, 30 Apr 2026 16:02:34 -0700 Subject: [PATCH 47/50] feat(circut+client): tail-latency-aware breaker with probabilistic pass-through Adds LatencyBreaker parallel to the existing connection-failure Breaker. Trips on latest > LatestThreshold OR rolling > RollingThreshold; recovers after K consecutive observations satisfy both thresholds. While ON, allows a configurable fraction of requests through to drive recovery sensing without committing real load. Host.IsUp() now requires both breakers to be up. Service.init wires a LatencyBreaker into each host when at least one threshold is non-zero (zero -> disabled, backward-compatible). postRequest feeds elapsed time to LatencyBreaker.Observe after each httpPost. Caller opts in by setting Config.LatencyBreaker* fields or via WithLatencyBreaker. Pass-through fraction defaults to 0.01 (1%). Co-authored-by: Claude Opus 4.7 Made-with: Cursor --- shared/circut/latency_breaker.go | 246 ++++++++++++++++++++++++++ shared/circut/latency_breaker_test.go | 218 +++++++++++++++++++++++ shared/client/config.go | 35 ++++ shared/client/host.go | 20 +++ shared/client/option.go | 38 ++++ shared/client/service.go | 30 ++++ 6 files changed, 587 insertions(+) create mode 100644 shared/circut/latency_breaker.go create mode 100644 shared/circut/latency_breaker_test.go diff --git a/shared/circut/latency_breaker.go b/shared/circut/latency_breaker.go new file mode 100644 index 0000000..9d48053 --- /dev/null +++ b/shared/circut/latency_breaker.go @@ -0,0 +1,246 @@ +package circut + +import ( + "math/rand/v2" + "sync" + "sync/atomic" + "time" +) + +// LatencyBreaker is a state machine that sheds traffic when observed +// request latencies exceed configured thresholds. It is independent of +// (and parallel to) the connection-failure-based Breaker. +// +// Detection: +// +// - latest: the most recent observation. Compared against +// LatestThreshold. +// - rolling: average over a sliding window. Compared against +// RollingThreshold. +// +// State transitions: +// +// - OFF -> ON on every observation where +// latest > LatestThreshold OR rolling > RollingThreshold +// (zero-valued thresholds are skipped, so a single threshold +// can be used by leaving the other zero). +// - ON -> OFF after K consecutive observations satisfy +// latest < LatestThreshold AND rolling < RollingThreshold. +// +// While ON, IsUp() returns true with probability PassThroughFraction +// and false otherwise -- letting a small fraction of traffic through +// to drive recovery sensing without committing real load. +// +// Concurrency: +// +// State (Down) is read with atomic.LoadInt32 in IsUp() and updated +// via atomic.CompareAndSwapInt32 from Observe(). Compound state +// (latest / rolling buckets / consecutiveOK) is protected by a +// mutex; only one Observe can mutate at a time. IsUp does not block. +// +// Random number source for pass-through is math/rand/v2 top-level, +// which is concurrent-safe and lock-free in Go 1.22+. A test seam +// (randFloat) allows deterministic tests. +type LatencyBreaker struct { + // Configuration. Set at construction; not mutated after. + LatestThreshold time.Duration + RollingThreshold time.Duration + RollingWindow time.Duration + KConsecutive int + PassThroughFraction float64 + + // state holds the OFF (0) / ON (1) flag. Atomic. + state int32 + + mu sync.Mutex // guards latest, rolling, consecutiveOK + latest time.Duration + rolling *rollingAverage + consecutiveOK int + + // randFloat returns a value in [0, 1). Defaults to math/rand/v2.Float64. + // Override for deterministic tests. + randFloat func() float64 +} + +// NewLatencyBreaker constructs a LatencyBreaker. Zero-valued thresholds +// disable that branch of the trip predicate. If both thresholds are +// zero, Observe is a no-op and IsUp always returns true (effectively +// disabled). +func NewLatencyBreaker( + latestThreshold, rollingThreshold, rollingWindow time.Duration, + kConsecutive int, + passThroughFraction float64, +) *LatencyBreaker { + if rollingWindow <= 0 { + rollingWindow = time.Second + } + if kConsecutive < 1 { + kConsecutive = 1 + } + if passThroughFraction < 0 { + passThroughFraction = 0 + } + if passThroughFraction > 1 { + passThroughFraction = 1 + } + return &LatencyBreaker{ + LatestThreshold: latestThreshold, + RollingThreshold: rollingThreshold, + RollingWindow: rollingWindow, + KConsecutive: kConsecutive, + PassThroughFraction: passThroughFraction, + rolling: newRollingAverage(rollingWindow, 10), + randFloat: rand.Float64, + } +} + +// IsUp returns true if the breaker is OFF (allowing all traffic), or +// true with probability PassThroughFraction if ON (allowing a small +// fraction through for recovery sensing). +func (lb *LatencyBreaker) IsUp() bool { + if lb == nil { + return true + } + if atomic.LoadInt32(&lb.state) == 0 { + return true + } + return lb.randFloat() < lb.PassThroughFraction +} + +// Observe records the latency of a completed request and advances the +// state machine. Called from the bidder client after each httpPost +// attempt completes (success or failure -- timeouts and errors count +// as observations and the elapsed time captured by the caller). +func (lb *LatencyBreaker) Observe(latency time.Duration) { + if lb == nil { + return + } + if lb.LatestThreshold == 0 && lb.RollingThreshold == 0 { + // Both thresholds disabled; no signal to act on. + return + } + + lb.mu.Lock() + now := time.Now() + lb.latest = latency + lb.rolling.add(latency, now) + rollingAvg := lb.rolling.average(now) + + state := atomic.LoadInt32(&lb.state) + + // triggerOn: ANY threshold breached. Zero-valued thresholds skip. + triggerOn := false + if lb.LatestThreshold > 0 && latency > lb.LatestThreshold { + triggerOn = true + } + if lb.RollingThreshold > 0 && rollingAvg > lb.RollingThreshold { + triggerOn = true + } + + // triggerOffReady: BOTH thresholds satisfied as below. Zero-valued + // thresholds count as satisfied. + triggerOffReady := true + if lb.LatestThreshold > 0 && latency >= lb.LatestThreshold { + triggerOffReady = false + } + if lb.RollingThreshold > 0 && rollingAvg >= lb.RollingThreshold { + triggerOffReady = false + } + + switch state { + case 0: // OFF + if triggerOn { + atomic.StoreInt32(&lb.state, 1) + lb.consecutiveOK = 0 + } + case 1: // ON + if triggerOffReady { + lb.consecutiveOK++ + if lb.consecutiveOK >= lb.KConsecutive { + atomic.StoreInt32(&lb.state, 0) + lb.consecutiveOK = 0 + } + } else { + lb.consecutiveOK = 0 + } + } + lb.mu.Unlock() +} + +// State returns 0 (OFF / up) or 1 (ON / shedding). Primarily for tests. +func (lb *LatencyBreaker) State() int32 { + if lb == nil { + return 0 + } + return atomic.LoadInt32(&lb.state) +} + +// rollingAverage keeps a sliding-window average of durations using a +// fixed number of time-aligned buckets. Buckets that fall outside the +// current window are reset on next access. All access is serialized +// by LatencyBreaker.mu; this struct is not goroutine-safe on its own. +type rollingAverage struct { + window time.Duration + bucketDur time.Duration + bucketDurN int64 // bucketDur in nanoseconds, cached + buckets []rollingBucket +} + +type rollingBucket struct { + sum time.Duration + count int64 + until int64 // exclusive end of bucket period, in nanoseconds since epoch +} + +func newRollingAverage(window time.Duration, n int) *rollingAverage { + if n < 1 { + n = 1 + } + bd := window / time.Duration(n) + if bd <= 0 { + bd = window + n = 1 + } + return &rollingAverage{ + window: window, + bucketDur: bd, + bucketDurN: int64(bd), + buckets: make([]rollingBucket, n), + } +} + +// add records a value with completion time t. +func (r *rollingAverage) add(v time.Duration, t time.Time) { + tn := t.UnixNano() + idx := int((tn / r.bucketDurN) % int64(len(r.buckets))) + until := ((tn / r.bucketDurN) + 1) * r.bucketDurN + if r.buckets[idx].until != until { + // Bucket belongs to a different period; reset and reuse. + r.buckets[idx].sum = 0 + r.buckets[idx].count = 0 + r.buckets[idx].until = until + } + r.buckets[idx].sum += v + r.buckets[idx].count++ +} + +// average returns the average across all buckets whose period overlaps +// the window ending at t. Returns 0 if no in-window samples. +func (r *rollingAverage) average(t time.Time) time.Duration { + cutoff := t.UnixNano() - int64(r.window) + var sum time.Duration + var count int64 + for i := range r.buckets { + b := &r.buckets[i] + // Bucket's period end (until) must be after cutoff to be in-window. + if b.until <= cutoff { + continue + } + sum += b.sum + count += b.count + } + if count == 0 { + return 0 + } + return sum / time.Duration(count) +} diff --git a/shared/circut/latency_breaker_test.go b/shared/circut/latency_breaker_test.go new file mode 100644 index 0000000..ef31e7a --- /dev/null +++ b/shared/circut/latency_breaker_test.go @@ -0,0 +1,218 @@ +package circut + +import ( + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + tinyWindow = 100 * time.Millisecond + winK = 3 +) + +func newTestLB(latest, rolling time.Duration, kConsecutive int, fraction float64) *LatencyBreaker { + return NewLatencyBreaker(latest, rolling, tinyWindow, kConsecutive, fraction) +} + +// TestLatencyBreaker_DisabledWhenZero verifies that a LatencyBreaker +// constructed with both thresholds = 0 is a no-op: Observe doesn't +// transition state, IsUp always returns true. +func TestLatencyBreaker_DisabledWhenZero(t *testing.T) { + lb := newTestLB(0, 0, winK, 0.01) + for i := 0; i < 10; i++ { + lb.Observe(time.Hour) + } + assert.Equal(t, int32(0), lb.State()) + assert.True(t, lb.IsUp()) +} + +// TestLatencyBreaker_TripOnLatest verifies OFF -> ON transition when +// the latest observation alone exceeds LatestThreshold (rolling +// threshold disabled). +func TestLatencyBreaker_TripOnLatest(t *testing.T) { + lb := newTestLB(40*time.Millisecond, 0, winK, 0.01) + require.Equal(t, int32(0), lb.State()) + + lb.Observe(50 * time.Millisecond) + assert.Equal(t, int32(1), lb.State(), "single observation above latestThreshold trips ON") +} + +// TestLatencyBreaker_TripOnRolling verifies OFF -> ON transition when +// the rolling average crosses RollingThreshold even though no single +// observation hits the latest threshold. +func TestLatencyBreaker_TripOnRolling(t *testing.T) { + // LatestThreshold=0 (disabled), RollingThreshold=20ms. + lb := newTestLB(0, 20*time.Millisecond, winK, 0.01) + + // Stream of 25ms observations -- each below latestThreshold (which + // is disabled), but rolling crosses 20ms quickly. + for i := 0; i < 5; i++ { + lb.Observe(25 * time.Millisecond) + } + assert.Equal(t, int32(1), lb.State(), "sustained observations above rollingThreshold trip ON") +} + +// TestLatencyBreaker_RecoverViaKConsecutive verifies ON -> OFF takes K +// consecutive observations satisfying the configured thresholds. +// +// Uses RollingThreshold=0 (disabled) so the test isolates the K-consecutive +// state machine from rolling-window pollution -- a single slow observation +// in the rolling window keeps the rolling average elevated for the full +// window duration regardless of how many subsequent fast observations +// arrive, which is correct production behavior but obscures this test's +// intent. +func TestLatencyBreaker_RecoverViaKConsecutive(t *testing.T) { + lb := newTestLB(40*time.Millisecond, 0, winK, 1.0) + + // Trip via latest. + lb.Observe(50 * time.Millisecond) + require.Equal(t, int32(1), lb.State()) + + // K-1 fast observations -- not enough to recover. + for i := 0; i < winK-1; i++ { + lb.Observe(5 * time.Millisecond) + require.Equal(t, int32(1), lb.State(), "still ON after %d observations", i+1) + } + + // Kth fast observation -- recovers. + lb.Observe(5 * time.Millisecond) + assert.Equal(t, int32(0), lb.State(), "ON -> OFF after K consecutive fast observations") +} + +// TestLatencyBreaker_RecoveryResetsOnSlow verifies that a single +// above-threshold observation resets the consecutive-OK counter. +// RollingThreshold=0 to isolate the consecutive-OK reset logic. +func TestLatencyBreaker_RecoveryResetsOnSlow(t *testing.T) { + lb := newTestLB(40*time.Millisecond, 0, winK, 1.0) + + lb.Observe(50 * time.Millisecond) // trip + require.Equal(t, int32(1), lb.State()) + + lb.Observe(5 * time.Millisecond) // 1 OK + lb.Observe(5 * time.Millisecond) // 2 OK + require.Equal(t, int32(1), lb.State(), "still ON before K consecutive") + + lb.Observe(50 * time.Millisecond) // bad observation -- resets consecutiveOK to 0 + require.Equal(t, int32(1), lb.State()) + + // Now need K consecutive again from scratch. + lb.Observe(5 * time.Millisecond) + lb.Observe(5 * time.Millisecond) + require.Equal(t, int32(1), lb.State(), "still ON after only 2 fast observations following reset") + lb.Observe(5 * time.Millisecond) + assert.Equal(t, int32(0), lb.State(), "ON -> OFF after K consecutive following reset") +} + +// TestLatencyBreaker_PassThroughFraction verifies the probabilistic +// pass-through behavior while ON, using an injected deterministic +// random source. +func TestLatencyBreaker_PassThroughFraction(t *testing.T) { + lb := newTestLB(40*time.Millisecond, 0, winK, 0.25) + + // Trip. + lb.Observe(50 * time.Millisecond) + require.Equal(t, int32(1), lb.State()) + + // Inject a deterministic counter generating values 0.0, 0.1, 0.2, + // 0.3, 0.4, ... -- pass-through fires when value < 0.25, i.e. for + // the first 3 (0.0, 0.1, 0.2) of every 10. + var counter int + lb.randFloat = func() float64 { + v := float64(counter%10) / 10.0 + counter++ + return v + } + + pass := 0 + const N = 1000 + for i := 0; i < N; i++ { + if lb.IsUp() { + pass++ + } + } + // Expected: 30% pass exactly with this generator (3/10 buckets). + // Allow ±2% drift for any rounding. + assert.InDelta(t, 0.30, float64(pass)/float64(N), 0.02, + "pass-through fraction should match injected random source") +} + +// TestLatencyBreaker_PassThroughZero verifies that PassThroughFraction=0 +// sheds 100% while ON. +func TestLatencyBreaker_PassThroughZero(t *testing.T) { + lb := newTestLB(40*time.Millisecond, 0, winK, 0.0) + lb.Observe(50 * time.Millisecond) + require.Equal(t, int32(1), lb.State()) + + for i := 0; i < 100; i++ { + assert.False(t, lb.IsUp(), "PassThroughFraction=0 must shed 100%% while ON") + } +} + +// TestLatencyBreaker_NilSafe verifies a nil receiver behaves as +// a permanently-up no-op breaker (lets the caller treat +// "no LatencyBreaker configured" identically to "configured but OFF"). +func TestLatencyBreaker_NilSafe(t *testing.T) { + var lb *LatencyBreaker + assert.True(t, lb.IsUp()) + lb.Observe(time.Second) // must not panic + assert.Equal(t, int32(0), lb.State()) +} + +// TestLatencyBreaker_Concurrent_NoDataRace exercises Observe / IsUp +// from many goroutines simultaneously. Run with `-race` to catch any +// data races introduced by future edits. +func TestLatencyBreaker_Concurrent_NoDataRace(t *testing.T) { + lb := newTestLB(40*time.Millisecond, 20*time.Millisecond, winK, 0.5) + + const goroutines = 16 + const iterations = 2000 + + var wg sync.WaitGroup + wg.Add(goroutines) + for g := 0; g < goroutines; g++ { + go func(seed int) { + defer wg.Done() + for i := 0; i < iterations; i++ { + switch (seed + i) % 3 { + case 0: + lb.Observe(10 * time.Millisecond) + case 1: + lb.Observe(50 * time.Millisecond) + case 2: + _ = lb.IsUp() + } + } + }(g) + } + wg.Wait() + + // Sanity: state must be 0 or 1. + st := lb.State() + assert.True(t, st == 0 || st == 1, "state must be 0 or 1, got %d", st) +} + +// TestRollingAverage_BucketRotation verifies that observations outside +// the rolling window are excluded from the average. +func TestRollingAverage_BucketRotation(t *testing.T) { + r := newRollingAverage(100*time.Millisecond, 10) + t0 := time.Unix(0, 1_000_000_000) // 1s exactly + + // Add 10 observations of 50ms within a single bucket period. + for i := 0; i < 10; i++ { + r.add(50*time.Millisecond, t0) + } + assert.Equal(t, 50*time.Millisecond, r.average(t0)) + + // 200ms later, all old buckets should be outside the 100ms window. + tLater := t0.Add(200 * time.Millisecond) + assert.Equal(t, time.Duration(0), r.average(tLater), + "window should expire all 10-bucket-old observations") +} + +// Sanity check: ensure the atomic types we use compile. +var _ = atomic.LoadInt32 diff --git a/shared/client/config.go b/shared/client/config.go index f103c15..a8f86b0 100644 --- a/shared/client/config.go +++ b/shared/client/config.go @@ -1,6 +1,8 @@ package client import ( + "time" + "github.com/viant/mly/shared/client/config" ) @@ -22,6 +24,39 @@ type Config struct { Debug bool DictHashValidation bool + + // LatencyBreaker fields are passed through to circut.LatencyBreaker + // at Service init time. When both LatencyBreakerLatestThreshold and + // LatencyBreakerRollingThreshold are zero, the LatencyBreaker is not + // constructed and the host's IsUp() reflects only the connection + // breaker -- backward-compatible default. + // + // LatestThreshold is the per-attempt latency above which a single + // observation is enough to trip into the shedding state. The caller + // is expected to size this near (or just below) its own request + // timeout so the breaker fires before requests would have failed. + LatencyBreakerLatestThreshold time.Duration + + // RollingThreshold is the rolling-average latency above which the + // breaker trips. Detects sustained slow-creep that no single + // observation crosses LatestThreshold for. + LatencyBreakerRollingThreshold time.Duration + + // RollingWindow is the duration over which the rolling average is + // computed. Default 1s if zero. + LatencyBreakerRollingWindow time.Duration + + // KConsecutive is the number of consecutive observations satisfying + // (latest < LatestThreshold AND rolling < RollingThreshold) needed + // to transition from ON back to OFF. Higher = more conservative + // recovery, prevents flap on outliers. Default 3 if zero. + LatencyBreakerKConsecutive int + + // PassThroughFraction is the probability that a request is allowed + // through while the breaker is ON, to drive recovery sensing. + // Default 0.01 (1%). Set higher for low-QPS models that need more + // observations to recover, or 0 to fully shed without recovery. + LatencyBreakerPassThroughFraction float64 } //CacheSize returns cache size diff --git a/shared/client/host.go b/shared/client/host.go index 5580db3..09432df 100644 --- a/shared/client/host.go +++ b/shared/client/host.go @@ -25,10 +25,30 @@ type Host struct { mux sync.RWMutex *circut.Breaker + // LatencyBreaker is an optional latency-driven shed mechanism that + // runs in parallel to the connection-failure-based Breaker. When + // configured (non-nil), getHost() requires both Breaker.IsUp() and + // LatencyBreaker.IsUp() to return true before letting a request + // through. nil = disabled (acts as permanently up). + LatencyBreaker *circut.LatencyBreaker + // memoization prefix string } +// IsUp combines the connection-failure Breaker and the (optional) +// LatencyBreaker. Both must say up for the host to be considered up. +// Shadows the embedded Breaker.IsUp(). +func (h *Host) IsUp() bool { + if h.Breaker != nil && !h.Breaker.IsUp() { + return false + } + if !h.LatencyBreaker.IsUp() { + return false + } + return true +} + func isSecurePort(port int) bool { return port == 443 || port == 1443 } diff --git a/shared/client/option.go b/shared/client/option.go index e626dfb..2e9b04b 100644 --- a/shared/client/option.go +++ b/shared/client/option.go @@ -1,6 +1,8 @@ package client import ( + "time" + "github.com/viant/gmetric" cconfig "github.com/viant/mly/shared/client/config" "github.com/viant/mly/shared/datastore" @@ -146,3 +148,39 @@ func (o *clientOptionsOption) Apply(c *Service) { func WithClientOptions(clientOptions ...dscli.Option) Option { return &clientOptionsOption{clientOptions: clientOptions} } + +type latencyBreakerOpt struct { + latest, rolling, window time.Duration + k int + fraction float64 +} + +func (o *latencyBreakerOpt) Apply(c *Service) { + c.Config.LatencyBreakerLatestThreshold = o.latest + c.Config.LatencyBreakerRollingThreshold = o.rolling + c.Config.LatencyBreakerRollingWindow = o.window + c.Config.LatencyBreakerKConsecutive = o.k + c.Config.LatencyBreakerPassThroughFraction = o.fraction +} + +// WithLatencyBreaker enables the latency-aware breaker on each host +// constructed for this Service. Pass-through fraction defaults to 0.01 +// when fraction <= 0; rolling window defaults to 1s; KConsecutive +// defaults to 3. +// +// Setting both latest and rolling to zero leaves the breaker disabled +// (backward compatible). Both thresholds are taken as raw durations; +// the caller is responsible for sizing them appropriately for the +// model's traffic profile and the caller's request timeout. +func WithLatencyBreaker(latest, rolling, window time.Duration, k int, fraction float64) Option { + if fraction <= 0 { + fraction = 0.01 + } + if window <= 0 { + window = time.Second + } + if k < 1 { + k = 3 + } + return &latencyBreakerOpt{latest: latest, rolling: rolling, window: window, k: k, fraction: fraction} +} diff --git a/shared/client/service.go b/shared/client/service.go index 81a8916..999620a 100644 --- a/shared/client/service.go +++ b/shared/client/service.go @@ -20,6 +20,7 @@ import ( "github.com/francoispqt/gojay" "github.com/viant/gmetric" + "github.com/viant/mly/shared/circut" "github.com/viant/mly/shared/client/config" "github.com/viant/mly/shared/common" "github.com/viant/mly/shared/common/storable" @@ -323,6 +324,8 @@ func (s *Service) init() error { s.Config.MaxRetry = 3 } + s.initLatencyBreakers() + err := s.initHTTPClient() if err != nil { return err @@ -629,7 +632,12 @@ func (s *Service) postRequest(ctx context.Context, data []byte, mvt *stat.Values var output []byte + start := time.Now() output, err = s.httpPost(ctx, data, host) + // Feed the latency observation to the latency breaker (if one is + // configured on this host). Observe is nil-safe. + host.LatencyBreaker.Observe(time.Since(start)) + if common.IsConnectionError(err) { if s.Config.Debug { log.Printf("[%s postRequest] connection error:%s", s.Config.Model, err) @@ -641,6 +649,28 @@ func (s *Service) postRequest(ctx context.Context, data []byte, mvt *stat.Values return output, err } +// initLatencyBreakers attaches a circut.LatencyBreaker to each +// configured host when the Config has at least one non-zero threshold. +// Both thresholds zero -> no breaker constructed (backward-compatible +// no-op). +func (s *Service) initLatencyBreakers() { + if s.Config.LatencyBreakerLatestThreshold == 0 && s.Config.LatencyBreakerRollingThreshold == 0 { + return + } + for _, h := range s.Config.Hosts { + if h == nil || h.LatencyBreaker != nil { + continue + } + h.LatencyBreaker = circut.NewLatencyBreaker( + s.Config.LatencyBreakerLatestThreshold, + s.Config.LatencyBreakerRollingThreshold, + s.Config.LatencyBreakerRollingWindow, + s.Config.LatencyBreakerKConsecutive, + s.Config.LatencyBreakerPassThroughFraction, + ) + } +} + func (s *Service) httpPost(ctx context.Context, data []byte, host *Host) ([]byte, error) { evalUrl := host.evalURL(s.Model) var terminate bool From 8104cd3668cfa86cfc6feeb06adc97942e3426fd Mon Sep 17 00:00:00 2001 From: David Choi Date: Thu, 30 Apr 2026 16:36:29 -0700 Subject: [PATCH 48/50] fix(self-test): avoid registering native client Prometheus metrics Add WithPrometheusMetrics(false) for short-lived helper clients that should not register long-lived native Prometheus series. Use it for server startup self-test clients so completed self-tests do not leave zero-valued mly_client_* series in the process-wide registry. Co-authored-by: Claude Opus 4.7 Made-with: Cursor --- service/endpoint/checker/self.go | 2 +- shared/client/option.go | 15 +++++++++++++++ shared/client/service.go | 9 +++++++++ 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/service/endpoint/checker/self.go b/service/endpoint/checker/self.go index e48dbdc..2e5ebd2 100644 --- a/service/endpoint/checker/self.go +++ b/service/endpoint/checker/self.go @@ -14,7 +14,7 @@ import ( ) func SelfTest(host []*client.Host, timeout time.Duration, modelID string, usesTransformer bool, tp config.TestPayload, debug bool) error { - cli, err := client.New(modelID, host, client.WithDebug(true)) + cli, err := client.New(modelID, host, client.WithDebug(true), client.WithPrometheusMetrics(false)) if err != nil { return fmt.Errorf("%s:%w", modelID, err) } diff --git a/shared/client/option.go b/shared/client/option.go index 2e9b04b..bd88146 100644 --- a/shared/client/option.go +++ b/shared/client/option.go @@ -41,6 +41,21 @@ func WithGmetrics(gmetrics *gmetric.Service) Option { return &gmetricsOpt{gmetrics: gmetrics} } +type prometheusMetricsOpt struct { + enable bool +} + +func (o *prometheusMetricsOpt) Apply(c *Service) { + c.noPrometheusMetrics = !o.enable +} + +// WithPrometheusMetrics enables or disables native Prometheus client +// metrics. Metrics are enabled by default; disable them for short-lived +// helper clients that should not register long-lived model series. +func WithPrometheusMetrics(enable bool) Option { + return &prometheusMetricsOpt{enable: enable} +} + type dictHashValidationOpt struct { enable bool } diff --git a/shared/client/service.go b/shared/client/service.go index 07f24fe..ccc8915 100644 --- a/shared/client/service.go +++ b/shared/client/service.go @@ -78,6 +78,12 @@ type Service struct { // See https://prometheus.io/docs/practices/histograms for guidance. noPrometheusSummaries bool + // noPrometheusMetrics disables native Prometheus metric registration. + // This is useful for short-lived helper clients (for example server + // startup self-tests) that should not leave zero-valued model series in + // the process-wide registry after the helper has finished. + noPrometheusMetrics bool + prometheusMetrics prometheusMetrics ErrorHistory tracker.Tracker @@ -340,6 +346,9 @@ func (s *Service) dictionary() *Dictionary { } func (s *Service) registerPrometheusMetrics() error { + if s.noPrometheusMetrics { + return nil + } pr := prometheus.DefaultRegisterer if s.PrometheusRegisterer != nil { pr = s.PrometheusRegisterer From b494c86e249a3593ccc4708177d4a44d199077f4 Mon Sep 17 00:00:00 2001 From: David Choi Date: Thu, 30 Apr 2026 16:51:58 -0700 Subject: [PATCH 49/50] fix(shared/client): make disabled Prometheus metrics nil-safe WithPrometheusMetrics(false) skips native Prometheus registration, leaving observer and counter fields nil. Guard observation helpers and direct counter increments so helper clients can opt out without panicking. Exercise the opt-out path in the client Run test. Co-authored-by: GPT-5.5 Made-with: Cursor --- shared/client/prometheus.go | 16 ++++++++++++---- shared/client/service.go | 8 ++++++-- shared/client/service_test.go | 1 + 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/shared/client/prometheus.go b/shared/client/prometheus.go index 8048ece..4a3158c 100644 --- a/shared/client/prometheus.go +++ b/shared/client/prometheus.go @@ -76,28 +76,36 @@ type prometheusMetrics struct { } func (m prometheusMetrics) observeRunDuration(duration float64) { - m.runDurationHistogram.Observe(duration) + if m.runDurationHistogram != nil { + m.runDurationHistogram.Observe(duration) + } if m.runDurationSummary != nil { m.runDurationSummary.Observe(duration) } } func (m prometheusMetrics) observeBatchSize(batchSize float64) { - m.batchSizeHistogram.Observe(batchSize) + if m.batchSizeHistogram != nil { + m.batchSizeHistogram.Observe(batchSize) + } if m.batchSizeSummary != nil { m.batchSizeSummary.Observe(batchSize) } } func (m prometheusMetrics) observeHttpDuration(duration float64) { - m.httpDurationHistogram.Observe(duration) + if m.httpDurationHistogram != nil { + m.httpDurationHistogram.Observe(duration) + } if m.httpDurationSummary != nil { m.httpDurationSummary.Observe(duration) } } func (m prometheusMetrics) observeHttpClientDuration(duration float64) { - m.httpClientDurationHistogram.Observe(duration) + if m.httpClientDurationHistogram != nil { + m.httpClientDurationHistogram.Observe(duration) + } if m.httpClientDurationSummary != nil { m.httpClientDurationSummary.Observe(duration) } diff --git a/shared/client/service.go b/shared/client/service.go index ccc8915..3759fe2 100644 --- a/shared/client/service.go +++ b/shared/client/service.go @@ -114,7 +114,9 @@ func (s *Service) Run(ctx context.Context, input interface{}, response *Response if ctx.Err() != nil { stats.Append(stat.EarlyCtxError) - s.prometheusMetrics.runErrorEarlyCtxCounter.Inc() + if s.prometheusMetrics.runErrorEarlyCtxCounter != nil { + s.prometheusMetrics.runErrorEarlyCtxCounter.Inc() + } } if response.Data == nil { @@ -701,7 +703,9 @@ func (s *Service) postRequest(ctx context.Context, data []byte, mvt *stat.Values } mvt.Append(stat.Down) - s.prometheusMetrics.httpDownCounter.Inc() + if s.prometheusMetrics.httpDownCounter != nil { + s.prometheusMetrics.httpDownCounter.Inc() + } host.FlagDown() } diff --git a/shared/client/service_test.go b/shared/client/service_test.go index ed299ec..ed6506b 100644 --- a/shared/client/service_test.go +++ b/shared/client/service_test.go @@ -89,6 +89,7 @@ func TestService_Run(t *testing.T) { WithDictionary(dictionary), WithDataStorer(mock.New()), WithDebug(true), + WithPrometheusMetrics(false), } } From 1ecba35db39afb102d3f07511685cac06952713a Mon Sep 17 00:00:00 2001 From: David Choi Date: Fri, 1 May 2026 15:32:05 -0700 Subject: [PATCH 50/50] fix(shared/client): validate latency breaker configuration Validate latency-breaker thresholds, rolling window, recovery count, and pass-through fraction during Service initialization. Normalize defaults in one path so direct Config usage and WithLatencyBreaker behave the same. Also wrap Service.init errors with context while preserving the underlying cause via %w. Co-authored-by: GPT-5.5 Co-authored-by: Cursor --- shared/client/config.go | 63 ++++++++++++++++++++++++++++-- shared/client/config_test.go | 76 ++++++++++++++++++++++++++++++++++++ shared/client/option.go | 14 +------ shared/client/service.go | 39 +++++++++++------- 4 files changed, 162 insertions(+), 30 deletions(-) create mode 100644 shared/client/config_test.go diff --git a/shared/client/config.go b/shared/client/config.go index a8f86b0..8a07fa6 100644 --- a/shared/client/config.go +++ b/shared/client/config.go @@ -1,12 +1,25 @@ package client import ( + "fmt" "time" "github.com/viant/mly/shared/client/config" ) -//Config represents a client config +const ( + defaultLatencyBreakerRollingWindow = time.Second + defaultLatencyBreakerKConsecutive = 3 + defaultLatencyBreakerPassThroughFraction = 0.01 +) + +type latencyBreakerSettings struct { + latest, rolling, window time.Duration + k int + fraction float64 +} + +// Config represents a client config type Config struct { Hosts []*Host Model string @@ -55,11 +68,55 @@ type Config struct { // PassThroughFraction is the probability that a request is allowed // through while the breaker is ON, to drive recovery sensing. // Default 0.01 (1%). Set higher for low-QPS models that need more - // observations to recover, or 0 to fully shed without recovery. + // observations to recover. Valid range: [0, 1]. A zero value means + // use the default. LatencyBreakerPassThroughFraction float64 } -//CacheSize returns cache size +func (c *Config) latencyBreakerSettings() (latencyBreakerSettings, bool, error) { + settings := latencyBreakerSettings{ + latest: c.LatencyBreakerLatestThreshold, + rolling: c.LatencyBreakerRollingThreshold, + window: c.LatencyBreakerRollingWindow, + k: c.LatencyBreakerKConsecutive, + fraction: c.LatencyBreakerPassThroughFraction, + } + + if settings.latest < 0 { + return settings, false, fmt.Errorf("LatencyBreakerLatestThreshold must be >= 0, got %s", settings.latest) + } + if settings.rolling < 0 { + return settings, false, fmt.Errorf("LatencyBreakerRollingThreshold must be >= 0, got %s", settings.rolling) + } + if settings.window < 0 { + return settings, false, fmt.Errorf("LatencyBreakerRollingWindow must be >= 0, got %s", settings.window) + } + if settings.k < 0 { + return settings, false, fmt.Errorf("LatencyBreakerKConsecutive must be >= 0, got %d", settings.k) + } + if settings.fraction < 0 || settings.fraction > 1 { + return settings, false, fmt.Errorf("LatencyBreakerPassThroughFraction must be in [0, 1], got %v", settings.fraction) + } + + enabled := settings.latest > 0 || settings.rolling > 0 + if !enabled { + return settings, false, nil + } + + if settings.window == 0 { + settings.window = defaultLatencyBreakerRollingWindow + } + if settings.k == 0 { + settings.k = defaultLatencyBreakerKConsecutive + } + if settings.fraction == 0 { + settings.fraction = defaultLatencyBreakerPassThroughFraction + } + + return settings, true, nil +} + +// CacheSize returns cache size func (c *Config) CacheSize() int { if c.CacheSizeMb == 0 { return 0 diff --git a/shared/client/config_test.go b/shared/client/config_test.go new file mode 100644 index 0000000..983e9d2 --- /dev/null +++ b/shared/client/config_test.go @@ -0,0 +1,76 @@ +package client + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConfig_LatencyBreakerSettings_Disabled(t *testing.T) { + settings, enabled, err := (&Config{}).latencyBreakerSettings() + require.NoError(t, err) + assert.False(t, enabled) + assert.Zero(t, settings.latest) + assert.Zero(t, settings.rolling) +} + +func TestConfig_LatencyBreakerSettings_Defaults(t *testing.T) { + cfg := &Config{LatencyBreakerLatestThreshold: 40 * time.Millisecond} + + settings, enabled, err := cfg.latencyBreakerSettings() + require.NoError(t, err) + assert.True(t, enabled) + assert.Equal(t, 40*time.Millisecond, settings.latest) + assert.Equal(t, defaultLatencyBreakerRollingWindow, settings.window) + assert.Equal(t, defaultLatencyBreakerKConsecutive, settings.k) + assert.Equal(t, defaultLatencyBreakerPassThroughFraction, settings.fraction) +} + +func TestConfig_LatencyBreakerSettings_ExplicitValues(t *testing.T) { + cfg := &Config{ + LatencyBreakerRollingThreshold: 20 * time.Millisecond, + LatencyBreakerRollingWindow: 500 * time.Millisecond, + LatencyBreakerKConsecutive: 5, + LatencyBreakerPassThroughFraction: 0.25, + } + + settings, enabled, err := cfg.latencyBreakerSettings() + require.NoError(t, err) + assert.True(t, enabled) + assert.Equal(t, 20*time.Millisecond, settings.rolling) + assert.Equal(t, 500*time.Millisecond, settings.window) + assert.Equal(t, 5, settings.k) + assert.Equal(t, 0.25, settings.fraction) +} + +func TestConfig_LatencyBreakerSettings_Invalid(t *testing.T) { + cases := []struct { + name string + config Config + }{ + {name: "negative latest threshold", config: Config{LatencyBreakerLatestThreshold: -time.Millisecond}}, + {name: "negative rolling threshold", config: Config{LatencyBreakerRollingThreshold: -time.Millisecond}}, + {name: "negative rolling window", config: Config{LatencyBreakerRollingWindow: -time.Millisecond}}, + {name: "negative consecutive count", config: Config{LatencyBreakerKConsecutive: -1}}, + {name: "negative pass-through fraction", config: Config{LatencyBreakerPassThroughFraction: -0.1}}, + {name: "pass-through fraction above one", config: Config{LatencyBreakerPassThroughFraction: 1.1}}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + _, _, err := tc.config.latencyBreakerSettings() + assert.Error(t, err) + }) + } +} + +func TestNew_InvalidLatencyBreakerOptionReturnsError(t *testing.T) { + _, err := New( + "invalid-latency-breaker", + []*Host{NewHost("localhost", 8080)}, + WithLatencyBreaker(-time.Millisecond, 0, 0, 0, 0), + ) + assert.ErrorContains(t, err, "LatencyBreakerLatestThreshold") +} diff --git a/shared/client/option.go b/shared/client/option.go index 2e9b04b..6a2caee 100644 --- a/shared/client/option.go +++ b/shared/client/option.go @@ -164,23 +164,13 @@ func (o *latencyBreakerOpt) Apply(c *Service) { } // WithLatencyBreaker enables the latency-aware breaker on each host -// constructed for this Service. Pass-through fraction defaults to 0.01 -// when fraction <= 0; rolling window defaults to 1s; KConsecutive -// defaults to 3. +// constructed for this Service. Defaults are applied during Service +// init: pass-through fraction 0.01, rolling window 1s, KConsecutive 3. // // Setting both latest and rolling to zero leaves the breaker disabled // (backward compatible). Both thresholds are taken as raw durations; // the caller is responsible for sizing them appropriately for the // model's traffic profile and the caller's request timeout. func WithLatencyBreaker(latest, rolling, window time.Duration, k int, fraction float64) Option { - if fraction <= 0 { - fraction = 0.01 - } - if window <= 0 { - window = time.Second - } - if k < 1 { - k = 3 - } return &latencyBreakerOpt{latest: latest, rolling: rolling, window: window, k: k, fraction: fraction} } diff --git a/shared/client/service.go b/shared/client/service.go index 999620a..076e056 100644 --- a/shared/client/service.go +++ b/shared/client/service.go @@ -324,35 +324,39 @@ func (s *Service) init() error { s.Config.MaxRetry = 3 } - s.initLatencyBreakers() + if err := s.initLatencyBreakers(); err != nil { + return fmt.Errorf("failed to initialize latency breakers: %w", err) + } err := s.initHTTPClient() if err != nil { - return err + return fmt.Errorf("failed to initialize HTTP client: %w", err) } if s.Config.Datastore == nil { if err := s.loadModelConfig(); err != nil { - return err + return fmt.Errorf("failed to load model config: %w", err) } } if s.dict == nil { if err := s.loadModelDictionary(); err != nil { - return err + return fmt.Errorf("failed to load model dictionary: %w", err) } } if ds := s.Config.Datastore; ds != nil { ds.Init() if err = ds.Validate(); err != nil { - return err + return fmt.Errorf("failed to validate datastore config: %w", err) } } if s.datastore == nil { err := s.initDatastore() - return err + if err != nil { + return fmt.Errorf("failed to initialize datastore: %w", err) + } } s.messages = NewMessages(s.dictionary) @@ -365,7 +369,7 @@ func (s *Service) initHTTPClient() error { if host != nil && host.IsSecurePort() { cert, err := getCertPool() if err != nil { - return fmt.Errorf("failed to create certificate: %v", err) + return fmt.Errorf("failed to create certificate: %w", err) } tslConfig = &tls.Config{ @@ -653,22 +657,27 @@ func (s *Service) postRequest(ctx context.Context, data []byte, mvt *stat.Values // configured host when the Config has at least one non-zero threshold. // Both thresholds zero -> no breaker constructed (backward-compatible // no-op). -func (s *Service) initLatencyBreakers() { - if s.Config.LatencyBreakerLatestThreshold == 0 && s.Config.LatencyBreakerRollingThreshold == 0 { - return +func (s *Service) initLatencyBreakers() error { + settings, enabled, err := s.Config.latencyBreakerSettings() + if err != nil { + return err + } + if !enabled { + return nil } for _, h := range s.Config.Hosts { if h == nil || h.LatencyBreaker != nil { continue } h.LatencyBreaker = circut.NewLatencyBreaker( - s.Config.LatencyBreakerLatestThreshold, - s.Config.LatencyBreakerRollingThreshold, - s.Config.LatencyBreakerRollingWindow, - s.Config.LatencyBreakerKConsecutive, - s.Config.LatencyBreakerPassThroughFraction, + settings.latest, + settings.rolling, + settings.window, + settings.k, + settings.fraction, ) } + return nil } func (s *Service) httpPost(ctx context.Context, data []byte, host *Host) ([]byte, error) {