Skip to content

Commit b811626

Browse files
authored
fix: native tool type mapping for freeform apply_patch and web_search
fix: native tool type mapping for freeform apply_patch and web_search - Maps original Responses tool types through streaming and non-streaming paths - Emit custom_tool_call for apply_patch (freeform) instead of generic function_call - Emit web_search_call for web_search tools so Codex Desktop handles them correctly - Add server-side DuckDuckGo web search for non-streaming BYOK responses - Add comprehensive tests for all tool type mapping scenarios
1 parent 08f12b0 commit b811626

3 files changed

Lines changed: 403 additions & 22 deletions

File tree

codex_shim/server.py

Lines changed: 230 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,8 @@ async def _cursor_passthrough(
620620

621621
response = _sse_response()
622622
await response.prepare(request)
623-
state = ResponsesStreamState(slug)
623+
tool_types = _build_tool_types(body)
624+
state = ResponsesStreamState(slug, tool_types)
624625
try:
625626
await state.start(response)
626627
async for event in iter_cursor_agent_events(prompt, upstream):
@@ -757,10 +758,13 @@ async def _post_openai_chat(
757758
if upstream.status >= 400:
758759
return await _error_response(upstream, slug=route.slug)
759760
if body.get("stream"):
760-
return await self._stream_openai_chat(request, upstream, route, as_responses)
761+
return await self._stream_openai_chat(request, upstream, route, as_responses, body)
761762
payload = await upstream.json(content_type=None)
762763
if as_responses:
763-
return web.json_response(chat_completion_to_response(payload, route.slug))
764+
tool_types = _build_tool_types(body)
765+
payload = chat_completion_to_response(payload, route.slug, tool_types)
766+
intercepted = _maybe_intercept_web_search(payload)
767+
return web.json_response(intercepted or payload)
764768
return web.json_response(payload)
765769

766770
async def _post_openai_chat_as_anthropic(
@@ -788,10 +792,13 @@ async def _post_anthropic(
788792
if upstream.status >= 400:
789793
return await _error_response(upstream)
790794
if body.get("stream"):
791-
return await self._stream_anthropic(request, upstream, route, as_responses)
795+
return await self._stream_anthropic(request, upstream, route, as_responses, body)
792796
payload = await upstream.json(content_type=None)
793797
if as_responses:
794-
return web.json_response(anthropic_to_response(payload, route.slug))
798+
tool_types = _build_tool_types(body)
799+
payload = anthropic_to_response(payload, route.slug, tool_types)
800+
intercepted = _maybe_intercept_web_search(payload)
801+
return web.json_response(intercepted or payload)
795802
return web.json_response(anthropic_to_chat_response(payload, route.slug))
796803

797804
async def _post_anthropic_messages(
@@ -811,12 +818,13 @@ async def _post_anthropic_messages(
811818
return web.json_response(payload)
812819

813820
async def _stream_openai_chat(
814-
self, request: web.Request, upstream, route: ShimModel, as_responses: bool
821+
self, request: web.Request, upstream, route: ShimModel, as_responses: bool, body: dict[str, Any] | None = None
815822
) -> web.StreamResponse:
816823
response = _sse_response()
817824
await response.prepare(request)
818825
if as_responses:
819-
state = ResponsesStreamState(route.slug)
826+
tool_types = _build_tool_types(body) if body else {}
827+
state = ResponsesStreamState(route.slug, tool_types)
820828
try:
821829
if as_responses:
822830
await state.start(response)
@@ -873,12 +881,13 @@ async def _stream_openai_chat_as_anthropic(
873881
return response
874882

875883
async def _stream_anthropic(
876-
self, request: web.Request, upstream, route: ShimModel, as_responses: bool
884+
self, request: web.Request, upstream, route: ShimModel, as_responses: bool, body: dict[str, Any] | None = None
877885
) -> web.StreamResponse:
878886
response = _sse_response()
879887
await response.prepare(request)
880888
if as_responses:
881-
state = ResponsesStreamState(route.slug)
889+
tool_types = _build_tool_types(body) if body else {}
890+
state = ResponsesStreamState(route.slug, tool_types)
882891
try:
883892
if as_responses:
884893
await state.start(response)
@@ -1207,22 +1216,22 @@ class ResponsesStreamState:
12071216
proper .added / .delta / .done / .completed events plus a final
12081217
`response.completed` with the full reconciled `output` array."""
12091218

1210-
def __init__(self, model: str):
1219+
def __init__(self, model: str, tool_types: dict[str, str] | None = None):
12111220
self.response_id = f"resp_{int(time.time() * 1000)}"
12121221
self.message_item_id = f"msg_{int(time.time() * 1000)}"
12131222
self.model = model
1214-
self.message_index: int | None = None # output_index for the assistant message
1223+
self.message_index: int | None = None
12151224
self.message_text = ""
12161225
self.message_opened = False
12171226
self.message_closed = False
12181227
self.usage: dict[str, Any] | None = None
1219-
# Tool call state, keyed by upstream "index" (chat-completions) or
1220-
# anthropic content_block_index. Each entry tracks its assigned
1221-
# output_index, accumulated arguments, name, etc.
12221228
self.tool_calls: dict[int, dict[str, Any]] = {}
1223-
# Reasoning (extended thinking) blocks, keyed by upstream index.
12241229
self.reasoning_blocks: dict[Any, dict[str, Any]] = {}
12251230
self.next_output_index = 0
1231+
# Map sanitized tool name -> original Responses tool type so we can
1232+
# emit the correct output item type (e.g. custom_tool_call for freeform
1233+
# apply_patch instead of generic function_call).
1234+
self.tool_types = tool_types or {}
12261235

12271236
# ------------------------------------------------------------------
12281237
# Lifecycle
@@ -1493,13 +1502,23 @@ async def _open_tool(self, response: web.StreamResponse, *, key: Any, call_id: s
14931502
await self._close_message(response)
14941503
output_index = self.next_output_index
14951504
self.next_output_index += 1
1505+
# Determine output item type based on original tool type.
1506+
# Freeform tools (apply_patch with no schema) emit custom_tool_call
1507+
# so Codex Desktop knows not to validate against a fixed enum.
1508+
original_type = self.tool_types.get(name, "")
1509+
output_type = "function_call"
1510+
if original_type == "apply_patch":
1511+
output_type = "custom_tool_call"
1512+
elif original_type.startswith("web_search"):
1513+
output_type = "web_search_call"
14961514
state: dict[str, Any] = {
14971515
"id": call_id,
14981516
"call_id": call_id,
14991517
"name": name,
15001518
"arguments": "",
15011519
"output_index": output_index,
15021520
"closed": False,
1521+
"output_type": output_type,
15031522
}
15041523
self.tool_calls[key] = state
15051524
await _write_sse(
@@ -1509,7 +1528,7 @@ async def _open_tool(self, response: web.StreamResponse, *, key: Any, call_id: s
15091528
"output_index": output_index,
15101529
"item": {
15111530
"id": call_id,
1512-
"type": "function_call",
1531+
"type": output_type,
15131532
"status": "in_progress",
15141533
"call_id": call_id,
15151534
"name": name,
@@ -1654,7 +1673,7 @@ def _message_item(self, status: str) -> dict[str, Any]:
16541673
def _tool_item(self, state: dict[str, Any], status: str) -> dict[str, Any]:
16551674
return {
16561675
"id": state["id"],
1657-
"type": "function_call",
1676+
"type": state.get("output_type", "function_call"),
16581677
"status": status,
16591678
"call_id": state["call_id"],
16601679
"name": state["name"],
@@ -1712,6 +1731,200 @@ def _decode_thinking_payload(encoded: str) -> dict[str, Any] | None:
17121731
return data if isinstance(data, dict) else None
17131732

17141733

1734+
def _build_tool_types(body: dict[str, Any]) -> dict[str, str]:
1735+
"""Build a map sanitized tool name -> original tool type from the request tools array.
1736+
1737+
Codex Desktop emits native tools like `{"type": "apply_patch"}` and MCP tools
1738+
like `{"type": "mcp__node_repl", "function": {"name": "js"}}`. When we translate
1739+
those into chat-completions `function` tools, the original type is lost. We
1740+
preserve it here so the Responses streaming translator can emit the correct
1741+
output item type (e.g. `custom_tool_call` for freeform apply_patch instead of
1742+
generic `function_call`).
1743+
"""
1744+
tool_types: dict[str, str] = {}
1745+
for tool in body.get("tools") or []:
1746+
if not isinstance(tool, dict):
1747+
continue
1748+
tool_type = str(tool.get("type") or "").strip().lower()
1749+
fn = tool.get("function")
1750+
if isinstance(fn, dict) and fn.get("name"):
1751+
name = str(fn["name"]).strip()
1752+
elif tool.get("name"):
1753+
name = str(tool["name"]).strip()
1754+
else:
1755+
name = tool_type
1756+
clean = re.sub(r"[^a-zA-Z0-9_-]+", "_", name.strip())[:64].strip("_")
1757+
if clean:
1758+
tool_types[clean] = tool_type
1759+
return tool_types
1760+
1761+
async def _perform_web_search(query: str) -> str:
1762+
"""Execute a web search via DuckDuckGo and return text results.
1763+
1764+
This is a server-side fallback for custom models whose provider does not
1765+
have a native web-search capability. Codex Desktop expects the shim to
1766+
return results as a `function_call_output` (or `web_search_call`) item;
1767+
when the model is BYOK, the Desktop app does not execute the search itself,
1768+
so the shim must do it and feed the results back into the conversation.
1769+
"""
1770+
import urllib.parse
1771+
import urllib.request
1772+
1773+
if not query or not query.strip():
1774+
return "No search query provided."
1775+
1776+
# DuckDuckGo lite HTML endpoint (no API key required)
1777+
url = (
1778+
"https://html.duckduckgo.com/html/"
1779+
+ "?q="
1780+
+ urllib.parse.quote_plus(query.strip())
1781+
)
1782+
req = urllib.request.Request(
1783+
url,
1784+
headers={
1785+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
1786+
"AppleWebKit/537.36 (KHTML, like Gecko) "
1787+
"Chrome/120.0.0.0 Safari/537.36"
1788+
},
1789+
)
1790+
try:
1791+
with urllib.request.urlopen(req, timeout=10) as resp:
1792+
html = resp.read().decode("utf-8", errors="replace")
1793+
except Exception as exc:
1794+
return f"Web search failed: {exc}"
1795+
1796+
# Extract title + snippet from result links
1797+
results: list[str] = []
1798+
# Each result is in a `.result` div with `.result__a` (title/link) and `.result__snippet`
1799+
from html.parser import HTMLParser
1800+
1801+
class _ResultParser(HTMLParser):
1802+
def __init__(self) -> None:
1803+
super().__init__()
1804+
self.in_result = False
1805+
self.in_a = False
1806+
self.in_snippet = False
1807+
self.current_title = ""
1808+
self.current_snippet = ""
1809+
self.results: list[dict[str, str]] = []
1810+
self._tag_stack: list[str] = []
1811+
self._class_stack: list[str] = []
1812+
1813+
def _current_class(self) -> str:
1814+
return self._class_stack[-1] if self._class_stack else ""
1815+
1816+
def handle_starttag(self, tag: str, attrs_list: list[tuple[str, str | None]]) -> None:
1817+
attrs = dict(attrs_list)
1818+
cls = (attrs.get("class") or "").lower()
1819+
self._tag_stack.append(tag)
1820+
self._class_stack.append(cls)
1821+
if "result" in cls and tag == "div":
1822+
self.in_result = True
1823+
self.current_title = ""
1824+
self.current_snippet = ""
1825+
if self.in_result and tag == "a" and "result__a" in cls:
1826+
self.in_a = True
1827+
if self.in_result and ("result__snippet" in cls or "result__body" in cls):
1828+
self.in_snippet = True
1829+
1830+
def handle_endtag(self, tag: str) -> None:
1831+
if self._tag_stack and self._tag_stack[-1] == tag:
1832+
self._tag_stack.pop()
1833+
self._class_stack.pop()
1834+
if tag == "div" and self.in_result:
1835+
if self.current_title or self.current_snippet:
1836+
self.results.append(
1837+
{
1838+
"title": self.current_title.strip(),
1839+
"snippet": self.current_snippet.strip(),
1840+
}
1841+
)
1842+
self.in_result = False
1843+
if tag == "a":
1844+
self.in_a = False
1845+
if tag in {"div", "span", "p"}:
1846+
self.in_snippet = False
1847+
1848+
def handle_data(self, data: str) -> None:
1849+
if self.in_a:
1850+
self.current_title += data
1851+
if self.in_snippet:
1852+
self.current_snippet += data
1853+
1854+
parser = _ResultParser()
1855+
parser.feed(html)
1856+
for r in parser.results[:5]:
1857+
title = r["title"].replace("\n", " ")
1858+
snippet = r["snippet"].replace("\n", " ")
1859+
if title and snippet:
1860+
results.append(f"{title}\n{snippet}")
1861+
elif title:
1862+
results.append(title)
1863+
elif snippet:
1864+
results.append(snippet)
1865+
1866+
if not results:
1867+
return "No web search results found."
1868+
return "\n\n".join(results)
1869+
1870+
def _maybe_intercept_web_search(payload: dict[str, Any]) -> dict[str, Any] | None:
1871+
"""If the response payload contains a web_search_call, execute it server-side
1872+
and return a new payload with the results embedded as a function_call_output.
1873+
1874+
Returns None if no web_search_call is present (pass through unchanged).
1875+
"""
1876+
output = payload.get("output") or []
1877+
if not isinstance(output, list):
1878+
return None
1879+
search_calls: list[tuple[int, dict[str, Any]]] = []
1880+
for i, item in enumerate(output):
1881+
if isinstance(item, dict) and item.get("type") == "web_search_call":
1882+
search_calls.append((i, item))
1883+
if not search_calls:
1884+
return None
1885+
1886+
# Build synthetic search results
1887+
results: list[dict[str, Any]] = []
1888+
for idx, call in search_calls:
1889+
try:
1890+
args = json.loads(call.get("arguments") or "{}")
1891+
except json.JSONDecodeError:
1892+
args = {}
1893+
query = args.get("query") or ""
1894+
# Run the search synchronously (non-streaming path only)
1895+
import asyncio
1896+
try:
1897+
loop = asyncio.get_running_loop()
1898+
result_text = loop.run_until_complete(_perform_web_search(query))
1899+
except RuntimeError:
1900+
result_text = "Web search unavailable in this context."
1901+
results.append({
1902+
"id": f"wso_{call.get('call_id', '0')}",
1903+
"type": "function_call_output",
1904+
"status": "completed",
1905+
"call_id": call.get("call_id"),
1906+
"output": result_text,
1907+
})
1908+
1909+
# Replace web_search_call items with their results
1910+
new_output: list[dict[str, Any]] = []
1911+
for i, item in enumerate(output):
1912+
if isinstance(item, dict) and item.get("type") == "web_search_call":
1913+
# Find matching result
1914+
for r in results:
1915+
if r.get("call_id") == item.get("call_id"):
1916+
new_output.append(r)
1917+
break
1918+
else:
1919+
new_output.append(item)
1920+
else:
1921+
new_output.append(item)
1922+
1923+
new_payload = dict(payload)
1924+
new_payload["output"] = new_output
1925+
return new_payload
1926+
1927+
17151928
_VERSIONED_BASE_RE = re.compile(r"(?:^|/)v\d+$")
17161929

17171930

codex_shim/translate.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def anthropic_to_chat_response(payload: dict[str, Any], requested_model: str) ->
315315
}
316316

317317

318-
def chat_completion_to_response(payload: dict[str, Any], requested_model: str) -> dict[str, Any]:
318+
def chat_completion_to_response(payload: dict[str, Any], requested_model: str, tool_types: dict[str, str] | None = None) -> dict[str, Any]:
319319
choice = (payload.get("choices") or [{}])[0]
320320
message = choice.get("message") or {}
321321
output: list[dict[str, Any]] = []
@@ -340,15 +340,23 @@ def chat_completion_to_response(payload: dict[str, Any], requested_model: str) -
340340
"content": [{"type": "output_text", "text": text, "annotations": []}],
341341
}
342342
)
343+
tool_types = tool_types or {}
343344
for call in message.get("tool_calls") or []:
344345
fn = call.get("function") or {}
346+
name = fn.get("name", "")
347+
original_type = tool_types.get(name, "")
348+
item_type = "function_call"
349+
if original_type == "apply_patch":
350+
item_type = "custom_tool_call"
351+
elif original_type.startswith("web_search"):
352+
item_type = "web_search_call"
345353
output.append(
346354
{
347355
"id": call.get("id", "call_0"),
348-
"type": "function_call",
356+
"type": item_type,
349357
"status": "completed",
350358
"call_id": call.get("id", "call_0"),
351-
"name": fn.get("name", ""),
359+
"name": name,
352360
"arguments": fn.get("arguments", ""),
353361
}
354362
)
@@ -363,8 +371,8 @@ def chat_completion_to_response(payload: dict[str, Any], requested_model: str) -
363371
}
364372

365373

366-
def anthropic_to_response(payload: dict[str, Any], requested_model: str) -> dict[str, Any]:
367-
response = chat_completion_to_response(anthropic_to_chat_response(payload, requested_model), requested_model)
374+
def anthropic_to_response(payload: dict[str, Any], requested_model: str, tool_types: dict[str, str] | None = None) -> dict[str, Any]:
375+
response = chat_completion_to_response(anthropic_to_chat_response(payload, requested_model), requested_model, tool_types)
368376
response["usage"] = normalize_responses_usage(payload.get("usage"))
369377
return response
370378

0 commit comments

Comments
 (0)