diff --git a/packages/opencode/src/session/prompt.ts b/packages/opencode/src/session/prompt.ts index fef8c438366c..e89ee84f99c7 100644 --- a/packages/opencode/src/session/prompt.ts +++ b/packages/opencode/src/session/prompt.ts @@ -621,7 +621,11 @@ NOTE: At any point in time through this workflow you should feel free to ask the sessionID, abort: taskAbort.signal, callID: part.callID, - extra: { bypassAgentCheck: true, promptOps }, + extra: { + bypassAgentCheck: true, + promptOps, + ...(task.command ? { taskModel: { providerID: taskModel.providerID, modelID: taskModel.id } } : {}), + }, messages: msgs, metadata: (val: { title?: string; metadata?: Record }) => Effect.gen(function* () { diff --git a/packages/opencode/src/tool/task.ts b/packages/opencode/src/tool/task.ts index 22e4e5671c89..cade561fd4af 100644 --- a/packages/opencode/src/tool/task.ts +++ b/packages/opencode/src/tool/task.ts @@ -8,6 +8,7 @@ import type { SessionPrompt } from "../session/prompt" import { Config } from "@/config/config" import { Effect, Exit, Schema } from "effect" import { EffectBridge } from "@/effect/bridge" +import { ModelID, ProviderID } from "@/provider/schema" export interface TaskPromptOps { cancel(sessionID: SessionID): Effect.Effect @@ -15,6 +16,22 @@ export interface TaskPromptOps { prompt(input: SessionPrompt.PromptInput): Effect.Effect } +type TaskModel = { + providerID: ProviderID + modelID: ModelID +} + +function internalTaskModel(value: unknown): TaskModel | undefined { + if (typeof value !== "object" || value === null) return undefined + if (!("providerID" in value) || !("modelID" in value)) return undefined + if (typeof value.providerID !== "string" || typeof value.modelID !== "string") return undefined + return { providerID: ProviderID.make(value.providerID), modelID: ModelID.make(value.modelID) } +} + +function resolveTaskModel(input: { internal?: TaskModel; agent?: TaskModel; parent: TaskModel }) { + return input.internal ?? input.agent ?? input.parent +} + const id = "task" export const Parameters = Schema.Struct({ @@ -61,11 +78,23 @@ export const TaskTool = Tool.define( const canTask = next.permission.some((rule) => rule.permission === id) const canTodo = next.permission.some((rule) => rule.permission === "todowrite") + const parent = yield* sessions.get(ctx.sessionID) + const msg = yield* Effect.sync(() => MessageV2.get({ sessionID: ctx.sessionID, messageID: ctx.messageID })) + if (msg.info.role !== "assistant") return yield* Effect.fail(new Error("Not an assistant message")) + + const model = resolveTaskModel({ + internal: internalTaskModel(ctx.extra?.taskModel), + agent: next.model, + parent: { + modelID: msg.info.modelID, + providerID: msg.info.providerID, + }, + }) + const taskID = params.task_id const session = taskID ? yield* sessions.get(SessionID.make(taskID)).pipe(Effect.catchCause(() => Effect.succeed(undefined))) : undefined - const parent = yield* sessions.get(ctx.sessionID) const nextSession = session ?? (yield* sessions.create({ @@ -101,14 +130,6 @@ export const TaskTool = Tool.define( ], })) - const msg = yield* Effect.sync(() => MessageV2.get({ sessionID: ctx.sessionID, messageID: ctx.messageID })) - if (msg.info.role !== "assistant") return yield* Effect.fail(new Error("Not an assistant message")) - - const model = next.model ?? { - modelID: msg.info.modelID, - providerID: msg.info.providerID, - } - yield* ctx.metadata({ title: params.description, metadata: { diff --git a/packages/opencode/test/session/prompt.test.ts b/packages/opencode/test/session/prompt.test.ts index c5170f346492..9291e0a0839b 100644 --- a/packages/opencode/test/session/prompt.test.ts +++ b/packages/opencode/test/session/prompt.test.ts @@ -64,6 +64,11 @@ const ref = { modelID: ModelID.make("test-model"), } +const commandModel = { + providerID: ProviderID.make("test"), + modelID: ModelID.make("command-model"), +} + function defer() { let resolve!: (value: T | PromiseLike) => void const promise = new Promise((done) => { @@ -895,6 +900,107 @@ it.live( 10_000, ) +it.live( + "command subtask model overrides target subagent configured model", + () => + provideTmpdirServer( + Effect.fnUntraced(function* ({ llm }) { + const prompt = yield* SessionPrompt.Service + const sessions = yield* Session.Service + const chat = yield* sessions.create({ title: "Pinned" }) + yield* llm.text("done") + + yield* prompt.command({ + sessionID: chat.id, + command: "cheap-review", + arguments: "check routing", + }) + + const msgs = yield* MessageV2.filterCompactedEffect(chat.id) + const taskMsg = msgs.find((item) => item.info.role === "assistant" && item.info.agent === "general") + expect(taskMsg?.info.role).toBe("assistant") + if (!taskMsg || taskMsg.info.role !== "assistant") return + + expect(taskMsg.info.modelID).toBe(commandModel.modelID) + expect(taskMsg.info.providerID).toBe(commandModel.providerID) + + const tool = completedTool(taskMsg.parts) + expect(tool?.state.metadata?.model).toEqual(commandModel) + const childSessionID = tool?.state.metadata?.sessionId + expect(typeof childSessionID).toBe("string") + if (typeof childSessionID !== "string") return + + const childMsgs = yield* sessions.messages({ sessionID: SessionID.make(childSessionID) }) + const childUser = childMsgs.find((item) => item.info.role === "user")?.info + expect(childUser?.role).toBe("user") + if (childUser?.role !== "user") return + expect(childUser.model).toEqual(commandModel) + }), + { + git: true, + config: (url) => ({ + ...providerCfg(url), + provider: { + test: { + ...providerCfg(url).provider.test, + models: { + ...providerCfg(url).provider.test.models, + "command-model": { + ...providerCfg(url).provider.test.models["test-model"], + id: "command-model", + name: "Command Model", + }, + "agent-model": { + ...providerCfg(url).provider.test.models["test-model"], + id: "agent-model", + name: "Agent Model", + }, + }, + }, + }, + agent: { + general: { + model: "test/agent-model", + }, + }, + command: { + "cheap-review": { + agent: "general", + model: "test/command-model", + template: "Review: $ARGUMENTS", + }, + }, + }), + }, + ), + 10_000, +) + +it.live("session create model metadata does not control prompt model", () => + provideTmpdirServer( + Effect.fnUntraced(function* () { + const prompt = yield* SessionPrompt.Service + const sessions = yield* Session.Service + const chat = yield* sessions.create({ + title: "Pinned", + model: { id: ModelID.make("session-model"), providerID: ProviderID.make("test") }, + }) + + const result = yield* prompt.prompt({ + sessionID: chat.id, + agent: "build", + noReply: true, + parts: [{ type: "text", text: "hello" }], + }) + + expect(result.info.role).toBe("user") + if (result.info.role !== "user") return + expect(result.info.model).toEqual(ref) + }), + { git: true, config: providerCfg }, + ), +) + it.live( "cancel with queued callers resolves all cleanly", () => diff --git a/packages/opencode/test/tool/task.test.ts b/packages/opencode/test/tool/task.test.ts index f75fcf84b8a9..c73bc37456ea 100644 --- a/packages/opencode/test/tool/task.test.ts +++ b/packages/opencode/test/tool/task.test.ts @@ -23,6 +23,16 @@ const ref = { modelID: ModelID.make("test-model"), } +const commandModel = { + providerID: ProviderID.make("test"), + modelID: ModelID.make("command-model"), +} + +const agentModel = { + providerID: ProviderID.make("test"), + modelID: ModelID.make("agent-model"), +} + const it = testEffect( Layer.mergeAll( Agent.defaultLayer, @@ -362,6 +372,96 @@ describe("tool.task", () => { }), ) + it.instance( + "execute treats internal task model as authoritative for subtasks", + () => + Effect.gen(function* () { + const sessions = yield* Session.Service + const { chat, assistant } = yield* seed() + const tool = yield* TaskTool + const def = yield* tool.init() + let seen: SessionPrompt.PromptInput | undefined + const promptOps = stubOps({ onPrompt: (input) => (seen = input) }) + + const result = yield* def.execute( + { + description: "inspect bug", + prompt: "look into the cache key path", + subagent_type: "general", + }, + { + sessionID: chat.id, + messageID: assistant.id, + agent: "build", + abort: new AbortController().signal, + extra: { promptOps, taskModel: commandModel }, + messages: [], + metadata: () => Effect.void, + ask: () => Effect.void, + }, + ) + + const child = yield* sessions.get(result.metadata.sessionId) + expect(seen?.agent).toBe("general") + expect(seen?.model).toEqual(commandModel) + expect(child.agent).toBeUndefined() + expect(child.model).toBeUndefined() + }), + { + config: { + agent: { + general: { + model: "test/agent-model", + }, + }, + }, + }, + ) + + it.instance( + "execute falls back to agent model then parent model without stale internal task model", + () => + Effect.gen(function* () { + const { chat, assistant } = yield* seed() + const tool = yield* TaskTool + const def = yield* tool.init() + const seen: SessionPrompt.PromptInput[] = [] + const promptOps = stubOps({ onPrompt: (input) => seen.push(input) }) + + const ctx = (extra?: Record) => ({ + sessionID: chat.id, + messageID: assistant.id, + agent: "build", + abort: new AbortController().signal, + extra: { promptOps, ...extra }, + messages: [], + metadata: () => Effect.void, + ask: () => Effect.void, + }) + + yield* def.execute( + { description: "first", prompt: "use command model", subagent_type: "reviewer" }, + ctx({ taskModel: commandModel }), + ) + yield* def.execute({ description: "second", prompt: "use agent model", subagent_type: "reviewer" }, ctx()) + yield* def.execute({ description: "third", prompt: "use parent model", subagent_type: "general" }, ctx()) + + expect(seen[0]?.model).toEqual(commandModel) + expect(seen[1]?.model).toEqual(agentModel) + expect(seen[2]?.model).toEqual(ref) + }), + { + config: { + agent: { + reviewer: { + mode: "subagent", + model: "test/agent-model", + }, + }, + }, + }, + ) + it.instance( "execute shapes child permissions for task, todowrite, and primary tools", () =>