Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 75 additions & 48 deletions ostool-server/src/serial/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async fn run_serial_ws_inner(
.as_ref()
.ok_or_else(|| anyhow::anyhow!("board has no serial configuration"))?;
let resolved_serial = resolve_serial_config(serial)?;
let port = tokio_serial::new(&resolved_serial.current_device_path, serial.baud_rate)
let mut port = tokio_serial::new(&resolved_serial.current_device_path, serial.baud_rate)
.timeout(SERIAL_READ_TIMEOUT)
.open_native_async()
.with_context(|| {
Expand All @@ -61,6 +61,7 @@ async fn run_serial_ws_inner(
resolved_serial.current_device_path
)
})?;
clear_serial_input_after_open(&session_id, &mut port);

let (mut ws_sender, mut ws_receiver) = socket.split();
let (mut serial_rx, mut serial_tx) = tokio::io::split(port);
Expand Down Expand Up @@ -107,48 +108,6 @@ async fn run_serial_ws_inner(
break;
}
}
maybe_message = ws_receiver.next() => {
let Some(message) = maybe_message else {
break;
};
match message {
Ok(Message::Binary(bytes)) => {
write_serial_payload(&mut serial_tx, &bytes).await?;
}
Ok(Message::Text(text)) => {
let control: ClientControlMessage = serde_json::from_str(&text)?;
match control.kind.as_str() {
"close" => {
let _ = ws_sender
.send(Message::Text(r#"{"type":"closed"}"#.to_string().into()))
.await;
break;
}
"tx" => {
let Some(data) = control.data.as_deref() else {
anyhow::bail!("missing tx data");
};
let payload = match control.encoding.as_deref() {
Some("base64") => base64::engine::general_purpose::STANDARD
.decode(data)
.context("invalid base64 payload")?,
Some("utf8") | None => data.as_bytes().to_vec(),
Some(other) => anyhow::bail!("unsupported encoding `{other}`"),
};
write_serial_payload(&mut serial_tx, &payload).await?;
}
other => anyhow::bail!("unsupported websocket control type `{other}`"),
}
}
Ok(Message::Close(_)) => break,
Ok(Message::Ping(payload)) => {
ws_sender.send(Message::Pong(payload)).await.ok();
}
Ok(Message::Pong(_)) => {}
Err(err) => return Err(err.into()),
}
let _ = session.heartbeat().await;
}
read = serial_rx.read(&mut serial_buffer) => {
let read = read.context("serial read failed")?;
if read == 0 {
Expand Down Expand Up @@ -301,6 +260,25 @@ where
let _ = ws_sender.send(Message::Close(None)).await;
}

trait SerialOpenCleanup {
fn clear_input_buffer(&mut self) -> std::io::Result<()>;
}

impl SerialOpenCleanup for tokio_serial::SerialStream {
fn clear_input_buffer(&mut self) -> std::io::Result<()> {
self.clear(ClearBuffer::Input).map_err(std::io::Error::from)
}
}

fn clear_serial_input_after_open<T>(session_id: &str, port: &mut T)
where
T: SerialOpenCleanup + ?Sized,
{
if let Err(err) = port.clear_input_buffer() {
log::warn!("session `{session_id}` failed to clear serial input after open: {err}");
}
}

async fn write_serial_payload(
port: &mut tokio::io::WriteHalf<tokio_serial::SerialStream>,
payload: &[u8],
Expand Down Expand Up @@ -368,9 +346,10 @@ mod tests {
use tempfile::tempdir;

use super::{
ClientControlMessage, SerialQueueCleanup, cleanup_power_link,
cleanup_serial_queue_before_close, finalize_power_linked_session,
preserve_result_after_serial_cleanup, send_power_on_failure_and_close,
ClientControlMessage, SerialOpenCleanup, SerialQueueCleanup, cleanup_power_link,
cleanup_serial_queue_before_close, clear_serial_input_after_open,
finalize_power_linked_session, preserve_result_after_serial_cleanup,
send_power_on_failure_and_close,
};
use crate::{
build_app_state,
Expand All @@ -389,15 +368,31 @@ mod tests {

#[derive(Debug, Clone, PartialEq, Eq)]
enum CleanupEvent {
ClearInput,
Flush,
ClearAll,
}

struct RecordingSerialOpenCleanup {
events: Arc<Mutex<Vec<CleanupEvent>>>,
clear_result: io::Result<()>,
}

struct RecordingSerialCleanup {
events: Arc<Mutex<Vec<CleanupEvent>>>,
clear_result: io::Result<()>,
}

impl SerialOpenCleanup for RecordingSerialOpenCleanup {
fn clear_input_buffer(&mut self) -> io::Result<()> {
self.events.lock().unwrap().push(CleanupEvent::ClearInput);
self.clear_result
.as_ref()
.map(|_| ())
.map_err(|err| io::Error::new(err.kind(), err.to_string()))
}
}

#[async_trait::async_trait]
impl SerialQueueCleanup for RecordingSerialCleanup {
async fn flush_output(&mut self) -> io::Result<()> {
Expand Down Expand Up @@ -450,6 +445,38 @@ mod tests {
assert_eq!(message.kind, "close");
}

#[test]
fn serial_open_cleanup_clears_only_input_buffer() {
let events = Arc::new(Mutex::new(Vec::new()));
let mut cleanup = RecordingSerialOpenCleanup {
events: events.clone(),
clear_result: Ok(()),
};

clear_serial_input_after_open("session-1", &mut cleanup);

assert_eq!(
events.lock().unwrap().as_slice(),
&[CleanupEvent::ClearInput]
);
}

#[test]
fn serial_open_cleanup_does_not_fail_session_on_clear_error() {
let events = Arc::new(Mutex::new(Vec::new()));
let mut cleanup = RecordingSerialOpenCleanup {
events: events.clone(),
clear_result: Err(io::Error::other("clear failed")),
};

clear_serial_input_after_open("session-1", &mut cleanup);

assert_eq!(
events.lock().unwrap().as_slice(),
&[CleanupEvent::ClearInput]
);
}

#[tokio::test]
async fn serial_cleanup_flushes_before_clearing_all_buffers() {
let events = Arc::new(Mutex::new(Vec::new()));
Expand All @@ -473,7 +500,7 @@ mod tests {
let events = Arc::new(Mutex::new(Vec::new()));
let mut cleanup = RecordingSerialCleanup {
events: events.clone(),
clear_result: Err(io::Error::new(io::ErrorKind::Other, "clear failed")),
clear_result: Err(io::Error::other("clear failed")),
};

let err = cleanup_serial_queue_before_close(&mut cleanup)
Expand All @@ -492,7 +519,7 @@ mod tests {
let events = Arc::new(Mutex::new(Vec::new()));
let mut cleanup = RecordingSerialCleanup {
events,
clear_result: Err(io::Error::new(io::ErrorKind::Other, "clear failed")),
clear_result: Err(io::Error::other("clear failed")),
};

let err = preserve_result_after_serial_cleanup::<(), _>(
Expand Down
3 changes: 2 additions & 1 deletion ostool-server/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use tokio::sync::{RwLock, mpsc, watch};

use crate::{config::BoardConfig, state::AppState};

pub const SESSION_TTL: Duration = Duration::seconds(2);
pub const SESSION_TTL: Duration = Duration::seconds(10);

const SESSION_STATE_ACTIVE: u8 = 0;
const SESSION_STATE_RELEASING: u8 = 1;
Expand Down Expand Up @@ -305,6 +305,7 @@ mod tests {
fn session_new_uses_fixed_ttl() {
let session = Session::new("demo".into(), Some("client".into()));
assert_eq!(session.expires_at - session.created_at, SESSION_TTL);
assert!(SESSION_TTL >= chrono::Duration::seconds(10));
assert_eq!(session.last_heartbeat_at, session.created_at);
assert_eq!(session.state, SessionLifecycleState::Active);
}
Expand Down
124 changes: 120 additions & 4 deletions ostool-server/tests/session_ws_lifecycle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#![cfg(unix)]

use std::{
io::Write,
io::{Read, Write},
net::SocketAddr,
path::Path,
sync::{Arc, mpsc},
Expand Down Expand Up @@ -60,7 +60,7 @@ struct SessionCreatedResponse {
ws_url: Option<String>,
}

fn sample_board(serial_port: String) -> BoardConfig {
fn sample_board_with_power_on(serial_port: String, power_on_cmd: String) -> BoardConfig {
BoardConfig {
id: TEST_BOARD_ID.into(),
board_type: TEST_BOARD_TYPE.into(),
Expand All @@ -75,7 +75,7 @@ fn sample_board(serial_port: String) -> BoardConfig {
resolved_usb_path: None,
}),
power_management: PowerManagementConfig::Custom(CustomPowerManagement {
power_on_cmd: "true".into(),
power_on_cmd,
power_off_cmd: "true".into(),
}),
boot: BootConfig::Uboot(UbootProfile {
Expand All @@ -90,6 +90,14 @@ fn sample_board(serial_port: String) -> BoardConfig {

/// Starts an in-process ostool-server with one board and PTY serial port.
fn spawn_test_server(root: &Path, serial_port: String) -> Result<TestServerHandle> {
spawn_test_server_with_power_on(root, serial_port, "true".into())
}

fn spawn_test_server_with_power_on(
root: &Path,
serial_port: String,
power_on_cmd: String,
) -> Result<TestServerHandle> {
let config_path = root.join("config.toml");
let data_dir = root.join("data");
let board_dir = root.join("boards");
Expand Down Expand Up @@ -117,7 +125,7 @@ fn spawn_test_server(root: &Path, serial_port: String) -> Result<TestServerHandl
std::fs::write(&config_path, toml::to_string_pretty(&config)?)
.with_context(|| format!("failed to write {}", config_path.display()))?;

let board = sample_board(serial_port);
let board = sample_board_with_power_on(serial_port, power_on_cmd);
let board_path = board_dir.join(format!("{}.toml", board.id));
std::fs::write(&board_path, toml::to_string_pretty(&board)?)
.with_context(|| format!("failed to write {}", board_path.display()))?;
Expand Down Expand Up @@ -193,6 +201,109 @@ fn spawn_test_server(root: &Path, serial_port: String) -> Result<TestServerHandl
})
}

fn run_delayed_client_write_case() -> Result<()> {
let temp = tempfile::tempdir().context("failed to create tempdir")?;
let gate_path = temp.path().join("power-on-ready");
let power_on_cmd = format!(
"while [ ! -f '{}' ]; do sleep 0.05; done",
gate_path.display()
);
let (mut serial_master, mut serial_handle) =
TTYPort::pair().context("failed to create PTY pair")?;
serial_handle
.set_exclusive(false)
.context("failed to disable PTY exclusivity")?;
serial_master
.set_timeout(POLL_INTERVAL)
.context("failed to configure PTY timeout")?;
let serial_port = serial_handle.name().context("failed to get PTY path")?;
drop(serial_handle);

let server = spawn_test_server_with_power_on(temp.path(), serial_port, power_on_cmd)?;
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.context("failed to build client runtime")?;

let client = reqwest::Client::new();
let (created, mut websocket) = runtime.block_on(async {
wait_for_server_ready(&client, &server.base_url).await?;
let created = create_session(&client, &server.base_url).await?;
let ws_url = resolve_ws_url(
&server.base_url,
created.ws_url.as_deref().context("missing websocket URL")?,
)?;
let (mut websocket, _) = tokio_tungstenite::connect_async(ws_url.as_str())
.await
.with_context(|| format!("failed to connect websocket {ws_url}"))?;
wait_for_opened(&mut websocket).await?;
websocket
.send(Message::Binary(b"early-input".to_vec().into()))
.await
.context("failed to send early websocket input")?;
Ok::<_, anyhow::Error>((created, websocket))
})?;

assert_no_serial_payload(&mut serial_master, Duration::from_millis(300))?;

std::fs::write(&gate_path, b"ready").context("failed to release power-on gate")?;
let payload = read_serial_master_payload(&mut serial_master, b"early-input")?;
assert_eq!(payload, b"early-input");

runtime.block_on(async {
websocket
.send(Message::Text(r#"{"type":"close"}"#.to_string().into()))
.await
.context("failed to send websocket close control message")?;
wait_for_closed(&mut websocket).await?;
wait_for_session_release(&client, &server.base_url, &created.session_id).await
})?;
server.shutdown()
}

fn assert_no_serial_payload(port: &mut TTYPort, duration: Duration) -> Result<()> {
let deadline = Instant::now() + duration;
let mut buffer = [0u8; 64];
loop {
match port.read(&mut buffer) {
Ok(read) if read > 0 => bail!(
"serial received payload before power-on completed: {:?}",
&buffer[..read]
),
Ok(_) => {}
Err(err) if err.kind() == std::io::ErrorKind::TimedOut => {}
Err(err) => return Err(err).context("failed to read PTY while checking early input"),
}
if Instant::now() >= deadline {
return Ok(());
}
}
}

fn read_serial_master_payload(port: &mut TTYPort, expected: &[u8]) -> Result<Vec<u8>> {
let deadline = Instant::now() + Duration::from_secs(2);
let mut payload = Vec::new();
let mut buffer = [0u8; 64];
while Instant::now() < deadline {
match port.read(&mut buffer) {
Ok(read) if read > 0 => {
payload.extend_from_slice(&buffer[..read]);
if payload.len() >= expected.len() {
return Ok(payload);
}
}
Ok(_) => {}
Err(err) if err.kind() == std::io::ErrorKind::TimedOut => {}
Err(err) => return Err(err).context("failed to read PTY serial payload"),
}
}
bail!(
"timed out waiting for PTY serial payload `{}`; got `{}`",
String::from_utf8_lossy(expected),
String::from_utf8_lossy(&payload)
)
}

fn run_ws_lifecycle_case(mode: ClientShutdownMode) -> Result<()> {
let temp = tempfile::tempdir().context("failed to create tempdir")?;
let (mut serial_master, mut serial_handle) =
Expand Down Expand Up @@ -463,3 +574,8 @@ fn graceful_ws_close_powers_off_and_releases_session() -> Result<()> {
fn abrupt_ws_drop_powers_off_and_releases_session() -> Result<()> {
run_ws_lifecycle_case(ClientShutdownMode::AbruptDrop)
}

#[test]
fn websocket_buffers_client_serial_input_until_power_on_finishes() -> Result<()> {
run_delayed_client_write_case()
}
Loading