Skip to content

Commit cc0801f

Browse files
committed
tests: add --mcp-mode to run the IDA suite end-to-end over MCP
1 parent 307c8be commit cc0801f

4 files changed

Lines changed: 744 additions & 7 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ dependencies = [
2727
dev = [
2828
"coverage>=7.13.4",
2929
"jsonschema>=4.0",
30+
"mcp>=1.0",
3031
]
3132

3233
[project.urls]
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
"""End-to-end MCP mode for the test framework.
2+
3+
When enabled, every @tool-decorated function is monkeypatched to route
4+
through a real MCP client/server HTTP round-trip and the response is
5+
validated against its advertised outputSchema. Lets the existing test
6+
suite double as a runtime-conformance check without rewriting anything.
7+
"""
8+
9+
import asyncio
10+
import inspect
11+
import socket
12+
import sys
13+
import threading
14+
from contextlib import contextmanager
15+
from functools import wraps
16+
from typing import Any, Callable, Iterator
17+
18+
from jsonschema import Draft202012Validator
19+
20+
from ..rpc import MCP_SERVER
21+
22+
23+
class _AsyncHarness:
24+
"""Event loop on a background thread; sync callers schedule coros."""
25+
26+
def __init__(self):
27+
self.loop: asyncio.AbstractEventLoop | None = None
28+
self.thread: threading.Thread | None = None
29+
self._ready = threading.Event()
30+
31+
def start(self):
32+
def loop_target():
33+
self.loop = asyncio.new_event_loop()
34+
asyncio.set_event_loop(self.loop)
35+
self._ready.set()
36+
self.loop.run_forever()
37+
38+
self.thread = threading.Thread(target=loop_target, daemon=True)
39+
self.thread.start()
40+
self._ready.wait()
41+
42+
def stop(self):
43+
if self.loop is not None:
44+
self.loop.call_soon_threadsafe(self.loop.stop)
45+
if self.thread is not None:
46+
self.thread.join(timeout=2)
47+
48+
def run(self, coro, timeout: float = 20.0):
49+
fut = asyncio.run_coroutine_threadsafe(coro, self.loop)
50+
return fut.result(timeout=timeout)
51+
52+
53+
class _McpMode:
54+
def __init__(self):
55+
self.harness = _AsyncHarness()
56+
self.host = "127.0.0.1"
57+
self.port: int | None = None
58+
self.session = None
59+
self._http_ctx = None
60+
self._session_ctx = None
61+
self.output_schemas: dict[str, dict] = {}
62+
self._patched: list[tuple[Any, str, Callable]] = []
63+
64+
def enable(self):
65+
self.driver_thread_id = threading.get_ident()
66+
self.harness.start()
67+
self.port = _pick_free_port(self.host)
68+
MCP_SERVER.serve(self.host, self.port, background=True)
69+
_wait_until_ready(f"http://{self.host}:{self.port}/mcp")
70+
self.harness.run(self._connect())
71+
self._patch_tools()
72+
73+
def disable(self):
74+
self._unpatch_tools()
75+
try:
76+
self.harness.run(self._disconnect())
77+
except Exception:
78+
pass
79+
try:
80+
MCP_SERVER.stop()
81+
except Exception:
82+
pass
83+
self.harness.stop()
84+
85+
async def _connect(self):
86+
from mcp import ClientSession
87+
from mcp.client.streamable_http import streamablehttp_client
88+
89+
self._http_ctx = streamablehttp_client(f"http://{self.host}:{self.port}/mcp")
90+
read, write, _ = await self._http_ctx.__aenter__()
91+
self._session_ctx = ClientSession(read, write)
92+
self.session = await self._session_ctx.__aenter__()
93+
await self.session.initialize()
94+
tools = await self.session.list_tools()
95+
for t in tools.tools:
96+
if getattr(t, "outputSchema", None):
97+
self.output_schemas[t.name] = t.outputSchema
98+
99+
async def _disconnect(self):
100+
if self._session_ctx is not None:
101+
await self._session_ctx.__aexit__(None, None, None)
102+
if self._http_ctx is not None:
103+
await self._http_ctx.__aexit__(None, None, None)
104+
105+
def _patch_tools(self):
106+
for name, original in MCP_SERVER.tools.methods.items():
107+
mod = sys.modules.get(original.__module__)
108+
if mod is None:
109+
continue
110+
if getattr(mod, name, None) is not original:
111+
continue
112+
proxy = self._make_proxy(name, original)
113+
self._patched.append((mod, name, original))
114+
setattr(mod, name, proxy)
115+
116+
def _unpatch_tools(self):
117+
for mod, name, original in self._patched:
118+
setattr(mod, name, original)
119+
self._patched.clear()
120+
121+
def _make_proxy(self, name: str, original: Callable) -> Callable:
122+
sig = inspect.signature(original)
123+
schema = self.output_schemas.get(name)
124+
harness = self.harness
125+
state = self
126+
127+
@wraps(original)
128+
def proxy(*args, **kwargs):
129+
# Only route through MCP for calls originating on the test-driver
130+
# thread. Nested tool calls run on the server's HTTP handler thread
131+
# and must execute directly, otherwise we'd re-enter the event
132+
# loop the outer call is blocking on.
133+
if threading.get_ident() != state.driver_thread_id:
134+
return original(*args, **kwargs)
135+
136+
bound = sig.bind(*args, **kwargs)
137+
bound.apply_defaults()
138+
arguments = dict(bound.arguments)
139+
140+
async def _call():
141+
return await state.session.call_tool(name, arguments=arguments)
142+
143+
result = harness.run(_call())
144+
145+
if getattr(result, "isError", False):
146+
msg = ""
147+
if result.content:
148+
msg = getattr(result.content[0], "text", str(result.content[0]))
149+
raise RuntimeError(f"MCP tool {name!r} returned isError: {msg}")
150+
151+
structured = getattr(result, "structuredContent", None)
152+
if schema is not None and structured is not None:
153+
Draft202012Validator(schema).validate(structured)
154+
155+
if (
156+
isinstance(structured, dict)
157+
and set(structured.keys()) == {"result"}
158+
):
159+
return structured["result"]
160+
return structured
161+
162+
return proxy
163+
164+
165+
_instance: _McpMode | None = None
166+
167+
168+
def enable_mcp_mode() -> None:
169+
global _instance
170+
assert _instance is None, "MCP mode already enabled"
171+
_instance = _McpMode()
172+
_instance.enable()
173+
174+
175+
def disable_mcp_mode() -> None:
176+
global _instance
177+
if _instance is not None:
178+
_instance.disable()
179+
_instance = None
180+
181+
182+
@contextmanager
183+
def mcp_mode() -> Iterator[None]:
184+
enable_mcp_mode()
185+
try:
186+
yield
187+
finally:
188+
disable_mcp_mode()
189+
190+
191+
def _pick_free_port(host: str) -> int:
192+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
193+
s.bind((host, 0))
194+
port = s.getsockname()[1]
195+
s.close()
196+
return port
197+
198+
199+
def _wait_until_ready(url: str, timeout: float = 2.0) -> None:
200+
import time
201+
import urllib.error
202+
import urllib.request
203+
204+
deadline = time.time() + timeout
205+
while time.time() < deadline:
206+
try:
207+
req = urllib.request.Request(url, method="OPTIONS")
208+
with urllib.request.urlopen(req, timeout=0.2):
209+
return
210+
except urllib.error.HTTPError:
211+
return
212+
except (urllib.error.URLError, ConnectionRefusedError, socket.timeout):
213+
time.sleep(0.02)
214+
raise RuntimeError(f"server at {url} not ready within {timeout}s")

src/ida_pro_mcp/test.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,12 @@ def main() -> int:
7676
action="store_true",
7777
help="Show IDA console messages",
7878
)
79+
parser.add_argument(
80+
"--mcp-mode",
81+
action="store_true",
82+
help="Route every @tool call through a real MCP client/server "
83+
"round-trip and validate responses against outputSchema",
84+
)
7985
args = parser.parse_args()
8086

