diff --git a/src/cli/sort_by.rs b/src/cli/sort_by.rs index 3086333..f37effa 100644 --- a/src/cli/sort_by.rs +++ b/src/cli/sort_by.rs @@ -1,5 +1,7 @@ use clap::ValueEnum; +use crate::models::Currency; + #[derive(Debug, ValueEnum, Clone)] pub enum SortBy { Currency, @@ -7,9 +9,9 @@ pub enum SortBy { } impl SortBy { - pub fn get_comparer(&self) -> fn(&(&str, f64), &(&str, f64)) -> std::cmp::Ordering { + pub fn get_comparer(&self) -> fn(&(&Currency, f64), &(&Currency, f64)) -> std::cmp::Ordering { match self { - Self::Currency => |a, b| a.0.cmp(&b.0), + Self::Currency => |a, b| a.0.as_ref().cmp(b.0.as_ref()), Self::Rate => |a, b| a.1.total_cmp(&b.1), } } diff --git a/src/main.rs b/src/main.rs index 573c9af..6ace144 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,9 +3,10 @@ use ecb_rates::cache::{Cache, CacheLine}; use ecb_rates::HeaderDescription; use reqwest::{Client, IntoUrl}; use std::process::ExitCode; +use std::str::FromStr; use ecb_rates::cli::{Cli, FormatOption}; -use ecb_rates::models::ExchangeRateResult; +use ecb_rates::models::{Currency, ExchangeRateResult}; use ecb_rates::parsing::parse; use ecb_rates::table::{TableRef, TableTrait as _}; use ecb_rates::utils_calc::{change_perspective, filter_currencies, invert_rates, round}; @@ -79,8 +80,18 @@ async fn main() -> ExitCode { }; cli.perspective = cli.perspective.map(|s| s.to_uppercase()); - if let Some(currency) = cli.perspective.as_ref() { - header_description.replace_eur(¤cy); + let parsed_currency = match cli.perspective.as_ref() { + Some(currency) => match Currency::from_str(currency) { + Ok(k) => Some(k), + Err(e) => { + eprintln!("The currency code '{}' is invalid: {:?}", currency, e); + return ExitCode::FAILURE; + } + }, + None => None, + }; + if let Some(currency) = parsed_currency.as_ref() { + header_description.replace_eur(currency.as_ref()); let error_occured = change_perspective(&mut parsed, ¤cy).is_none(); if error_occured { eprintln!("The currency wasn't in the data from the ECB!"); @@ -96,11 +107,19 @@ async fn main() -> ExitCode { round(&mut parsed, cli.max_decimals); if !cli.currencies.is_empty() { - let currencies = cli + let currencies = match cli .currencies .iter() .map(|x| x.to_uppercase()) - .collect::>(); + .map(|x| Currency::from_str(&x)) + .collect::>>() + { + Ok(k) => k, + Err(e) => { + eprintln!("Failed to parse currenc(y/ies): {:?}", e); + return ExitCode::FAILURE; + } + }; filter_currencies(&mut parsed, ¤cies); } diff --git a/src/models/currency.rs b/src/models/currency.rs new file mode 100644 index 0000000..d739494 --- /dev/null +++ b/src/models/currency.rs @@ -0,0 +1,97 @@ +use std::{ + fmt, + ops::Index, + slice::Iter, + str::{self, FromStr}, +}; + +use serde::{de, Deserialize, Serialize}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Currency { + name: [u8; 3], +} + +impl Currency { + pub fn as_str(&self) -> &str { + // SAFETY: We validate that bytes are ASCII in FromStr. + unsafe { str::from_utf8_unchecked(&self.name) } + } + + pub fn iter(&self) -> Iter<'_, u8> { + self.name.iter() + } +} + +impl AsRef for Currency { + fn as_ref(&self) -> &str { + self.as_str() + } +} + +impl Serialize for Currency { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(self.as_str()) + } +} + +impl<'de> Deserialize<'de> for Currency { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + FromStr::from_str(&s).map_err(de::Error::custom) + } +} + +impl FromStr for Currency { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + if s.len() != 3 { + anyhow::bail!("Currency code must be exactly 3 chars"); + } + if !s.is_ascii() { + anyhow::bail!("Currency code must be ASCII"); + } + + let b = s.as_bytes(); + Ok(Self { + name: [b[0], b[1], b[2]], + }) + } +} + +impl TryFrom<&str> for Currency { + type Error = anyhow::Error; + fn try_from(value: &str) -> Result { + Currency::from_str(value) + } +} + +impl fmt::Display for Currency { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +impl Index for Currency { + type Output = u8; + + fn index(&self, index: usize) -> &Self::Output { + &self.name[index] + } +} + +impl<'a> IntoIterator for &'a Currency { + type Item = &'a u8; + type IntoIter = Iter<'a, u8>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} diff --git a/src/models.rs b/src/models/exchange_rate_result.rs similarity index 75% rename from src/models.rs rename to src/models/exchange_rate_result.rs index 2b0dedc..d8f9f10 100644 --- a/src/models.rs +++ b/src/models/exchange_rate_result.rs @@ -1,8 +1,10 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; +use super::Currency; + #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] pub struct ExchangeRateResult { pub time: String, - pub rates: HashMap, + pub rates: HashMap, } diff --git a/src/models/mod.rs b/src/models/mod.rs new file mode 100644 index 0000000..f4c8406 --- /dev/null +++ b/src/models/mod.rs @@ -0,0 +1,5 @@ +mod currency; +mod exchange_rate_result; + +pub use currency::Currency; +pub use exchange_rate_result::ExchangeRateResult; diff --git a/src/parsing.rs b/src/parsing.rs index daa6cac..e94ef6a 100644 --- a/src/parsing.rs +++ b/src/parsing.rs @@ -1,9 +1,9 @@ -use std::collections::HashMap; +use std::{collections::HashMap, str::FromStr}; use quick_xml::events::Event; use quick_xml::Reader; -use crate::models::ExchangeRateResult; +use crate::models::{Currency, ExchangeRateResult}; pub fn parse(xml: &str) -> anyhow::Result> { let mut reader = Reader::from_str(xml); @@ -18,7 +18,7 @@ pub fn parse(xml: &str) -> anyhow::Result> { e: &quick_xml::events::BytesStart, current_time: &mut Option, inside_cube_time: &mut bool, - current_rates: &mut HashMap, + current_rates: &mut HashMap, results: &mut Vec, ) -> anyhow::Result<()> { if e.name().local_name().as_ref() != b"Cube" { @@ -26,7 +26,7 @@ pub fn parse(xml: &str) -> anyhow::Result> { } let mut time_attr: Option = None; - let mut currency_attr: Option = None; + let mut currency_attr: Option = None; let mut rate_attr: Option = None; for attr_result in e.attributes() { @@ -39,7 +39,7 @@ pub fn parse(xml: &str) -> anyhow::Result> { time_attr = Some(val); } b"currency" => { - currency_attr = Some(val); + currency_attr = Some(Currency::from_str(&val)?); } b"rate" => { rate_attr = Some(val); diff --git a/src/table/table_owned.rs b/src/table/table_owned.rs index f4f143a..c0ca533 100644 --- a/src/table/table_owned.rs +++ b/src/table/table_owned.rs @@ -1,7 +1,7 @@ use std::fmt::Display; use crate::cli::SortBy; -use crate::models::ExchangeRateResult; +use crate::models::{Currency, ExchangeRateResult}; use crate::DEFAULT_WIDTH; use super::table_display::helper_table_print; @@ -11,7 +11,7 @@ pub struct Table { pub(super) header: Option, pub(super) column_left: String, pub(super) column_right: String, - pub(super) rows: Vec<(String, f64)>, + pub(super) rows: Vec<(Currency, f64)>, pub color: bool, pub width: usize, pub left_offset: usize, @@ -21,7 +21,7 @@ impl<'a> TableTrait<'a> for Table { type Header = String; type ColumnLeft = String; type ColumnRight = String; - type RowLeft = String; + type RowLeft = Currency; fn new( header: Option, @@ -59,7 +59,7 @@ impl<'a> TableTrait<'a> for Table { } impl TableGet for Table { - type RowLeftRef = String; + type RowLeftRef = Currency; type RowRightRef = String; fn get_header(&self) -> Option<&str> { @@ -77,7 +77,7 @@ impl TableGet for Table { fn get_width(&self) -> usize { self.width } - + fn get_left_offset(&self) -> usize { self.left_offset } diff --git a/src/table/table_ref.rs b/src/table/table_ref.rs index b7b13ac..3c888e9 100644 --- a/src/table/table_ref.rs +++ b/src/table/table_ref.rs @@ -1,7 +1,7 @@ use std::fmt::Display; use crate::cli::SortBy; -use crate::models::ExchangeRateResult; +use crate::models::{Currency, ExchangeRateResult}; use crate::DEFAULT_WIDTH; use super::table_display::helper_table_print; @@ -13,7 +13,7 @@ pub struct TableRef<'a> { header: Option<&'a str>, column_left: &'a str, column_right: &'a str, - rows: Vec<(&'a str, f64)>, + rows: Vec<(&'a Currency, f64)>, pub color: bool, pub width: usize, pub left_offset: usize, @@ -23,7 +23,7 @@ impl<'a> TableTrait<'a> for TableRef<'a> { type Header = &'a str; type ColumnLeft = &'a str; type ColumnRight = &'a str; - type RowLeft = &'a str; + type RowLeft = &'a Currency; fn new( header: Option, @@ -60,7 +60,7 @@ impl<'a> TableTrait<'a> for TableRef<'a> { } impl<'a> TableGet for TableRef<'a> { - type RowLeftRef = &'a str; + type RowLeftRef = &'a Currency; type RowRightRef = &'a str; fn get_header(&self) -> Option<&str> { @@ -78,7 +78,7 @@ impl<'a> TableGet for TableRef<'a> { fn get_width(&self) -> usize { self.width } - + fn get_left_offset(&self) -> usize { self.left_offset } @@ -106,7 +106,7 @@ impl<'a> From<&'a Table> for TableRef<'a> { let rows = table .rows .iter() - .map(|(left, right)| (left.as_str(), *right)) + .map(|(left, right)| (left, *right)) .collect(); TableRef { diff --git a/src/utils_calc.rs b/src/utils_calc.rs index 81b9e3e..9365bd6 100644 --- a/src/utils_calc.rs +++ b/src/utils_calc.rs @@ -1,10 +1,13 @@ -use std::{borrow::BorrowMut, collections::HashMap, ops::Deref}; +use std::{borrow::BorrowMut, collections::HashMap, ops::Deref, str::FromStr}; -use crate::models::ExchangeRateResult; +use crate::models::{Currency, ExchangeRateResult}; -pub fn filter_currencies(exchange_rate_results: &mut [ExchangeRateResult], currencies: &[String]) { +pub fn filter_currencies( + exchange_rate_results: &mut [ExchangeRateResult], + currencies: &[Currency], +) { for exchange_rate in exchange_rate_results { - let rates_ptr: *mut HashMap = &mut exchange_rate.rates; + let rates_ptr: *mut HashMap = &mut exchange_rate.rates; exchange_rate .rates .keys() @@ -22,7 +25,7 @@ pub fn filter_currencies(exchange_rate_results: &mut [ExchangeRateResult], curre pub fn change_perspective( exchange_rate_results: &mut [ExchangeRateResult], - currency: &str, + currency: &Currency, ) -> Option<()> { for rate_res in exchange_rate_results { let currency_rate = rate_res.rates.remove(currency)?; @@ -32,7 +35,10 @@ pub fn change_perspective( *iter_rate = eur_rate * iter_rate.deref(); } - rate_res.rates.insert("EUR".to_string(), eur_rate); + rate_res.rates.insert( + unsafe { Currency::from_str("EUR").unwrap_unchecked() }, + eur_rate, + ); } Some(()) }