Compare commits

..

8 Commits

Author SHA1 Message Date
787f74e3ec break out module 2024-07-25 21:16:42 +02:00
9b68b32354 prefer box 2024-07-25 21:07:31 +02:00
e5c9cb6024 prefer match 2024-07-25 21:03:43 +02:00
0cd2d364aa break out network_change_listener 2024-07-25 21:01:40 +02:00
f34c4609a5 exit listener 2024-07-25 20:46:30 +02:00
5466090256 remove main 2024-07-25 20:35:08 +02:00
b6f447e39f add check every so other 2024-07-25 20:32:52 +02:00
8e5d472018 tests and time support 2024-07-25 18:16:12 +02:00
12 changed files with 357 additions and 96 deletions

View File

@ -1,21 +1,26 @@
// SPDX: BSD-2-Clause
use crate::utils::duration_to_string;
use anyhow;
use dirs;
use log::warn;
use serde::{self, Deserialize, Serialize};
use std::env;
use std::path::{Path, PathBuf};
use std::time::Duration;
use tokio::{fs, io::AsyncReadExt};
use crate::PROGRAM_NAME;
#[derive(Deserialize, Serialize, Debug)]
#[derive(Debug, Deserialize, Serialize)]
pub struct Config {
pub zone_id: Box<str>,
pub api_key: Box<str>,
pub zone_id: String,
pub api_key: String,
pub domains: Vec<Box<str>>,
#[serde(default)]
pub max_errors_in_row: Option<usize>,
#[serde(with = "duration_format", default)]
pub max_duration: Option<Duration>,
}
pub async fn get_config_path() -> Result<PathBuf, Vec<PathBuf>> {
@ -67,3 +72,58 @@ pub async fn read_config<P: AsRef<Path>>(path: &P) -> anyhow::Result<Config> {
file.read_to_string(&mut buf).await?;
Ok(toml::from_str(&buf)?)
}
mod duration_format {
use super::*;
use serde::{self, Deserialize, Deserializer, Serializer};
use std::time::Duration;
pub fn serialize<S>(duration: &Option<Duration>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match duration.as_ref() {
Some(duration) => serializer.serialize_str(duration_to_string(duration).trim()),
None => serializer.serialize_none(),
}
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
where
D: Deserializer<'de>,
{
Option::<Box<str>>::deserialize(deserializer)?.map_or(Ok(None), |s| {
parse_duration(&s)
.map(Some)
.map_err(serde::de::Error::custom)
})
}
fn parse_duration(s: &str) -> Result<Duration, String> {
let mut total_duration = Duration::new(0, 0);
let units = [("d", 86400), ("h", 3600), ("m", 60), ("s", 1)];
let mut remainder = s;
for &(unit, factor) in &units {
if let Some(idx) = remainder.find(unit) {
let (value, rest) = remainder.split_at(idx);
let value: u64 = value.trim().parse().map_err(|_| "Invalid number")?;
total_duration += Duration::from_secs(value * factor);
remainder = &rest[unit.len()..];
}
}
if let Some(idx) = remainder.find("ns") {
let (value, rest) = remainder.split_at(idx);
let value: u32 = value.trim().parse().map_err(|_| "Invalid number")?;
total_duration += Duration::new(0, value);
remainder = &rest["ns".len()..];
}
if !remainder.trim().is_empty() {
return Err("Invalid duration format".to_string());
}
Ok(total_duration)
}
}

43
src/exit_listener.rs Normal file
View File

@ -0,0 +1,43 @@
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use tokio::{
self, signal,
sync::{futures::Notified, Notify},
};
pub struct ExitListener {
should_exit: Arc<AtomicBool>,
notify: Arc<Notify>,
}
impl ExitListener {
pub fn new() -> Self {
let this = Self {
should_exit: Arc::new(AtomicBool::new(false)),
notify: Arc::new(Notify::new()),
};
let should_exit = this.should_exit.clone();
let notify = this.notify.clone();
tokio::spawn(async move {
signal::ctrl_c()
.await
.expect("Failed to install CTRL+C signal handler");
should_exit.store(true, Ordering::SeqCst);
notify.notify_one();
});
this
}
pub fn notified(&self) -> Notified<'_> {
self.notify.notified()
}
pub fn should_exit(&self) -> bool {
self.should_exit.load(Ordering::SeqCst)
}
}

View File

@ -3,40 +3,14 @@
use anyhow::{Context, Result};
use log::{error, info};
use reqwest::Client;
use serde::{self, Deserialize, Serialize};
use std::{
collections::HashMap,
fmt,
net::{IpAddr, Ipv4Addr},
};
use super::cloudflare_responses::{CloudflareResponse, DnsRecord};
use crate::get_current_public_ipv4;
#[derive(Serialize, Deserialize, Clone, Debug)]
struct DnsRecord {
id: String,
#[serde(rename = "type")]
record_type: Box<str>,
name: Box<str>,
content: Box<str>,
ttl: u32,
proxied: bool,
locked: bool,
zone_id: Box<str>,
zone_name: Box<str>,
modified_on: Box<str>,
created_on: Box<str>,
meta: HashMap<Box<str>, serde_json::Value>,
}
#[derive(Serialize, Deserialize, Debug)]
struct CloudflareResponse {
success: bool,
errors: Vec<HashMap<String, serde_json::Value>>,
messages: Vec<HashMap<String, serde_json::Value>>,
result: Option<Vec<DnsRecord>>,
}
pub struct CloudflareClient<A, Z>
where
A: fmt::Display,

View File

@ -0,0 +1,27 @@
use std::collections::HashMap;
use serde;
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
pub struct DnsRecord {
pub id: Box<str>,
#[serde(rename = "type")]
pub record_type: Box<str>,
pub name: Box<str>,
pub content: Box<str>,
pub ttl: u32,
pub proxied: bool,
pub locked: bool,
pub zone_id: Box<str>,
pub zone_name: Box<str>,
pub modified_on: Box<str>,
pub created_on: Box<str>,
pub meta: HashMap<Box<str>, serde_json::Value>,
}
#[derive(serde::Serialize, serde::Deserialize, Debug)]
pub struct CloudflareResponse {
pub success: bool,
pub errors: Box<[HashMap<Box<str>, serde_json::Value>]>,
pub messages: Box<[HashMap<Box<str>, serde_json::Value>]>,
pub result: Option<Vec<DnsRecord>>,
}

3
src/internet/mod.rs Normal file
View File

@ -0,0 +1,3 @@
mod cloudflare;
pub mod cloudflare_responses;
pub use cloudflare::CloudflareClient;

View File

@ -1,17 +1,23 @@
// SPDX: BSD-2-Clause
mod cloudflare;
mod config;
mod exit_listener;
mod logging;
mod public_ip;
mod message_handler;
mod network_change_listener;
mod public_ip;
pub mod utils;
pub use cloudflare::CloudflareClient;
mod tests;
mod internet;
pub use internet::CloudflareClient;
pub use config::{get_config_path, read_config, Config};
pub use exit_listener::ExitListener;
pub use logging::init_logger;
pub use public_ip::get_current_public_ipv4;
pub use message_handler::MessageHandler;
pub use network_change_listener::NetworkChangeListener;
pub use public_ip::get_current_public_ipv4;
pub const PROGRAM_NAME: &'static str = "dynip-cloudflare";
pub const MAX_ERORS_IN_ROW_DEFAULT: usize = 10;

View File

@ -1,32 +1,14 @@
// SPDX: BSD-2-Clause
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use futures::stream::StreamExt;
use futures::future::{self, Either};
use log::{error, info};
use netlink_sys::{AsyncSocket, SocketAddr};
use rtnetlink::new_connection;
use dynip_cloudflare::{utils, CloudflareClient, MessageHandler, MAX_ERORS_IN_ROW_DEFAULT};
use scopeguard::defer;
use tokio::{signal, sync::Notify};
use tokio::time;
const RTNLGRP_LINK: u32 = 1;
const RTNLGRP_IPV4_IFADDR: u32 = 5;
const fn nl_mgrp(group: u32) -> u32 {
if group > 31 {
panic!("use netlink_sys::Socket::add_membership() for this group");
}
if group == 0 {
0
} else {
1 << (group - 1)
}
}
use dynip_cloudflare::{
utils, CloudflareClient, ExitListener, MessageHandler, NetworkChangeListener,
MAX_ERORS_IN_ROW_DEFAULT,
};
#[tokio::main]
async fn main() {
@ -34,24 +16,11 @@ async fn main() {
defer! {
log::logger().flush();
}
let should_exit = Arc::new(AtomicBool::new(false));
let notify = Arc::new(Notify::new());
let exit_listener = ExitListener::new();
let should_exit_clone = should_exit.clone();
let notify_clone = notify.clone();
tokio::spawn(async move {
signal::ctrl_c()
.await
.expect("Failed to install CTRL+C signal handler");
should_exit_clone.store(true, Ordering::SeqCst);
notify_clone.notify_one();
});
let config = if let Some(aux) = utils::get_config().await {
aux
} else {
return;
let config = match utils::get_config().await {
Some(aux) => aux,
None => return,
};
let mut cloudflare =
@ -63,29 +32,48 @@ async fn main() {
}
};
let (mut conn, mut _handle, mut messages) = new_connection().unwrap();
let groups = nl_mgrp(RTNLGRP_LINK) | nl_mgrp(RTNLGRP_IPV4_IFADDR);
let addr = SocketAddr::new(0, groups);
if let Err(e) = conn.socket_mut().socket_mut().bind(&addr) {
error!("Failed to bind to socket: {:?}", &e);
let mut network_change_listener = match NetworkChangeListener::new() {
Some(aux) => {
info!("Listening for IPv4 address changes and interface connect/disconnect events...");
aux
}
None => {
error!("Failed to initialize networkchangelistener");
return;
}
tokio::spawn(conn);
info!("Listening for IPv4 address changes and interface connect/disconnect events...");
};
let mut message_handler = MessageHandler::new(
&mut cloudflare,
config.max_errors_in_row.unwrap_or(MAX_ERORS_IN_ROW_DEFAULT),
);
while !should_exit.load(Ordering::SeqCst) {
let mut interval = config.max_duration.map(|duration| {
let mut interval = time::interval(duration);
interval.set_missed_tick_behavior(time::MissedTickBehavior::Delay);
interval
});
while !exit_listener.should_exit() {
let tick_future = match interval.as_mut() {
Some(interval) => Either::Left(interval.tick()),
None => Either::Right(future::pending::<tokio::time::Instant>()),
};
tokio::select! {
_ = notify.notified() => break,
message = messages.next() => {
_ = exit_listener.notified() => break,
_ = tick_future => {
if let Some(duration) = config.max_duration.as_ref() {
let duration_string = utils::duration_to_string(duration);
let log_string = format!("{} has passed since last check, checking...", duration_string.trim());
message_handler.log_and_check(Some(&log_string), Option::<&&str>::None).await;
}
}
message = network_change_listener.next_message() => {
if let Some((message, _)) = message {
if let Some(interval) = interval.as_mut() {
interval.reset();
}
message_handler.handle_message(message).await;
}
}

View File

@ -1,7 +1,6 @@
// SPDX: BSD-2-Clause
use std::fmt;
use log::{debug, error, info};
use netlink_packet_core::{NetlinkMessage, NetlinkPayload};
use netlink_packet_route::RouteNetlinkMessage as RtnlMessage;
@ -34,13 +33,14 @@ where
pub async fn handle_message(&mut self, message: NetlinkMessage<RtnlMessage>) -> Option<()> {
match message.payload {
NetlinkPayload::InnerMessage(RtnlMessage::NewAddress(msg)) => {
self.log_and_check("New IPv4 address", &msg).await
self.log_and_check(Some("New IPv4 address"), Some(&msg))
.await
}
NetlinkPayload::InnerMessage(RtnlMessage::DelAddress(msg)) => {
self.log_info("Deleted IPv4 address", &msg).await
}
NetlinkPayload::InnerMessage(RtnlMessage::NewLink(link)) => {
self.log_and_check("New link (interface connected)", &link)
self.log_and_check(Some("New link (interface connected)"), Some(&link))
.await
}
NetlinkPayload::InnerMessage(RtnlMessage::DelLink(link)) => {
@ -54,13 +54,21 @@ where
}
}
async fn log_and_check<D, M>(&mut self, log_msg: &D, msg: &M) -> Option<()>
pub async fn log_and_check<D, M>(&mut self, log_msg: Option<&D>, msg: Option<&M>) -> Option<()>
where
D: fmt::Display + ?Sized,
M: fmt::Debug,
{
info!("{}", log_msg);
debug!("{}: {:?}", log_msg, msg);
if let Some(s) = log_msg {
info!("{}", s);
}
if let Some(m) = msg {
if let Some(lm) = log_msg {
debug!("{}: {:?}", lm, m);
} else {
debug!("{:?}", m);
}
}
if let Err(e) = self.cloudflare.check().await {
self.errs_counter += 1;
error!(

View File

@ -0,0 +1,60 @@
use futures::{
channel::mpsc::UnboundedReceiver,
stream::{Next, StreamExt},
};
use log::error;
use netlink_packet_core::NetlinkMessage;
use netlink_packet_route::RouteNetlinkMessage;
use netlink_sys::{AsyncSocket, SocketAddr};
use rtnetlink::new_connection;
use tokio::task::JoinHandle;
const RTNLGRP_LINK: u32 = 1;
const RTNLGRP_IPV4_IFADDR: u32 = 5;
const fn nl_mgrp(group: u32) -> u32 {
if group > 31 {
panic!("use netlink_sys::Socket::add_membership() for this group");
}
if group == 0 {
0
} else {
1 << (group - 1)
}
}
type Messages = UnboundedReceiver<(NetlinkMessage<RouteNetlinkMessage>, SocketAddr)>;
pub struct NetworkChangeListener {
messages: Messages,
thread_handle: JoinHandle<()>,
}
impl NetworkChangeListener {
pub fn new() -> Option<Self> {
let (mut conn, mut _handle, messages) = new_connection().ok()?;
let groups = nl_mgrp(RTNLGRP_LINK) | nl_mgrp(RTNLGRP_IPV4_IFADDR);
let addr = SocketAddr::new(0, groups);
if let Err(e) = conn.socket_mut().socket_mut().bind(&addr) {
error!("Failed to bind to socket: {:?}", &e);
return None;
}
let thread_handle = tokio::spawn(conn);
Some(Self {
messages,
thread_handle,
})
}
pub fn next_message(&mut self) -> Next<'_, Messages> {
self.messages.next()
}
}
impl Drop for NetworkChangeListener {
fn drop(&mut self) {
self.thread_handle.abort();
}
}

View File

@ -0,0 +1,59 @@
use crate::Config;
use std::time::Duration;
use toml;
const TOML_STR_ONE: &str = r#"
zone_id = ""
api_key = ""
domains = [""]
max_duration = "1d 2h 30m 45s 500000000ns"
"#;
#[test]
fn test_deserialize() {
let config: Config = toml::from_str(TOML_STR_ONE).unwrap();
assert_eq!(config.max_duration, Some(Duration::new(95445, 500000000)));
}
#[test]
fn test_serialize() {
let config = Config {
zone_id: "".into(),
api_key: "".into(),
domains: vec!["".into()],
max_errors_in_row: None,
max_duration: Some(Duration::new(95445, 500000000)),
};
let toml_str = toml::to_string(&config).unwrap();
assert_eq!(TOML_STR_ONE.trim(), toml_str.trim());
}
#[test]
fn test_deserialize_none() {
let toml_str = r#"
zone_id = ""
api_key = ""
domains = [""]
max_errors_in_row = 5
"#;
let config: Config = toml::from_str(toml_str).unwrap();
assert_eq!(config.max_duration, None);
}
#[test]
fn test_serialize_none() {
let toml_to_be = r#"
zone_id = ""
api_key = ""
domains = [""]
"#;
let config = Config {
zone_id: "".into(),
api_key: "".into(),
domains: vec!["".into()],
max_errors_in_row: None,
max_duration: None,
};
let toml_str = toml::to_string(&config).unwrap();
assert_eq!(toml_to_be.trim(), toml_str.trim());
}

2
src/tests/mod.rs Normal file
View File

@ -0,0 +1,2 @@
#[cfg(test)]
mod config_serialization;

View File

@ -1,5 +1,7 @@
// SPDX: BSD-2-Clause
use std::time::Duration;
use crate::{get_config_path, read_config, Config};
use log::error;
@ -28,3 +30,32 @@ pub async fn get_config() -> Option<Config> {
}
read_result.ok()
}
pub fn duration_to_string(duration: &Duration) -> String {
let mut secs = duration.as_secs();
let nanos = duration.subsec_nanos();
let days = secs / 86400;
secs %= 86400;
let hours = secs / 3600;
secs %= 3600;
let minutes = secs / 60;
secs %= 60;
let mut ret = String::new();
if days > 0 {
ret.push_str(&format!("{}d ", days));
}
if hours > 0 {
ret.push_str(&format!("{}h ", hours));
}
if minutes > 0 {
ret.push_str(&format!("{}m ", minutes));
}
if secs > 0 {
ret.push_str(&format!("{}s ", secs));
}
if nanos > 0 {
ret.push_str(&format!("{}ns", nanos));
}
ret
}