Skip to content

Commit 65c4a1c

Browse files
Remove need to call connect explicitly (StructuredLabs#653)
* Clean up managers/data.py for linter * Fixed connect idempotency issue - no longer recreating sources upon rerun * Remove need to call connect explicitly * Remove connect() from template and tutorial
1 parent 3bb68a4 commit 65c4a1c

11 files changed

Lines changed: 157 additions & 94 deletions

File tree

examples/earthquakes/hello.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,12 @@
11
import pandas as pd
22
import plotly.express as px
33

4-
from preswald import connect, get_df, plotly, slider, table, text
4+
from preswald import get_df, plotly, slider, table, text
55

66

77
# Title
88
text("# Earthquake Analytics Dashboard 🌍")
99

10-
# Load and connect data
11-
connect()
12-
13-
# ---
14-
1510
# Clickhouse section
1611

1712
query_string = """

examples/fires/hello.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
11
import pandas as pd
22
import plotly.express as px
33

4-
from preswald import connect, get_df, plotly, text
4+
from preswald import get_df, plotly, text
55

66

77
# Display the dashboard title
88
text("# Fire Incident Analytics Dashboard 🔥")
99

10-
# Connect to the data
11-
connect()
12-
1310
# Load and preprocess the data
1411
data = get_df("csv")
1512
data["incident_acres_burned"] = pd.to_numeric(

examples/iris/hello.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from preswald import (
77
chat,
8-
connect,
98
# fastplotlib,
109
get_df,
1110
plotly,
@@ -28,7 +27,6 @@
2827
)
2928

3029
# Load the CSV
31-
connect() # Load in all sources, which by default is the iris_csv
3230
df = get_df("iris_csv")
3331

3432
# 1. Scatter plot - Sepal Length vs Sepal Width

examples/user_event/hello.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,30 @@
1-
from preswald import text, plotly, connect, get_df, table
2-
import pandas as pd
31
import plotly.express as px
42

3+
from preswald import get_df, plotly, table, text
4+
5+
56
text("# Welcome to Preswald!")
67
text("This is your first app. 🎉")
78

89
# Load the JSON source defined as "user_events" in preswald.toml
9-
connect() # This loads all data sources, including our nested JSON source.
10-
df = get_df('user_events')
10+
df = get_df("user_events")
1111

1212
# Create a scatter plot using the flattened data.
1313
# Assuming the JSON file has been flattened to include "user" and "details.clicks"
1414
fig = px.scatter(
1515
df,
16-
x='user',
17-
y='details.clicks',
18-
text='user',
19-
title='User Events: Clicks per User',
20-
labels={'user': 'User', 'details.clicks': 'Clicks'}
16+
x="user",
17+
y="details.clicks",
18+
text="user",
19+
title="User Events: Clicks per User",
20+
labels={"user": "User", "details.clicks": "Clicks"},
2121
)
2222

2323
# Add labels for each point
24-
fig.update_traces(textposition='top center', marker=dict(size=12, color='lightblue'))
24+
fig.update_traces(textposition="top center", marker={"size": 12, "color": "lightblue"})
2525

2626
# Style the plot
27-
fig.update_layout(template='plotly_white')
27+
fig.update_layout(template="plotly_white")
2828

2929
# Display the plot
3030
plotly(fig)

preswald/browser/virtual_service.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import asyncio
66
import logging
77
import sys
8-
from typing import Any, Dict, Optional
8+
from typing import Any
99

1010
from preswald.engine.base_service import BasePreswaldService
1111
from preswald.engine.utils import RenderBuffer
@@ -37,7 +37,7 @@ def __init__(self, client_id: str):
3737
)
3838
console.log(f"[Communication] is browser mode: {self.is_browser_mode}")
3939

40-
async def send_json(self, data: Dict[str, Any]):
40+
async def send_json(self, data: dict[str, Any]):
4141
"""Send JSON data to JavaScript frontend"""
4242
if not self.is_connected:
4343
logger.error(f"Cannot send message, connection closed for {self.client_id}")
@@ -185,11 +185,14 @@ def handle_message_from_js(client_id, message_type, data):
185185
if logger.isEnabledFor(logging.DEBUG):
186186
logger.debug(f"Handling JS message from {client_id}: {message}")
187187

188-
asyncio.create_task(self.handle_client_message(client_id, message)) # noqa: RUF006
188+
asyncio.create_task(self.handle_client_message(client_id, message)) # noqa: RUF006
189189
return True
190190
except Exception:
191191
import traceback
192-
logger.error("Error in handle_message_from_js: %s", traceback.format_exc())
192+
193+
logger.error(
194+
"Error in handle_message_from_js: %s", traceback.format_exc()
195+
)
193196
return False
194197

195198
# Export the function to JavaScript

preswald/engine/base_service.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import logging
22
import os
33
import time
4+
from collections.abc import Callable
45
from threading import Lock
5-
from typing import Any, Callable, Dict, Optional
6+
from typing import Any
67

78
from preswald.engine.runner import ScriptRunner
89
from preswald.engine.utils import (
@@ -28,19 +29,19 @@ class BasePreswaldService:
2829
_not_initialized_msg = "Base service not initialized."
2930

3031
def __init__(self):
31-
self._component_states: Dict[str, Any] = {}
32+
self._component_states: dict[str, Any] = {}
3233
self._lock = Lock()
3334

3435
# Data management
35-
self.data_manager: Optional[DataManager] = None # set during server creation
36+
self.data_manager: DataManager | None = None # set during server creation
3637

3738
# Initialize service state
38-
self._script_path: Optional[str] = None
39+
self._script_path: str | None = None
3940
self._is_shutting_down: bool = False
4041
self._render_buffer = RenderBuffer()
4142

4243
# Initialize session tracking
43-
self.script_runners: Dict[str, ScriptRunner] = {}
44+
self.script_runners: dict[str, ScriptRunner] = {}
4445

4546
# Layout management
4647
self._layout_manager = LayoutManager()
@@ -61,7 +62,7 @@ def initialize(cls, script_path=None):
6162
return cls._instance
6263

6364
@property
64-
def script_path(self) -> Optional[str]:
65+
def script_path(self) -> str | None:
6566
return self._script_path
6667

6768
@script_path.setter
@@ -76,9 +77,10 @@ def script_path(self, path: str):
7677
def append_component(self, component):
7778
"""Add a component to the layout manager"""
7879
try:
79-
8080
if isinstance(component, dict):
81-
logger.info(f"[APPEND] Appending component: {component.get('id')}, type: {component.get('type')}")
81+
logger.info(
82+
f"[APPEND] Appending component: {component.get('id')}, type: {component.get('type')}"
83+
)
8284
# Clean any NaN values in the component
8385
clean_start = time.time()
8486
cleaned_component = clean_nan_values(component)
@@ -101,7 +103,9 @@ def append_component(self, component):
101103
)
102104
self._layout_manager.add_component(cleaned_component)
103105
if logger.isEnabledFor(logging.DEBUG):
104-
logger.debug(f"Added component with state: {cleaned_component}")
106+
logger.debug(
107+
f"Added component with state: {cleaned_component}"
108+
)
105109
else:
106110
# Components without IDs are added as-is
107111
self._layout_manager.add_component(cleaned_component)
@@ -137,7 +141,7 @@ def get_rendered_components(self):
137141
rows = self._layout_manager.get_layout()
138142
return {"rows": rows}
139143

140-
async def handle_client_message(self, client_id: str, message: Dict[str, Any]):
144+
async def handle_client_message(self, client_id: str, message: dict[str, Any]):
141145
"""Process incoming messages from clients"""
142146
start_time = time.time()
143147
try:
@@ -196,7 +200,7 @@ async def unregister_client(self, client_id: str):
196200
def _create_send_callback(self, websocket: Any) -> Callable:
197201
"""Create a message sending callback for a specific websocket"""
198202

199-
async def send_message(msg: Dict[str, Any]):
203+
async def send_message(msg: dict[str, Any]):
200204
if not self._is_shutting_down:
201205
try:
202206
await websocket.send_json(msg)
@@ -206,7 +210,7 @@ async def send_message(msg: Dict[str, Any]):
206210
return send_message
207211

208212
async def _broadcast_state_updates(
209-
self, states: Dict[str, Any], exclude_client: Optional[str] = None
213+
self, states: dict[str, Any], exclude_client: str | None = None
210214
):
211215
"""Broadcast state updates to all clients except the sender"""
212216

@@ -231,18 +235,15 @@ async def _broadcast_state_updates(
231235
except Exception as e:
232236
logger.error(f"Error broadcasting to {client_id}: {e}")
233237

234-
async def _handle_component_update(self, client_id: str, message: Dict[str, Any]):
238+
async def _handle_component_update(self, client_id: str, message: dict[str, Any]):
235239
"""Handle component state update messages"""
236240
states = message.get("states", {})
237241
if not states:
238242
await self._send_error(client_id, "Component update missing states")
239243
raise ValueError("Component update missing states")
240244

241245
# Only rerun if any state actually changed
242-
changed_states = {
243-
k: v for k, v in states.items()
244-
if self.should_render(k, v)
245-
}
246+
changed_states = {k: v for k, v in states.items() if self.should_render(k, v)}
246247

247248
if not changed_states:
248249
logger.debug("[STATE] No actual state changes detected. Skipping rerun.")
@@ -260,6 +261,10 @@ async def _handle_component_update(self, client_id: str, message: Dict[str, Any]
260261
# Broadcast updates to other clients
261262
await self._broadcast_state_updates(changed_states, exclude_client=client_id)
262263

264+
def connect_data_manager(self):
265+
"""Connect the data manager"""
266+
self.data_manager.connect()
267+
263268
def _initialize_data_manager(self, script_path: str) -> None:
264269
script_dir = os.path.dirname(script_path)
265270
preswald_path = os.path.join(script_dir, "preswald.toml")
@@ -269,7 +274,9 @@ def _initialize_data_manager(self, script_path: str) -> None:
269274
preswald_path=preswald_path, secrets_path=secrets_path
270275
)
271276

272-
async def _register_common_client_setup(self, client_id: str, websocket: Any) -> ScriptRunner:
277+
async def _register_common_client_setup(
278+
self, client_id: str, websocket: Any
279+
) -> ScriptRunner:
273280
logger.info(f"Registering client: {client_id}")
274281

275282
self.websocket_connections[client_id] = websocket
@@ -309,7 +316,7 @@ async def _send_initial_states(self, websocket: Any):
309316
except Exception as e:
310317
logger.error(f"Error sending initial states: {e}")
311318

312-
def _update_component_states(self, states: Dict[str, Any]):
319+
def _update_component_states(self, states: dict[str, Any]):
313320
"""Update internal state dictionary with cleaned component values."""
314321
with self._lock:
315322
logger.debug("[STATE] Updating states")

0 commit comments

Comments
 (0)