8187
# Check binary exists
@@ -144,13 +150,24 @@ def main() -> int:
144150
in_ci = os.environ.get("CI", "").lower() not in ("", "0", "false", "no")
145151
interactive_output = sys.stdout.isatty()
146152
show_all_test_output = (not args.quiet) and (interactive_output or in_ci)
147-
results = run_tests(
148-
pattern=args.pattern,
149-
category=args.category,
150-
verbose=show_all_test_output,
151-
stop_on_failure=args.stop_on_failure,
152-
failures_only=(not args.quiet) and not show_all_test_output,
153-
)
153+
154+
def _run():
155+
return run_tests(
156+
pattern=args.pattern,
157+
category=args.category,
158+
verbose=show_all_test_output,
159+
stop_on_failure=args.stop_on_failure,
160+
failures_only=(not args.quiet) and not show_all_test_output,
161+
)
162+
163+
if args.mcp_mode:
164+
from ida_pro_mcp.ida_mcp.tests.mcp_mode import mcp_mode
165+
166+
print("[MCP] Running tests in end-to-end MCP mode.")
167+
with mcp_mode():
168+
results = _run()
169+
else:
170+
results = _run()
154171

155172
# No matched tests is likely a configuration/test-selection mistake
156173
if not results.results:

0 commit comments

Comments
 (0)