diff --git a/ostool-server/src/serial/ws.rs b/ostool-server/src/serial/ws.rs index f461dd18..04940b64 100644 --- a/ostool-server/src/serial/ws.rs +++ b/ostool-server/src/serial/ws.rs @@ -52,7 +52,7 @@ async fn run_serial_ws_inner( .as_ref() .ok_or_else(|| anyhow::anyhow!("board has no serial configuration"))?; let resolved_serial = resolve_serial_config(serial)?; - let port = tokio_serial::new(&resolved_serial.current_device_path, serial.baud_rate) + let mut port = tokio_serial::new(&resolved_serial.current_device_path, serial.baud_rate) .timeout(SERIAL_READ_TIMEOUT) .open_native_async() .with_context(|| { @@ -61,6 +61,7 @@ async fn run_serial_ws_inner( resolved_serial.current_device_path ) })?; + clear_serial_input_after_open(&session_id, &mut port); let (mut ws_sender, mut ws_receiver) = socket.split(); let (mut serial_rx, mut serial_tx) = tokio::io::split(port); @@ -107,48 +108,6 @@ async fn run_serial_ws_inner( break; } } - maybe_message = ws_receiver.next() => { - let Some(message) = maybe_message else { - break; - }; - match message { - Ok(Message::Binary(bytes)) => { - write_serial_payload(&mut serial_tx, &bytes).await?; - } - Ok(Message::Text(text)) => { - let control: ClientControlMessage = serde_json::from_str(&text)?; - match control.kind.as_str() { - "close" => { - let _ = ws_sender - .send(Message::Text(r#"{"type":"closed"}"#.to_string().into())) - .await; - break; - } - "tx" => { - let Some(data) = control.data.as_deref() else { - anyhow::bail!("missing tx data"); - }; - let payload = match control.encoding.as_deref() { - Some("base64") => base64::engine::general_purpose::STANDARD - .decode(data) - .context("invalid base64 payload")?, - Some("utf8") | None => data.as_bytes().to_vec(), - Some(other) => anyhow::bail!("unsupported encoding `{other}`"), - }; - write_serial_payload(&mut serial_tx, &payload).await?; - } - other => anyhow::bail!("unsupported websocket control type `{other}`"), - } - } - Ok(Message::Close(_)) => break, - Ok(Message::Ping(payload)) => { - ws_sender.send(Message::Pong(payload)).await.ok(); - } - Ok(Message::Pong(_)) => {} - Err(err) => return Err(err.into()), - } - let _ = session.heartbeat().await; - } read = serial_rx.read(&mut serial_buffer) => { let read = read.context("serial read failed")?; if read == 0 { @@ -301,6 +260,25 @@ where let _ = ws_sender.send(Message::Close(None)).await; } +trait SerialOpenCleanup { + fn clear_input_buffer(&mut self) -> std::io::Result<()>; +} + +impl SerialOpenCleanup for tokio_serial::SerialStream { + fn clear_input_buffer(&mut self) -> std::io::Result<()> { + self.clear(ClearBuffer::Input).map_err(std::io::Error::from) + } +} + +fn clear_serial_input_after_open(session_id: &str, port: &mut T) +where + T: SerialOpenCleanup + ?Sized, +{ + if let Err(err) = port.clear_input_buffer() { + log::warn!("session `{session_id}` failed to clear serial input after open: {err}"); + } +} + async fn write_serial_payload( port: &mut tokio::io::WriteHalf, payload: &[u8], @@ -368,9 +346,10 @@ mod tests { use tempfile::tempdir; use super::{ - ClientControlMessage, SerialQueueCleanup, cleanup_power_link, - cleanup_serial_queue_before_close, finalize_power_linked_session, - preserve_result_after_serial_cleanup, send_power_on_failure_and_close, + ClientControlMessage, SerialOpenCleanup, SerialQueueCleanup, cleanup_power_link, + cleanup_serial_queue_before_close, clear_serial_input_after_open, + finalize_power_linked_session, preserve_result_after_serial_cleanup, + send_power_on_failure_and_close, }; use crate::{ build_app_state, @@ -389,15 +368,31 @@ mod tests { #[derive(Debug, Clone, PartialEq, Eq)] enum CleanupEvent { + ClearInput, Flush, ClearAll, } + struct RecordingSerialOpenCleanup { + events: Arc>>, + clear_result: io::Result<()>, + } + struct RecordingSerialCleanup { events: Arc>>, clear_result: io::Result<()>, } + impl SerialOpenCleanup for RecordingSerialOpenCleanup { + fn clear_input_buffer(&mut self) -> io::Result<()> { + self.events.lock().unwrap().push(CleanupEvent::ClearInput); + self.clear_result + .as_ref() + .map(|_| ()) + .map_err(|err| io::Error::new(err.kind(), err.to_string())) + } + } + #[async_trait::async_trait] impl SerialQueueCleanup for RecordingSerialCleanup { async fn flush_output(&mut self) -> io::Result<()> { @@ -450,6 +445,38 @@ mod tests { assert_eq!(message.kind, "close"); } + #[test] + fn serial_open_cleanup_clears_only_input_buffer() { + let events = Arc::new(Mutex::new(Vec::new())); + let mut cleanup = RecordingSerialOpenCleanup { + events: events.clone(), + clear_result: Ok(()), + }; + + clear_serial_input_after_open("session-1", &mut cleanup); + + assert_eq!( + events.lock().unwrap().as_slice(), + &[CleanupEvent::ClearInput] + ); + } + + #[test] + fn serial_open_cleanup_does_not_fail_session_on_clear_error() { + let events = Arc::new(Mutex::new(Vec::new())); + let mut cleanup = RecordingSerialOpenCleanup { + events: events.clone(), + clear_result: Err(io::Error::other("clear failed")), + }; + + clear_serial_input_after_open("session-1", &mut cleanup); + + assert_eq!( + events.lock().unwrap().as_slice(), + &[CleanupEvent::ClearInput] + ); + } + #[tokio::test] async fn serial_cleanup_flushes_before_clearing_all_buffers() { let events = Arc::new(Mutex::new(Vec::new())); @@ -473,7 +500,7 @@ mod tests { let events = Arc::new(Mutex::new(Vec::new())); let mut cleanup = RecordingSerialCleanup { events: events.clone(), - clear_result: Err(io::Error::new(io::ErrorKind::Other, "clear failed")), + clear_result: Err(io::Error::other("clear failed")), }; let err = cleanup_serial_queue_before_close(&mut cleanup) @@ -492,7 +519,7 @@ mod tests { let events = Arc::new(Mutex::new(Vec::new())); let mut cleanup = RecordingSerialCleanup { events, - clear_result: Err(io::Error::new(io::ErrorKind::Other, "clear failed")), + clear_result: Err(io::Error::other("clear failed")), }; let err = preserve_result_after_serial_cleanup::<(), _>( diff --git a/ostool-server/src/session.rs b/ostool-server/src/session.rs index ec3d3a8f..e07287a9 100644 --- a/ostool-server/src/session.rs +++ b/ostool-server/src/session.rs @@ -9,7 +9,7 @@ use tokio::sync::{RwLock, mpsc, watch}; use crate::{config::BoardConfig, state::AppState}; -pub const SESSION_TTL: Duration = Duration::seconds(2); +pub const SESSION_TTL: Duration = Duration::seconds(10); const SESSION_STATE_ACTIVE: u8 = 0; const SESSION_STATE_RELEASING: u8 = 1; @@ -305,6 +305,7 @@ mod tests { fn session_new_uses_fixed_ttl() { let session = Session::new("demo".into(), Some("client".into())); assert_eq!(session.expires_at - session.created_at, SESSION_TTL); + assert!(SESSION_TTL >= chrono::Duration::seconds(10)); assert_eq!(session.last_heartbeat_at, session.created_at); assert_eq!(session.state, SessionLifecycleState::Active); } diff --git a/ostool-server/tests/session_ws_lifecycle.rs b/ostool-server/tests/session_ws_lifecycle.rs index 721a4a33..03f0f406 100644 --- a/ostool-server/tests/session_ws_lifecycle.rs +++ b/ostool-server/tests/session_ws_lifecycle.rs @@ -2,7 +2,7 @@ #![cfg(unix)] use std::{ - io::Write, + io::{Read, Write}, net::SocketAddr, path::Path, sync::{Arc, mpsc}, @@ -60,7 +60,7 @@ struct SessionCreatedResponse { ws_url: Option, } -fn sample_board(serial_port: String) -> BoardConfig { +fn sample_board_with_power_on(serial_port: String, power_on_cmd: String) -> BoardConfig { BoardConfig { id: TEST_BOARD_ID.into(), board_type: TEST_BOARD_TYPE.into(), @@ -75,7 +75,7 @@ fn sample_board(serial_port: String) -> BoardConfig { resolved_usb_path: None, }), power_management: PowerManagementConfig::Custom(CustomPowerManagement { - power_on_cmd: "true".into(), + power_on_cmd, power_off_cmd: "true".into(), }), boot: BootConfig::Uboot(UbootProfile { @@ -90,6 +90,14 @@ fn sample_board(serial_port: String) -> BoardConfig { /// Starts an in-process ostool-server with one board and PTY serial port. fn spawn_test_server(root: &Path, serial_port: String) -> Result { + spawn_test_server_with_power_on(root, serial_port, "true".into()) +} + +fn spawn_test_server_with_power_on( + root: &Path, + serial_port: String, + power_on_cmd: String, +) -> Result { let config_path = root.join("config.toml"); let data_dir = root.join("data"); let board_dir = root.join("boards"); @@ -117,7 +125,7 @@ fn spawn_test_server(root: &Path, serial_port: String) -> Result Result Result<()> { + let temp = tempfile::tempdir().context("failed to create tempdir")?; + let gate_path = temp.path().join("power-on-ready"); + let power_on_cmd = format!( + "while [ ! -f '{}' ]; do sleep 0.05; done", + gate_path.display() + ); + let (mut serial_master, mut serial_handle) = + TTYPort::pair().context("failed to create PTY pair")?; + serial_handle + .set_exclusive(false) + .context("failed to disable PTY exclusivity")?; + serial_master + .set_timeout(POLL_INTERVAL) + .context("failed to configure PTY timeout")?; + let serial_port = serial_handle.name().context("failed to get PTY path")?; + drop(serial_handle); + + let server = spawn_test_server_with_power_on(temp.path(), serial_port, power_on_cmd)?; + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .context("failed to build client runtime")?; + + let client = reqwest::Client::new(); + let (created, mut websocket) = runtime.block_on(async { + wait_for_server_ready(&client, &server.base_url).await?; + let created = create_session(&client, &server.base_url).await?; + let ws_url = resolve_ws_url( + &server.base_url, + created.ws_url.as_deref().context("missing websocket URL")?, + )?; + let (mut websocket, _) = tokio_tungstenite::connect_async(ws_url.as_str()) + .await + .with_context(|| format!("failed to connect websocket {ws_url}"))?; + wait_for_opened(&mut websocket).await?; + websocket + .send(Message::Binary(b"early-input".to_vec().into())) + .await + .context("failed to send early websocket input")?; + Ok::<_, anyhow::Error>((created, websocket)) + })?; + + assert_no_serial_payload(&mut serial_master, Duration::from_millis(300))?; + + std::fs::write(&gate_path, b"ready").context("failed to release power-on gate")?; + let payload = read_serial_master_payload(&mut serial_master, b"early-input")?; + assert_eq!(payload, b"early-input"); + + runtime.block_on(async { + websocket + .send(Message::Text(r#"{"type":"close"}"#.to_string().into())) + .await + .context("failed to send websocket close control message")?; + wait_for_closed(&mut websocket).await?; + wait_for_session_release(&client, &server.base_url, &created.session_id).await + })?; + server.shutdown() +} + +fn assert_no_serial_payload(port: &mut TTYPort, duration: Duration) -> Result<()> { + let deadline = Instant::now() + duration; + let mut buffer = [0u8; 64]; + loop { + match port.read(&mut buffer) { + Ok(read) if read > 0 => bail!( + "serial received payload before power-on completed: {:?}", + &buffer[..read] + ), + Ok(_) => {} + Err(err) if err.kind() == std::io::ErrorKind::TimedOut => {} + Err(err) => return Err(err).context("failed to read PTY while checking early input"), + } + if Instant::now() >= deadline { + return Ok(()); + } + } +} + +fn read_serial_master_payload(port: &mut TTYPort, expected: &[u8]) -> Result> { + let deadline = Instant::now() + Duration::from_secs(2); + let mut payload = Vec::new(); + let mut buffer = [0u8; 64]; + while Instant::now() < deadline { + match port.read(&mut buffer) { + Ok(read) if read > 0 => { + payload.extend_from_slice(&buffer[..read]); + if payload.len() >= expected.len() { + return Ok(payload); + } + } + Ok(_) => {} + Err(err) if err.kind() == std::io::ErrorKind::TimedOut => {} + Err(err) => return Err(err).context("failed to read PTY serial payload"), + } + } + bail!( + "timed out waiting for PTY serial payload `{}`; got `{}`", + String::from_utf8_lossy(expected), + String::from_utf8_lossy(&payload) + ) +} + fn run_ws_lifecycle_case(mode: ClientShutdownMode) -> Result<()> { let temp = tempfile::tempdir().context("failed to create tempdir")?; let (mut serial_master, mut serial_handle) = @@ -463,3 +574,8 @@ fn graceful_ws_close_powers_off_and_releases_session() -> Result<()> { fn abrupt_ws_drop_powers_off_and_releases_session() -> Result<()> { run_ws_lifecycle_case(ClientShutdownMode::AbruptDrop) } + +#[test] +fn websocket_buffers_client_serial_input_until_power_on_finishes() -> Result<()> { + run_delayed_client_write_case() +} diff --git a/uboot-shell/src/lib.rs b/uboot-shell/src/lib.rs index 4d816067..27181000 100644 --- a/uboot-shell/src/lib.rs +++ b/uboot-shell/src/lib.rs @@ -5,7 +5,7 @@ extern crate log; use std::{ io::{Error, ErrorKind, Result, stdout}, - path::PathBuf, + path::{Path, PathBuf}, pin::Pin, task::{Context, Poll}, time::Duration, @@ -34,6 +34,8 @@ macro_rules! dbg { const CTRL_C: u8 = 0x03; const INT_STR: &str = ""; const INT: &[u8] = INT_STR.as_bytes(); +const LOADY_MAX_ATTEMPTS: usize = 3; +const LOADY_RETRY_DELAY: Duration = Duration::from_millis(300); type Tx = Box; type Rx = Box; @@ -279,18 +281,52 @@ impl UbootShell { file: impl Into, on_progress: impl Fn(usize, usize), ) -> Result { + let file = file.into(); + + for attempt in 1..=LOADY_MAX_ATTEMPTS { + match self.loady_once(addr, &file, &on_progress).await { + Ok(reply) => return Ok(reply), + Err(err) if attempt < LOADY_MAX_ATTEMPTS => { + warn!( + "loady attempt {attempt}/{LOADY_MAX_ATTEMPTS} failed: {err}; retrying..." + ); + self.wait_for_shell().await.map_err(|recover_err| { + Error::other(format!( + "loady attempt {attempt} failed and shell recovery failed: {recover_err}", + )) + })?; + Delay::new(LOADY_RETRY_DELAY).await; + } + Err(err) => { + return Err(Error::other(format!( + "loady failed after {LOADY_MAX_ATTEMPTS} attempts: {err}" + ))); + } + } + } + + unreachable!("LOADY_MAX_ATTEMPTS must be greater than zero") + } + + async fn loady_once( + &mut self, + addr: usize, + file: &Path, + on_progress: &impl Fn(usize, usize), + ) -> Result { + self.clear_shell().await?; self.cmd_without_reply(&format!("loady {addr:#x}")).await?; let crc = self.wait_for_load_crc().await?; let mut protocol = ymodem::Ymodem::new(crc); - let file = file.into(); let name = file .file_name() .and_then(|name| name.to_str()) .ok_or_else(|| Error::new(ErrorKind::InvalidInput, "file name must be valid UTF-8"))?; - let size = std::fs::metadata(&file)?.len() as usize; - let mut file = AllowStdIo::new(std::fs::File::open(&file)?); + let size = std::fs::metadata(file)?.len() as usize; + let mut file = AllowStdIo::new(std::fs::File::open(file)?); + on_progress(0, size); protocol .send(self, &mut file, name, size, |sent| on_progress(sent, size)) .await?; @@ -386,3 +422,160 @@ fn print_raw_win(buff: &[u8]) { g.clear(); } } + +#[cfg(test)] +mod tests { + use super::*; + use std::{ + collections::VecDeque, + fs, + sync::{Arc, Mutex}, + }; + + #[derive(Default)] + struct LoadyScript { + reads: VecDeque, + writes: Vec, + command: Vec, + loady_count: usize, + interrupted: bool, + accepting_commands: bool, + } + + impl LoadyScript { + fn queue_read(&mut self, bytes: impl AsRef<[u8]>) { + self.reads.extend(bytes.as_ref()); + } + + fn handle_write(&mut self, bytes: &[u8]) { + self.writes.extend_from_slice(bytes); + + if bytes == [CTRL_C] { + self.command.clear(); + self.accepting_commands = true; + if !self.interrupted { + self.interrupted = true; + self.queue_read(b"=> \n"); + } + return; + } + + if !self.accepting_commands { + return; + } + + for &byte in bytes { + self.command.push(byte); + if byte == b'\n' { + let command = std::mem::take(&mut self.command); + if command.starts_with(b"loady ") { + self.loady_count += 1; + self.accepting_commands = false; + self.queue_loady_response(); + } + } else if self.command.len() > 256 { + self.command.clear(); + } + } + } + + fn queue_loady_response(&mut self) { + match self.loady_count { + 1 => { + self.queue_read(*b"C"); + self.queue_read([ymodem::CRC; ymodem::DEFAULT_BLOCK_RETRIES]); + } + 2 => { + self.queue_read(*b"C"); + self.queue_read([ymodem::ACK, ymodem::ACK, ymodem::ACK, ymodem::ACK, b'C']); + self.queue_read(b"done\n=> "); + } + _ => {} + } + } + } + + #[derive(Clone)] + struct ScriptedTx { + script: Arc>, + } + + #[derive(Clone)] + struct ScriptedRx { + script: Arc>, + } + + impl AsyncWrite for ScriptedTx { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.script.lock().unwrap().handle_write(buf); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + } + + impl AsyncRead for ScriptedRx { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let mut script = self.script.lock().unwrap(); + if script.reads.is_empty() { + return Poll::Pending; + } + + let n = buf.len().min(script.reads.len()); + for slot in &mut buf[..n] { + *slot = script.reads.pop_front().unwrap(); + } + Poll::Ready(Ok(n)) + } + } + + #[tokio::test] + async fn loady_restarts_transfer_after_receiver_rejects_first_attempt() -> Result<()> { + let script = Arc::new(Mutex::new(LoadyScript::default())); + script.lock().unwrap().accepting_commands = true; + let mut shell = UbootShell { + tx: Some(Box::new(ScriptedTx { + script: script.clone(), + })), + rx: Some(Box::new(ScriptedRx { + script: script.clone(), + })), + perfix: "=> ".to_string(), + }; + + let file = + std::env::temp_dir().join(format!("uboot-shell-loady-retry-{}", std::process::id())); + fs::write(&file, b"payload")?; + + let progress = Arc::new(Mutex::new(Vec::new())); + let reply = shell + .loady(0x80200000, file.clone(), { + let progress = progress.clone(); + move |sent, size| progress.lock().unwrap().push((sent, size)) + }) + .await; + let _ = fs::remove_file(&file); + + assert!(reply?.contains("done")); + let script = script.lock().unwrap(); + let writes = String::from_utf8_lossy(&script.writes); + assert_eq!(writes.matches("loady 0x80200000").count(), 2); + assert!(script.writes.contains(&CTRL_C)); + assert_eq!(*progress.lock().unwrap(), vec![(0, 7), (0, 7), (7, 7)]); + Ok(()) + } +} diff --git a/uboot-shell/src/ymodem.rs b/uboot-shell/src/ymodem.rs index 1b05f371..b1ff0cf6 100644 --- a/uboot-shell/src/ymodem.rs +++ b/uboot-shell/src/ymodem.rs @@ -12,15 +12,16 @@ use crate::crc::crc16_ccitt; const SOH: u8 = 0x01; const STX: u8 = 0x02; const EOT: u8 = 0x04; -const ACK: u8 = 0x06; -const NAK: u8 = 0x15; +pub(crate) const ACK: u8 = 0x06; +pub(crate) const NAK: u8 = 0x15; const EOF: u8 = 0x1A; -const CRC: u8 = 0x43; +pub(crate) const CRC: u8 = 0x43; +pub(crate) const DEFAULT_BLOCK_RETRIES: usize = 10; pub struct Ymodem { crc_mode: bool, blk: u8, - retries: usize, + max_block_retries: usize, } impl Ymodem { @@ -28,7 +29,7 @@ impl Ymodem { Self { crc_mode, blk: 0, - retries: 10, + max_block_retries: DEFAULT_BLOCK_RETRIES, } } @@ -141,9 +142,10 @@ impl Ymodem { }; let blk = if last { 0 } else { self.blk }; let mut err = None; + let mut retries = self.max_block_retries; loop { - if self.retries == 0 { + if retries == 0 { return Err(err.unwrap_or(Error::new(ErrorKind::BrokenPipe, "retry too much"))); } @@ -165,7 +167,7 @@ impl Ymodem { Ok(_) => break, Err(e) => { err = Some(e); - self.retries -= 1; + retries -= 1; } } } @@ -179,3 +181,82 @@ impl Ymodem { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use std::{ + collections::VecDeque, + pin::Pin, + sync::Mutex, + task::{Context, Poll}, + }; + + use futures::io::Cursor; + + struct ScriptedDevice { + reads: VecDeque, + writes: Vec, + } + + impl AsyncRead for ScriptedDevice { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + if self.reads.is_empty() { + return Poll::Ready(Ok(0)); + } + + let n = buf.len().min(self.reads.len()); + for slot in &mut buf[..n] { + *slot = self.reads.pop_front().unwrap(); + } + Poll::Ready(Ok(n)) + } + } + + impl AsyncWrite for ScriptedDevice { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.writes.extend_from_slice(buf); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + } + + #[tokio::test] + async fn acked_block_resets_retry_budget() -> Result<()> { + let mut reads = VecDeque::from([CRC, ACK]); + reads.extend(std::iter::repeat_n(CRC, DEFAULT_BLOCK_RETRIES - 1)); + reads.extend([ACK, ACK, ACK, CRC]); + + let mut dev = ScriptedDevice { + reads, + writes: Vec::new(), + }; + let mut file = Cursor::new(b"payload".to_vec()); + let progress = Mutex::new(Vec::new()); + + Ymodem::new(true) + .send(&mut dev, &mut file, "kernel", 7, |sent| { + progress.lock().unwrap().push(sent); + }) + .await?; + + assert_eq!(*progress.lock().unwrap(), vec![7]); + assert!(dev.writes.contains(&EOT)); + Ok(()) + } +}