Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rust/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
target
7 changes: 7 additions & 0 deletions rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub mod matrix_layout;
pub mod matrix;
8 changes: 6 additions & 2 deletions rust/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
use NeuralNetwork::matrix;
use NeuralNetwork::matrix_layout::{ColumnMajor, RowMajor};
Comment thread
SanaeProject marked this conversation as resolved.
Comment thread
SanaeProject marked this conversation as resolved.
Comment thread
SanaeProject marked this conversation as resolved.

fn main() {
println!("Hello, world!");
}
let mtx = matrix::Matrix::<i32, RowMajor>::with_size(3, 4);
println!("{}", mtx);
}
55 changes: 55 additions & 0 deletions rust/src/matrix.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use core::fmt;
use crate::matrix_layout::{ RowMajor, MatrixLayout};

pub struct Matrix<T, L: MatrixLayout = RowMajor> {
data: Vec<T>,
row: usize,
col: usize,
_marker: std::marker::PhantomData<L>,
}

impl<T, L: MatrixLayout> Matrix<T, L> {
pub fn with_size(row: usize, col: usize) -> Self
where
T: Default + Clone,
{
Self {
data: vec![T::default(); row * col],
Comment thread
SanaeProject marked this conversation as resolved.
Comment thread
SanaeProject marked this conversation as resolved.
Comment thread
SanaeProject marked this conversation as resolved.
Comment thread
SanaeProject marked this conversation as resolved.
row,
col,
_marker: std::marker::PhantomData,
}
}

pub fn get(&self, row: usize, col: usize) -> Option<&T> {
if row >= self.row || col >= self.col { return None; }

let idx = L::get_index(row, col, self.row, self.col)?;
self.data.get(idx)
Comment thread
SanaeProject marked this conversation as resolved.
}
Comment on lines +24 to +29

pub fn get_row(&self, row: usize) -> Option<impl Iterator<Item = &T>> {
let (start, step) = L::row_stride(row, self.row, self.col)?;
Some(self.data.iter().skip(start).step_by(step).take(self.col))
}

pub fn get_column(&self, col: usize) -> Option<impl Iterator<Item = &T>> {
let (start, step) = L::col_stride(col, self.row, self.col)?;
Some(self.data.iter().skip(start).step_by(step).take(self.row))
}
}

impl<T: std::fmt::Display, L: MatrixLayout> std::fmt::Display for Matrix<T, L> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(0..self.row).for_each(|r| {
if let Some(row_iter) = self.get_row(r) {
row_iter.for_each(|c| {
_ = write!(f, "{}\t", c);
});
}
_ = write!(f, "\n");
});
Comment thread
SanaeProject marked this conversation as resolved.
Comment thread
SanaeProject marked this conversation as resolved.
Comment thread
SanaeProject marked this conversation as resolved.

Ok(())
Comment thread
SanaeProject marked this conversation as resolved.
}
}
75 changes: 75 additions & 0 deletions rust/src/matrix_layout.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
pub trait MatrixLayout {
/// row, colから一次元配列のindexを取得する
/// # 引数
/// - `row`: 行番号
/// - `col`: 列番号
/// - `matrix_row`: 行数
/// - `matrix_col`: 列数
/// # 戻り値
/// - 一次元配列のindex
fn get_index(row: usize, col: usize, matrix_row: usize, matrix_col: usize) -> Option<usize>;

/// row, colから行・列のstrideを取得する
/// # 引数
/// - `row`: 行番号
/// - `matrix_row`: 行数
/// - `matrix_col`: 列数
/// # 戻り値
/// - (開始位置, ステップ)のタプル
fn row_stride(row: usize, matrix_row: usize, matrix_col: usize) -> Option<(usize, usize)>;

/// row, colから行・列のstrideを取得する
/// # 引数
/// - `col`: 列番号
/// - `matrix_row`: 行数
/// - `matrix_col`: 列数
/// # 戻り値
/// - (開始位置, ステップ)のタプル
fn col_stride(col: usize, matrix_row: usize, matrix_col: usize) -> Option<(usize, usize)>;
}

pub struct RowMajor;
impl MatrixLayout for RowMajor {
fn get_index(row: usize, col: usize, _m_row: usize, m_col: usize) -> Option<usize> {
if _m_row == 0 || m_col == 0 || row >= _m_row || col >= m_col {
return None;
}

Some(row * m_col + col)
}
fn row_stride(row: usize, _m_row: usize, m_col: usize) -> Option<(usize, usize)> {
if _m_row == 0 || m_col == 0 || row >= _m_row {
return None;
}

Some((row * m_col, 1))
}
fn col_stride(col: usize, _m_row: usize, m_col: usize) -> Option<(usize, usize)> {
if _m_row == 0 || m_col == 0 || col >= m_col {
return None;
}

Some((col, m_col))
}
}
pub struct ColumnMajor;
impl MatrixLayout for ColumnMajor {
fn get_index(row: usize, col: usize, m_row: usize, _m_col: usize) -> Option<usize> {
if m_row == 0 || _m_col == 0 || row >= m_row || col >= _m_col {
return None;
}
Some(col * m_row + row)
}
fn row_stride(row: usize, m_row: usize, _m_col: usize) -> Option<(usize, usize)> {
if m_row == 0 || _m_col == 0 || row >= m_row {
return None;
}
Some((row, m_row))
}
fn col_stride(col: usize, m_row: usize, _m_col: usize) -> Option<(usize, usize)> {
if m_row == 0 || _m_col == 0 || col >= _m_col {
return None;
}
Some((col * m_row, 1))
}
}