Skip to content

Commit 6b7e845

Browse files
feat: add PSI baseline collector tool for GB10 calibration
- Samples /proc/pressure/memory and /proc/pressure/io every second - Saves JSON + human-readable log to ~/sparkview_logs/psi_baseline/ - Summary stats: min, max, mean, p90, p99 per channel - System info: driver, GPU, hostname, kernel, memory - Three labels: idle, vllm_loaded, inference_running
1 parent 4e53963 commit 6b7e845

1 file changed

Lines changed: 227 additions & 0 deletions

File tree

tools/collect_psi_baseline.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
#!/usr/bin/env python3
2+
"""
3+
collect_psi_baseline.py — sparkview PSI baseline collector for GB10 calibration.
4+
5+
Samples /proc/pressure/memory and /proc/pressure/io every second for a set
6+
duration and saves timestamped JSON output for community calibration.
7+
8+
Usage:
9+
python3 collect_psi_baseline.py --duration 120 --label idle
10+
python3 collect_psi_baseline.py --duration 120 --label vllm_loaded
11+
python3 collect_psi_baseline.py --duration 120 --label inference_running
12+
13+
Output:
14+
~/sparkview_logs/psi_baseline/sparkview_psi_baseline_<label>_<timestamp>.json
15+
"""
16+
17+
from __future__ import annotations
18+
19+
import argparse
20+
import json
21+
import os
22+
import pathlib
23+
import platform
24+
import subprocess
25+
import time
26+
from datetime import datetime
27+
28+
29+
# ── PSI paths ────────────────────────────────────────────────────────────────
30+
PSI_MEM = pathlib.Path("/proc/pressure/memory")
31+
PSI_IO = pathlib.Path("/proc/pressure/io")
32+
LOG_DIR = pathlib.Path.home() / "sparkview_logs" / "psi_baseline"
33+
34+
35+
def _parse_psi(path: pathlib.Path) -> dict:
36+
try:
37+
lines = path.read_text().strip().splitlines()
38+
result = {}
39+
for line in lines:
40+
parts = line.split()
41+
kind = parts[0] # "some" or "full"
42+
kv = {p.split("=")[0]: float(p.split("=")[1]) for p in parts[1:]}
43+
result[kind] = kv
44+
return result
45+
except (OSError, ValueError, IndexError):
46+
return {}
47+
48+
49+
def _system_info() -> dict:
50+
info = {
51+
"hostname": platform.node(),
52+
"kernel": platform.release(),
53+
"collected": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
54+
}
55+
try:
56+
out = subprocess.check_output(
57+
["nvidia-smi", "--query-gpu=driver_version,name",
58+
"--format=csv,noheader"],
59+
text=True, timeout=5
60+
).strip().splitlines()[0]
61+
driver, gpu = [x.strip() for x in out.split(",")]
62+
info["driver"] = driver
63+
info["gpu"] = gpu
64+
except Exception:
65+
info["driver"] = "unknown"
66+
info["gpu"] = "unknown"
67+
try:
68+
mem = pathlib.Path("/proc/meminfo").read_text()
69+
for line in mem.splitlines():
70+
if line.startswith("MemTotal:"):
71+
info["mem_total_gb"] = round(int(line.split()[1]) / (1024**2), 1)
72+
if line.startswith("MemAvailable:"):
73+
info["mem_available_gb"] = round(int(line.split()[1]) / (1024**2), 1)
74+
except OSError:
75+
pass
76+
return info
77+
78+
79+
def _stats(vals: list) -> dict:
80+
if not vals:
81+
return {}
82+
return {
83+
"min": round(min(vals), 4),
84+
"max": round(max(vals), 4),
85+
"mean": round(sum(vals) / len(vals), 4),
86+
"p90": round(sorted(vals)[int(len(vals) * 0.90)], 4),
87+
"p99": round(sorted(vals)[int(len(vals) * 0.99)], 4),
88+
}
89+
90+
91+
def collect(duration: int, label: str, interval: float = 1.0) -> str:
92+
LOG_DIR.mkdir(parents=True, exist_ok=True)
93+
94+
print(f"sparkview PSI baseline collector")
95+
print(f" label: {label}")
96+
print(f" duration: {duration}s")
97+
print(f" output: {LOG_DIR}")
98+
print(f" started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
99+
print()
100+
101+
if not PSI_MEM.exists():
102+
print("ERROR: /proc/pressure/memory not found — PSI not supported on this kernel")
103+
return ""
104+
if not PSI_IO.exists():
105+
print("ERROR: /proc/pressure/io not found — IO PSI not supported on this kernel")
106+
return ""
107+
108+
samples = []
109+
log_lines = []
110+
start = time.monotonic()
111+
n = 0
112+
113+
try:
114+
while time.monotonic() - start < duration:
115+
ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
116+
mem = _parse_psi(PSI_MEM)
117+
io = _parse_psi(PSI_IO)
118+
t = round(time.monotonic() - start, 1)
119+
120+
sample = {"t": t, "ts": ts, "mem": mem, "io": io}
121+
samples.append(sample)
122+
123+
mem_some = mem.get("some", {}).get("avg10", 0.0)
124+
mem_full = mem.get("full", {}).get("avg10", 0.0)
125+
io_some = io.get("some", {}).get("avg10", 0.0)
126+
io_full = io.get("full", {}).get("avg10", 0.0)
127+
128+
line = (
129+
f"{ts} t={t:6.1f}s "
130+
f"mem some={mem_some:.4f} full={mem_full:.4f} "
131+
f"io some={io_some:.4f} full={io_full:.4f}"
132+
)
133+
log_lines.append(line)
134+
n += 1
135+
136+
print(f" [{t:6.1f}s] mem some={mem_some:.4f} io some={io_some:.4f}", end="\r")
137+
time.sleep(interval)
138+
139+
except KeyboardInterrupt:
140+
print("\nInterrupted — saving collected samples...")
141+
142+
print(f"\n collected {n} samples")
143+
144+
# ── Stats ─────────────────────────────────────────────────────────────────
145+
mem_some_vals = [s["mem"].get("some", {}).get("avg10", 0) for s in samples]
146+
mem_full_vals = [s["mem"].get("full", {}).get("avg10", 0) for s in samples]
147+
io_some_vals = [s["io"].get("some", {}).get("avg10", 0) for s in samples]
148+
io_full_vals = [s["io"].get("full", {}).get("avg10", 0) for s in samples]
149+
150+
summary = {
151+
"mem_some": _stats(mem_some_vals),
152+
"mem_full": _stats(mem_full_vals),
153+
"io_some": _stats(io_some_vals),
154+
"io_full": _stats(io_full_vals),
155+
}
156+
157+
# ── Write JSON ────────────────────────────────────────────────────────────
158+
ts_file = datetime.now().strftime("%Y%m%d_%H%M%S")
159+
basename = f"sparkview_psi_baseline_{label}_{ts_file}"
160+
json_path = LOG_DIR / f"{basename}.json"
161+
log_path = LOG_DIR / f"{basename}.log"
162+
163+
output = {
164+
"tool": "sparkview_psi_baseline_collector",
165+
"version": "1.0.0",
166+
"label": label,
167+
"duration": duration,
168+
"samples": n,
169+
"system": _system_info(),
170+
"summary": summary,
171+
"data": samples,
172+
}
173+
174+
with open(json_path, "w") as f:
175+
json.dump(output, f, indent=2)
176+
177+
# ── Write human-readable log ──────────────────────────────────────────────
178+
with open(log_path, "w") as f:
179+
f.write(f"sparkview PSI baseline log\n")
180+
f.write(f"label: {label}\n")
181+
f.write(f"duration: {duration}s\n")
182+
f.write(f"samples: {n}\n")
183+
f.write(f"system: {platform.node()} / {platform.release()}\n")
184+
f.write(f"\n")
185+
f.write(f"{'timestamp':<22} {'t':>7} "
186+
f"{'mem_some':>10} {'mem_full':>10} "
187+
f"{'io_some':>10} {'io_full':>10}\n")
188+
f.write("-" * 80 + "\n")
189+
for line in log_lines:
190+
f.write(line + "\n")
191+
f.write("\n")
192+
f.write("Summary:\n")
193+
for key, st in summary.items():
194+
f.write(f" {key:<12} min={st.get('min','?')} max={st.get('max','?')} "
195+
f"mean={st.get('mean','?')} p90={st.get('p90','?')} "
196+
f"p99={st.get('p99','?')}\n")
197+
198+
print(f"\n json: {json_path}")
199+
print(f" log: {log_path}")
200+
print()
201+
print(" Summary:")
202+
for key, st in summary.items():
203+
print(f" {key:<12} min={st.get('min','?')} max={st.get('max','?')} "
204+
f"mean={st.get('mean','?')} p90={st.get('p90','?')}")
205+
206+
return str(json_path)
207+
208+
209+
if __name__ == "__main__":
210+
parser = argparse.ArgumentParser(
211+
description="sparkview PSI baseline collector — GB10 calibration"
212+
)
213+
parser.add_argument(
214+
"--duration", type=int, default=120,
215+
help="Collection duration in seconds (default: 120)"
216+
)
217+
parser.add_argument(
218+
"--label", type=str, default="idle",
219+
choices=["idle", "vllm_loaded", "inference_running", "post_inference", "custom"],
220+
help="Workload label for this collection run"
221+
)
222+
parser.add_argument(
223+
"--interval", type=float, default=1.0,
224+
help="Sample interval in seconds (default: 1.0)"
225+
)
226+
args = parser.parse_args()
227+
collect(args.duration, args.label, args.interval)

0 commit comments

Comments
 (0)