1, 添加创建更新城市轨道交通仿真功能接口

2, 添加mqtt客户端crate,提供订阅、发布、发送请求等功能
This commit is contained in:
soul-walker 2024-10-22 17:24:50 +08:00
parent 3798572b0c
commit 7226904e04
17 changed files with 1049 additions and 82 deletions

View File

@ -2,6 +2,8 @@
"cSpell.words": [
"chrono",
"cpus",
"dashmap",
"eventloop",
"Graphi",
"graphiql",
"hashbrown",
@ -11,7 +13,11 @@
"Joylink",
"jsonwebtoken",
"mplj",
"Mqtt",
"mqttbytes",
"Neng",
"nextval",
"oneshot",
"plpgsql",
"prost",
"proto",
@ -20,6 +26,8 @@
"repr",
"reqwest",
"rtss",
"rumqtt",
"rumqttc",
"sqlx",
"sysinfo",
"thiserror",

169
Cargo.lock generated
View File

@ -121,9 +121,9 @@ dependencies = [
[[package]]
name = "anyhow"
version = "1.0.88"
version = "1.0.90"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e1496f8fb1fbf272686b8d37f523dab3e4a7443300055e74cdaa449f3114356"
checksum = "37bf3594c4c988a53154954629820791dde498571819ae4ca50ca811e060cc95"
[[package]]
name = "ascii_utils"
@ -270,9 +270,9 @@ checksum = "8b75356056920673b02621b35afd0f7dda9306d03c79a30f5c56c44cf256e3de"
[[package]]
name = "async-trait"
version = "0.1.81"
version = "0.1.83"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e0c28dcc82d7c8ead5cb13beb15405b57b8546e93215673ff8ca0349a028107"
checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd"
dependencies = [
"proc-macro2",
"quote",
@ -619,9 +619,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "bytes"
version = "1.7.1"
version = "1.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8318a53db07bb3f8dca91a600466bdb3f2eaadeedfdbcf02e1accbad9271ba50"
checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3"
dependencies = [
"serde",
]
@ -658,9 +658,9 @@ dependencies = [
[[package]]
name = "clap"
version = "4.5.17"
version = "4.5.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3e5a21b8495e732f1b3c364c9949b201ca7bae518c502c80256c96ad79eaf6ac"
checksum = "b97f376d85a664d5837dbae44bf546e6477a679ff6610010f17276f686d867e8"
dependencies = [
"clap_builder",
"clap_derive",
@ -668,9 +668,9 @@ dependencies = [
[[package]]
name = "clap_builder"
version = "4.5.17"
version = "4.5.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8cf2dd12af7a047ad9d6da2b6b249759a22a7abc0f474c1dae1777afa4b21a73"
checksum = "19bc80abd44e4bed93ca373a0704ccbd1b710dc5749406201bb018272808dc54"
dependencies = [
"anstream",
"anstyle",
@ -680,9 +680,9 @@ dependencies = [
[[package]]
name = "clap_derive"
version = "4.5.13"
version = "4.5.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "501d359d5f3dcaf6ecdeee48833ae73ec6e42723a1e52419c79abf9507eec0a0"
checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab"
dependencies = [
"heck",
"proc-macro2",
@ -776,6 +776,16 @@ dependencies = [
"unicode-segmentation",
]
[[package]]
name = "core-foundation"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "core-foundation-sys"
version = "0.8.7"
@ -1407,10 +1417,10 @@ dependencies = [
"http",
"hyper",
"hyper-util",
"rustls",
"rustls 0.23.13",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tokio-rustls 0.26.0",
"tower-service",
"webpki-roots",
]
@ -1816,6 +1826,12 @@ version = "1.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
[[package]]
name = "openssl-probe"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
[[package]]
name = "ordered-multimap"
version = "0.6.0"
@ -2119,7 +2135,7 @@ dependencies = [
"quinn-proto",
"quinn-udp",
"rustc-hash",
"rustls",
"rustls 0.23.13",
"socket2",
"thiserror",
"tokio",
@ -2136,7 +2152,7 @@ dependencies = [
"rand",
"ring",
"rustc-hash",
"rustls",
"rustls 0.23.13",
"slab",
"thiserror",
"tinyvec",
@ -2301,7 +2317,7 @@ dependencies = [
"percent-encoding",
"pin-project-lite",
"quinn",
"rustls",
"rustls 0.23.13",
"rustls-pemfile",
"rustls-pki-types",
"serde",
@ -2309,7 +2325,7 @@ dependencies = [
"serde_urlencoded",
"sync_wrapper 1.0.1",
"tokio",
"tokio-rustls",
"tokio-rustls 0.26.0",
"tower-service",
"url",
"wasm-bindgen",
@ -2440,6 +2456,20 @@ dependencies = [
"tracing-wasm",
]
[[package]]
name = "rtss_mqtt"
version = "0.1.0"
dependencies = [
"async-trait",
"bytes",
"lazy_static",
"rtss_db",
"rtss_log",
"rumqttc",
"thiserror",
"tokio",
]
[[package]]
name = "rtss_sim_manage"
version = "0.1.0"
@ -2482,6 +2512,25 @@ dependencies = [
"rtss_log",
]
[[package]]
name = "rumqttc"
version = "0.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1568e15fab2d546f940ed3a21f48bbbd1c494c90c99c4481339364a497f94a9"
dependencies = [
"bytes",
"flume",
"futures-util",
"log",
"rustls-native-certs",
"rustls-pemfile",
"rustls-webpki",
"thiserror",
"tokio",
"tokio-rustls 0.25.0",
"url",
]
[[package]]
name = "rust-ini"
version = "0.19.0"
@ -2517,6 +2566,20 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "rustls"
version = "0.22.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432"
dependencies = [
"log",
"ring",
"rustls-pki-types",
"rustls-webpki",
"subtle",
"zeroize",
]
[[package]]
name = "rustls"
version = "0.23.13"
@ -2531,6 +2594,19 @@ dependencies = [
"zeroize",
]
[[package]]
name = "rustls-native-certs"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5bfb394eeed242e909609f56089eecfe5fda225042e8b171791b9c95f5931e5"
dependencies = [
"openssl-probe",
"rustls-pemfile",
"rustls-pki-types",
"schannel",
"security-framework",
]
[[package]]
name = "rustls-pemfile"
version = "2.1.3"
@ -2570,12 +2646,44 @@ version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
[[package]]
name = "schannel"
version = "0.1.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01227be5826fa0690321a2ba6c5cd57a19cf3f6a09e76973b58e61de6ab9d1c1"
dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "scopeguard"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "security-framework"
version = "2.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02"
dependencies = [
"bitflags 2.6.0",
"core-foundation",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
[[package]]
name = "security-framework-sys"
version = "2.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea4a292869320c0272d7bc55a5a6aafaff59b4f63404a003887b679a2e05b4b6"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "serde"
version = "1.0.210"
@ -2598,9 +2706,9 @@ dependencies = [
[[package]]
name = "serde_json"
version = "1.0.125"
version = "1.0.131"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "83c8e735a073ccf5be70aa8066aa984eaf2fa000db6c8d0100ae605b366d31ed"
checksum = "67d42a0bd4ac281beff598909bb56a86acaf979b84483e1c79c10dcaf98f8cf3"
dependencies = [
"itoa",
"memchr",
@ -3072,18 +3180,18 @@ dependencies = [
[[package]]
name = "thiserror"
version = "1.0.63"
version = "1.0.64"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724"
checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.63"
version = "1.0.64"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261"
checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3"
dependencies = [
"proc-macro2",
"quote",
@ -3182,13 +3290,24 @@ dependencies = [
"syn",
]
[[package]]
name = "tokio-rustls"
version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "775e0c0f0adb3a2f22a00c4745d728b479985fc15ee7ca6a2608388c5569860f"
dependencies = [
"rustls 0.22.4",
"rustls-pki-types",
"tokio",
]
[[package]]
name = "tokio-rustls"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4"
dependencies = [
"rustls",
"rustls 0.23.13",
"rustls-pki-types",
"tokio",
]

View File

@ -15,7 +15,7 @@ bevy_ecs = "0.14"
bevy_time = "0.14"
rayon = "1.10"
tokio = { version = "1.40", features = ["macros", "rt-multi-thread"] }
thiserror = "1.0"
thiserror = "1.0.64"
sqlx = { version = "0.8", features = [
"runtime-tokio",
"postgres",
@ -23,8 +23,12 @@ sqlx = { version = "0.8", features = [
"chrono",
] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0.125"
anyhow = "1.0"
serde_json = "1.0.131"
anyhow = "1.0.90"
async-trait = "0.1.83"
bytes = "1.7.2"
lazy_static = "1.5.0"
[dependencies]
tokio = { version = "1.39.3", features = ["macros", "rt-multi-thread"] }
@ -33,6 +37,6 @@ rtss_api = { path = "crates/rtss_api" }
rtss_db = { path = "crates/rtss_db" }
serde = { workspace = true }
config = "0.14.0"
clap = { version = "4.5", features = ["derive"] }
clap = { version = "4.5.20", features = ["derive"] }
enum_dispatch = "0.3"
anyhow = { workspace = true }

View File

@ -10,7 +10,7 @@ pub trait DataOptions: InputType + OutputType + Serialize + DeserializeOwned {
impl DataOptions for Value {
fn to_data_options_filter_clause(&self) -> String {
format!("options @> '{}'", self)
format!("{} @> '{}'", DraftDataColumn::Options.name(), self)
}
}

View File

@ -11,7 +11,7 @@ use serde_json::Value;
use crate::apis::{PageDto, PageQueryDto};
use crate::loader::RtssDbLoader;
use super::common::{DataOptions, IscsDataOptions};
use super::data_options_def::{DataOptions, IscsDataOptions};
use super::release_data::ReleaseDataId;
use super::user::UserId;
@ -115,7 +115,7 @@ impl DraftDataMutation {
.data::<UserAuthCache>()?
.query_user(&ctx.data::<Token>()?.0)
.await?;
input = input.with_user_id(user.id_i32());
input = input.with_data_type_and_user_id(DataType::Iscs, user.id_i32());
let db_accessor = ctx.data::<RtssDbAccessor>()?;
let draft_data = db_accessor.create_draft_data(input.into()).await?;
Ok(draft_data.into())
@ -207,6 +207,8 @@ impl DraftDataMutation {
#[derive(Debug, InputObject)]
#[graphql(concrete(name = "CreateDraftIscsDto", params(IscsDataOptions)))]
pub struct CreateDraftDataDto<T: DataOptions> {
#[graphql(skip)]
pub data_type: Option<DataType>,
pub name: String,
pub options: Option<T>,
#[graphql(skip)]
@ -214,7 +216,8 @@ pub struct CreateDraftDataDto<T: DataOptions> {
}
impl<T: DataOptions> CreateDraftDataDto<T> {
pub fn with_user_id(mut self, id: i32) -> Self {
pub fn with_data_type_and_user_id(mut self, data_type: DataType, id: i32) -> Self {
self.data_type = Some(data_type);
self.user_id = Some(id);
self
}
@ -224,7 +227,7 @@ impl<T: DataOptions> From<CreateDraftDataDto<T>> for rtss_db::CreateDraftData {
fn from(value: CreateDraftDataDto<T>) -> Self {
let cdd = Self::new(
&value.name,
DataType::Iscs,
value.data_type.expect("need data_type"),
value.user_id.expect("CreateDraftDataDto need user_id"),
);
if value.options.is_some() {

View File

@ -1,18 +1,22 @@
use crate::{
apis::{PageDto, PageQueryDto},
loader::RtssDbLoader,
user_auth::{RoleGuard, Token, UserAuthCache},
};
use async_graphql::{
dataloader::DataLoader, ComplexObject, Context, InputObject, Object, SimpleObject,
};
use chrono::NaiveDateTime;
use rtss_db::{FeatureAccessor, RtssDbAccessor};
use rtss_db::{CreateFeature, FeatureAccessor, RtssDbAccessor, UpdateFeature};
use rtss_dto::common::FeatureType;
use rtss_dto::common::Role;
use serde_json::Value;
use crate::{
apis::{PageDto, PageQueryDto},
loader::RtssDbLoader,
use super::{
feature_config_def::{FeatureConfig, UrFeatureConfig},
user::UserId,
};
use super::user::UserId;
#[derive(Default)]
pub struct FeatureQuery;
@ -21,7 +25,8 @@ pub struct FeatureMutation;
#[Object]
impl FeatureQuery {
/// 分页查询特征(系统管理)
/// 分页查询功能feature(系统管理)
#[graphql(guard = "RoleGuard::new(Role::Admin)")]
async fn feature_paging(
&self,
ctx: &Context<'_>,
@ -35,14 +40,16 @@ impl FeatureQuery {
Ok(paging.into())
}
/// id获取特征
/// id获取功能feature
#[graphql(guard = "RoleGuard::new(Role::User)")]
async fn feature(&self, ctx: &Context<'_>, id: i32) -> async_graphql::Result<FeatureDto> {
let dba = ctx.data::<RtssDbAccessor>()?;
let feature = dba.get_feature(id).await?;
Ok(feature.into())
}
/// id列表获取特征
/// id列表获取功能feature列表
#[graphql(guard = "RoleGuard::new(Role::User)")]
async fn features(
&self,
ctx: &Context<'_>,
@ -56,7 +63,8 @@ impl FeatureQuery {
#[Object]
impl FeatureMutation {
/// 上下架特征
/// 上下架功能feature
#[graphql(guard = "RoleGuard::new(Role::Admin)")]
async fn publish_feature(
&self,
ctx: &Context<'_>,
@ -67,6 +75,90 @@ impl FeatureMutation {
let feature = dba.set_feature_published(id, is_published).await?;
Ok(feature.into())
}
/// 创建城轨仿真功能feature
#[graphql(guard = "RoleGuard::new(Role::Admin)")]
async fn create_ur_feature(
&self,
ctx: &Context<'_>,
mut input: CreateFeatureDto<UrFeatureConfig>,
) -> async_graphql::Result<FeatureDto> {
let dba = ctx.data::<RtssDbAccessor>()?;
let user = ctx
.data::<UserAuthCache>()?
.query_user(&ctx.data::<Token>()?.0)
.await?;
input = input.with_feature_type_and_user_id(FeatureType::Ur, user.id_i32());
let feature = dba.create_feature(&input.into()).await?;
Ok(feature.into())
}
/// 更新城轨仿真功能feature
#[graphql(guard = "RoleGuard::new(Role::Admin)")]
async fn update_ur_feature(
&self,
ctx: &Context<'_>,
input: UpdateFeatureDto<UrFeatureConfig>,
) -> async_graphql::Result<FeatureDto> {
let dba = ctx.data::<RtssDbAccessor>()?;
let feature = dba.update_feature(&input.into()).await?;
Ok(feature.into())
}
}
#[derive(Debug, InputObject)]
#[graphql(concrete(name = "UpdateUrFeatureDto", params(UrFeatureConfig)))]
pub struct UpdateFeatureDto<T: FeatureConfig> {
pub id: i32,
pub name: String,
pub description: String,
pub config: T,
#[graphql(skip)]
pub user_id: Option<i32>,
}
impl<T: FeatureConfig> From<UpdateFeatureDto<T>> for UpdateFeature {
fn from(value: UpdateFeatureDto<T>) -> Self {
Self {
id: value.id,
name: value.name,
description: value.description,
config: serde_json::to_value(&value.config).expect("config is to_value failed"),
updater_id: value.user_id.expect("user_id must be set"),
}
}
}
#[derive(Debug, InputObject)]
#[graphql(concrete(name = "CreateUrFeatureDto", params(UrFeatureConfig)))]
pub struct CreateFeatureDto<T: FeatureConfig> {
#[graphql(skip)]
pub feature_type: Option<FeatureType>,
pub name: String,
pub description: String,
pub config: T,
#[graphql(skip)]
pub user_id: Option<i32>,
}
impl<T: FeatureConfig> From<CreateFeatureDto<T>> for CreateFeature {
fn from(value: CreateFeatureDto<T>) -> Self {
Self {
feature_type: value.feature_type.expect("feature_type must be set"),
name: value.name,
description: value.description,
config: serde_json::to_value(&value.config).expect("config is to_value failed"),
creator_id: value.user_id.expect("user_id must be set"),
}
}
}
impl<T: FeatureConfig> CreateFeatureDto<T> {
fn with_feature_type_and_user_id(mut self, feature_type: FeatureType, uid: i32) -> Self {
self.feature_type = Some(feature_type);
self.user_id = Some(uid);
self
}
}
#[derive(Debug, InputObject)]

View File

@ -0,0 +1,17 @@
use async_graphql::{InputObject, InputType, OutputType, SimpleObject};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::Value;
pub trait FeatureConfig: InputType + OutputType + Serialize + DeserializeOwned {}
impl FeatureConfig for Value {}
/// UR功能配置
#[derive(Debug, Clone, InputObject, SimpleObject, Serialize, Deserialize)]
#[graphql(input_name = "UrFeatureConfigInput")]
pub struct UrFeatureConfig {
/// 电子地图id
pub ems: Vec<i32>,
}
impl FeatureConfig for UrFeatureConfig {}

View File

@ -8,9 +8,10 @@ mod sys_info;
use simulation_definition::*;
use user::{UserMutation, UserQuery};
mod common;
mod data_options_def;
mod draft_data;
mod feature;
mod feature_config_def;
mod release_data;
mod simulation;
mod user;

View File

@ -14,7 +14,7 @@ use serde_json::Value;
use crate::apis::draft_data::DraftDataDto;
use crate::loader::RtssDbLoader;
use super::common::{DataOptions, IscsDataOptions};
use super::data_options_def::{DataOptions, IscsDataOptions};
use super::user::UserId;
use super::{PageDto, PageQueryDto};

View File

@ -10,6 +10,7 @@ use axum::{
};
use dataloader::DataLoader;
use http::{playground_source, GraphQLPlaygroundConfig};
use rtss_db::RtssDbAccessor;
use rtss_log::tracing::{debug, info};
use tokio::net::TcpListener;
use tower_http::cors::CorsLayer;
@ -47,7 +48,12 @@ impl ServerConfig {
}
pub async fn serve(config: ServerConfig) -> anyhow::Result<()> {
let schema = new_schema(config.clone()).await;
let client = config
.user_auth_client
.clone()
.expect("user auth client not configured");
let dba = rtss_db::get_db_accessor(&config.database_url).await;
let schema = new_schema(SchemaOptions::new(client, dba));
let app = Router::new()
.route("/", get(graphiql).post(graphql_handler))
@ -88,18 +94,52 @@ async fn graphiql() -> impl IntoResponse {
pub type RtssAppSchema = Schema<Query, Mutation, EmptySubscription>;
pub async fn new_schema(config: ServerConfig) -> RtssAppSchema {
let client = config
.user_auth_client
.expect("user auth client not configured");
let user_info_cache = crate::user_auth::UserAuthCache::new(client.clone());
let dba = rtss_db::get_db_accessor(&config.database_url).await;
let loader = RtssDbLoader::new(dba.clone());
pub struct SchemaOptions {
pub user_auth_client: UserAuthClient,
pub user_info_cache: user_auth::UserAuthCache,
pub rtss_dba: RtssDbAccessor,
}
impl SchemaOptions {
pub fn new(user_auth_client: UserAuthClient, rtss_dba: RtssDbAccessor) -> Self {
let user_info_cache = user_auth::UserAuthCache::new(user_auth_client.clone());
Self {
user_auth_client,
user_info_cache,
rtss_dba,
}
}
}
pub fn new_schema(options: SchemaOptions) -> RtssAppSchema {
let loader = RtssDbLoader::new(options.rtss_dba.clone());
Schema::build(Query::default(), Mutation::default(), EmptySubscription)
.data(client)
.data(user_info_cache)
.data(dba)
.data(options.user_auth_client)
.data(options.user_info_cache)
.data(options.rtss_dba)
.data(DataLoader::new(loader, tokio::spawn))
// .data(MutexSimulationManager::default())
.finish()
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_new_schema() {
let dba =
rtss_db::get_db_accessor("postgresql://joylink:Joylink@0503@localhost:5432/joylink")
.await;
let _ = new_schema(SchemaOptions::new(
crate::UserAuthClient {
base_url: "".to_string(),
login_url: "".to_string(),
logout_url: "".to_string(),
user_info_url: "".to_string(),
sync_user_url: "".to_string(),
},
dba,
));
}
}

View File

@ -276,9 +276,6 @@ impl UserAuthCache {
#[cfg(test)]
mod tests {
use anyhow::Ok;
use rtss_log::tracing::Level;
use super::*;
@ -302,25 +299,25 @@ mod tests {
println!("{:?}", dt);
}
#[tokio::test]
async fn test_user_auth_cache() -> anyhow::Result<()> {
rtss_log::Logging::default().with_level(Level::DEBUG).init();
let client = UserAuthClient {
base_url: "http://192.168.33.233/rtss-server".to_string(),
login_url: "/api/login".to_string(),
logout_url: "/api/login/logout".to_string(),
user_info_url: "/api/login/getUserInfo".to_string(),
sync_user_url: "/api/userinfo/list/all".to_string(),
};
let cache = UserAuthCache::new(client.clone());
let token = cache.client.login(LoginInfo::default()).await?;
let user = cache.query_user(&token).await?;
println!("token: {}, {:?}", token, user);
assert_eq!(cache.len(), 1);
// #[tokio::test]
// async fn test_user_auth_cache() -> anyhow::Result<()> {
// rtss_log::Logging::default().with_level(Level::DEBUG).init();
// let client = UserAuthClient {
// base_url: "http://192.168.33.233/rtss-server".to_string(),
// login_url: "/api/login".to_string(),
// logout_url: "/api/login/logout".to_string(),
// user_info_url: "/api/login/getUserInfo".to_string(),
// sync_user_url: "/api/userinfo/list/all".to_string(),
// };
// let cache = UserAuthCache::new(client.clone());
// let token = cache.client.login(LoginInfo::default()).await?;
// let user = cache.query_user(&token).await?;
// println!("token: {}, {:?}", token, user);
// assert_eq!(cache.len(), 1);
let user_list = client.query_all_users(&Token(token)).await?;
println!("{:?}", user_list);
// let user_list = client.query_all_users(&Token(token)).await?;
// println!("{:?}", user_list);
Ok(())
}
// Ok(())
// }
}

View File

@ -7,6 +7,8 @@ pub use user::*;
mod feature;
pub use feature::*;
use crate::{model::MqttClientIdSeq, DbAccessError};
#[derive(Clone)]
pub struct RtssDbAccessor {
pool: sqlx::PgPool,
@ -16,9 +18,37 @@ impl RtssDbAccessor {
pub fn new(pool: sqlx::PgPool) -> Self {
RtssDbAccessor { pool }
}
pub async fn get_next_mqtt_client_id(&self) -> Result<i64, DbAccessError> {
let seq_name = MqttClientIdSeq::Name.name();
let next = sqlx::query_scalar(&format!("SELECT nextval('{}')", seq_name))
.fetch_one(&self.pool)
.await?;
Ok(next)
}
}
pub async fn get_db_accessor(url: &str) -> RtssDbAccessor {
let pool = sqlx::PgPool::connect(url).await.expect("连接数据库失败");
RtssDbAccessor::new(pool)
}
#[cfg(test)]
mod tests {
use super::*;
use rtss_log::tracing::{self, Level};
use sqlx::PgPool;
// You could also do `use foo_crate::MIGRATOR` and just refer to it as `MIGRATOR` here.
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_get_mqtt_client_id(pool: PgPool) -> Result<(), DbAccessError> {
rtss_log::Logging::default().with_level(Level::DEBUG).init();
let accessor = crate::db_access::RtssDbAccessor::new(pool);
for _ in 0..10 {
let id = accessor.get_next_mqtt_client_id().await?;
tracing::info!("id = {}", id);
assert!(id > 0);
}
Ok(())
}
}

View File

@ -7,6 +7,18 @@ use sqlx::types::{
use crate::common::TableColumn;
pub enum MqttClientIdSeq {
Name,
}
impl MqttClientIdSeq {
pub fn name(&self) -> &str {
match self {
MqttClientIdSeq::Name => "rtss.mqtt_client_id_seq",
}
}
}
#[derive(Debug)]
pub enum UserColumn {
Table,
@ -101,7 +113,7 @@ pub struct ReleaseDataModel {
/// 数据库表 rtss.release_data_version 列映射
#[derive(Debug)]
pub(crate) enum ReleaseDataVersionColumn {
pub enum ReleaseDataVersionColumn {
Table,
Id,
ReleaseDataId,
@ -128,7 +140,7 @@ pub struct ReleaseDataVersionModel {
/// 数据库表 rtss.feature 列映射
#[derive(Debug)]
#[allow(dead_code)]
pub(crate) enum FeatureColumn {
pub enum FeatureColumn {
Table,
Id,
FeatureType,
@ -160,7 +172,7 @@ pub struct FeatureModel {
/// 数据库表 rtss.user_config 列映射
#[derive(Debug)]
#[allow(dead_code)]
pub(crate) enum UserConfigColumn {
pub enum UserConfigColumn {
Table,
Id,
UserId,

View File

@ -0,0 +1,15 @@
[package]
name = "rtss_mqtt"
version = "0.1.0"
edition = "2021"
[dependencies]
rumqttc = { version = "0.24.0", features = ["url"] }
tokio = { workspace = true }
async-trait = { workspace = true }
bytes = { workspace = true }
lazy_static = { workspace = true }
thiserror = { workspace = true }
rtss_db = { path = "../rtss_db" }
rtss_log = { path = "../rtss_log" }

View File

@ -0,0 +1,14 @@
use rumqttc::v5::ClientError;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum MqttClientError {
#[error("未知的Mqtt客户端错误")]
Unknown,
#[error("客户端已设置")]
AlreadySet,
#[error("rumqttc 错误: {0}")]
ClientError(#[from] ClientError),
#[error("全局客户端未设置")]
NoClient,
}

612
crates/rtss_mqtt/src/lib.rs Normal file
View File

@ -0,0 +1,612 @@
use std::{
any::TypeId,
collections::HashMap,
sync::{
atomic::{AtomicU64, Ordering},
Arc, Mutex,
},
task::Waker,
time::Duration,
};
use bytes::Bytes;
use lazy_static::lazy_static;
use rtss_log::tracing::{debug, error, info};
use rumqttc::{
v5::{
mqttbytes::{
v5::{Packet, Publish, PublishProperties},
QoS,
},
AsyncClient, Event, EventLoop, MqttOptions,
},
Outgoing,
};
use tokio::{sync::oneshot, time::timeout};
mod error;
use error::MqttClientError;
lazy_static! {
/// 全局静态MqttClient实例
static ref MQTT_CLIENT: tokio::sync::Mutex<Option<MqttClient>> = tokio::sync::Mutex::new(None);
}
/// 设置全局MqttClient实例
pub async fn set_global_mqtt_client(client: MqttClient) -> Result<(), MqttClientError> {
let mut mqtt_client = MQTT_CLIENT.lock().await;
if mqtt_client.is_some() {
return Err(MqttClientError::AlreadySet);
}
*mqtt_client = Some(client);
Ok(())
}
pub async fn get_global_mqtt_client() -> Option<MqttClient> {
let mqtt_client = MQTT_CLIENT.lock().await;
mqtt_client.clone()
}
pub struct MqttClientOptions {
id: String,
options: MqttOptions,
request_timeout: Duration,
}
impl MqttClientOptions {
pub fn new(id: &str, url: &str) -> Self {
Self {
id: id.to_string(),
options: MqttOptions::parse_url(format!("{}?client_id={}", url, id))
.expect("解析mqtt url失败"),
request_timeout: Duration::from_secs(5),
}
}
pub fn set_request_timeout(&mut self, timeout: Duration) -> &mut Self {
self.request_timeout = timeout;
self
}
pub fn set_credentials(&mut self, username: &str, password: &str) -> &mut Self {
self.options.set_credentials(username, password);
self
}
pub fn build(&mut self) -> MqttClient {
self.options.set_keep_alive(Duration::from_secs(10));
let (client, eventloop) = AsyncClient::new(self.options.clone(), 10);
let subscriptions = SubscribeHandlerMap::new();
let loop_sub = subscriptions.clone();
tokio::spawn(async move {
MqttClient::handle_connection_loop(eventloop, loop_sub).await;
});
MqttClient {
id: self.id.clone(),
request_timeout: self.request_timeout,
client,
request_id: Arc::new(AtomicU64::new(0)),
subscriptions,
}
}
}
/// MQTT客户端
/// id: 客户端ID,从数据库的id序列中获取
/// 客户端具有的功能:
/// 1. 启动
/// 2. 订阅
/// 3. 发布
/// 4. 实现类似http的请求相应功能
/// 5. 断开连接
#[derive(Clone)]
pub struct MqttClient {
id: String,
request_timeout: Duration,
client: AsyncClient,
request_id: Arc<AtomicU64>,
subscriptions: SubscribeHandlerMap,
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub struct HandlerId(TypeId);
#[derive(Clone)]
pub struct SubscribeHandlerMap {
sub_handlers: Arc<Mutex<HashMap<String, MessageHandlerMap>>>,
}
impl SubscribeHandlerMap {
fn new() -> Self {
Self {
sub_handlers: Arc::new(Mutex::new(HashMap::new())),
}
}
fn insert(
&self,
topic: &str,
handler_id: HandlerId,
handler: Arc<dyn MessageHandler>,
) -> HandlerId {
self.sub_handlers
.lock()
.unwrap()
.entry(topic.to_string())
.or_insert_with(MessageHandlerMap::new)
.insert(handler_id, handler);
handler_id
}
fn remove(&self, topic: &str, handler_id: HandlerId) {
if let Some(topic_handlers) = self.sub_handlers.lock().unwrap().get_mut(topic) {
topic_handlers.remove(handler_id);
}
}
#[allow(dead_code)]
fn remove_all(&self, topic: &str) {
if let Some(topic_handlers) = self.sub_handlers.lock().unwrap().get_mut(topic) {
topic_handlers.remove_all();
}
}
fn get_handlers(&self, topic: &str) -> Option<Vec<Arc<dyn MessageHandler>>> {
if let Some(topic_handlers) = self.sub_handlers.lock().unwrap().get(topic) {
Some(topic_handlers.values())
} else {
None
}
}
#[allow(dead_code)]
fn get_mut(&self, topic: &str) -> Option<MessageHandlerMap> {
self.sub_handlers.lock().unwrap().get(topic).cloned()
}
#[allow(dead_code)]
fn is_topic_empty(&self, topic: &str) -> bool {
if let Some(topic_handlers) = self.sub_handlers.lock().unwrap().get(topic) {
topic_handlers.is_empty()
} else {
true
}
}
#[allow(dead_code)]
fn is_empty(&self) -> bool {
self.sub_handlers.lock().unwrap().is_empty()
}
fn clear(&self) {
self.sub_handlers.lock().unwrap().clear();
}
}
#[derive(Clone)]
struct MessageHandlerMap {
handlers: Arc<Mutex<HashMap<HandlerId, Arc<dyn MessageHandler>>>>,
}
impl MessageHandlerMap {
fn new() -> Self {
Self {
handlers: Arc::new(Mutex::new(HashMap::new())),
}
}
fn insert(&self, handler_id: HandlerId, handler: Arc<dyn MessageHandler>) {
self.handlers.lock().unwrap().insert(handler_id, handler);
}
/// 移除处理器,返回剩余处理器数量
fn remove(&self, handler_id: HandlerId) -> Option<Arc<dyn MessageHandler>> {
self.handlers.lock().unwrap().remove(&handler_id)
}
#[allow(dead_code)]
fn remove_all(&self) {
self.handlers.lock().unwrap().clear();
}
fn values(&self) -> Vec<Arc<dyn MessageHandler>> {
self.handlers.lock().unwrap().values().cloned().collect()
}
#[allow(dead_code)]
fn is_empty(&self) -> bool {
self.handlers.lock().unwrap().is_empty()
}
}
#[must_use = "this `SubscribeTopicHandler` implements `Drop`, which will unregister the handler"]
#[derive(Clone)]
pub struct SubscribeTopicHandler {
topic: String,
handler_id: HandlerId,
handler_map: SubscribeHandlerMap,
}
impl SubscribeTopicHandler {
pub fn new(topic: &str, handler_id: HandlerId, handler_map: SubscribeHandlerMap) -> Self {
Self {
topic: topic.to_string(),
handler_id,
handler_map,
}
}
pub fn unregister(&self) {
self.handler_map.remove(&self.topic, self.handler_id);
}
}
/// 订阅消息处理器
#[async_trait::async_trait]
pub trait MessageHandler: Send + Sync {
async fn handle(&self, publish: Publish);
}
/// 为闭包实现消息处理器
#[async_trait::async_trait]
impl<F> MessageHandler for F
where
F: Fn(Publish) + Sync + Send,
{
async fn handle(&self, publish: Publish) {
self(publish);
}
}
impl MqttClient {
pub async fn close(&self) -> Result<(), MqttClientError> {
self.client.disconnect().await?;
// 清空订阅处理器
self.subscriptions.clear();
Ok(())
}
pub fn id(&self) -> &str {
&self.id
}
pub async fn subscribe(&self, topic: &str, qos: QoS) -> Result<(), MqttClientError> {
self.client.subscribe(topic, qos).await?;
Ok(())
}
pub async fn unsubscribe(&self, topic: &str) -> Result<(), MqttClientError> {
self.client.unsubscribe(topic).await?;
Ok(())
}
pub fn register_topic_handler<H>(&self, topic: &str, handler: H) -> SubscribeTopicHandler
where
H: MessageHandler + 'static,
{
let handler_id = HandlerId(TypeId::of::<H>());
self.subscriptions
.insert(topic, handler_id, Arc::new(handler));
SubscribeTopicHandler::new(topic, handler_id, self.subscriptions.clone())
}
pub fn unregister_topic(&self, topic: &str) {
self.subscriptions.remove_all(topic);
}
pub fn topic_handler_count(&self, topic: &str) -> usize {
if let Some(topic_handlers) = self.subscriptions.get_handlers(topic) {
topic_handlers.len()
} else {
0
}
}
pub async fn publish(
&self,
topic: &str,
qos: QoS,
payload: Vec<u8>,
) -> Result<(), MqttClientError> {
self.client.publish(topic, qos, false, payload).await?;
Ok(())
}
pub fn next_request_id(&self) -> u64 {
self.request_id.fetch_add(1, Ordering::Relaxed)
}
/// 发送请求并等待响应
pub async fn request(
&self,
topic: &str,
qos: QoS,
payload: Vec<u8>,
) -> Result<MqttResponse, MqttClientError> {
// 订阅响应主题
let response_topic = format!("{}/{}/resp/{}", self.id, topic, self.next_request_id());
self.subscribe(&response_topic, QoS::ExactlyOnce).await?;
// 创建请求future
let response_future = MqttResponseFuture::new(&response_topic, self.request_timeout);
// 注册响应处理器
let response_handler =
self.register_topic_handler(&response_topic, response_future.clone());
// 发布请求
let property = PublishProperties {
response_topic: Some(response_topic.clone().into()),
..Default::default()
};
self.client
.publish_with_properties(topic, qos, false, payload, property)
.await?;
// 等待响应
let resp = response_future.await;
// 注销响应处理器并取消订阅
response_handler.unregister();
self.unsubscribe(&response_topic).await?;
Ok(resp)
}
async fn handle_connection_loop(mut eventloop: EventLoop, subscriptions: SubscribeHandlerMap) {
while let Ok(notification) = eventloop.poll().await {
match notification {
Event::Incoming(Packet::Publish(publish)) => {
debug!("Received message: {:?}", publish);
let topic: String = String::from_utf8_lossy(&publish.topic).to_string();
if let Some(topic_handlers) = subscriptions.get_handlers(&topic) {
for handler in topic_handlers {
let handler = handler.clone();
let p = publish.clone();
tokio::spawn(async move {
handler.handle(p).await;
});
}
}
}
Event::Outgoing(Outgoing::Disconnect) => {
info!("Disconnected to the broker");
break;
}
Event::Incoming(Packet::Disconnect(disconnect)) => {
info!("Disconnected from the broker: {:?}", disconnect);
break;
}
_ => {
debug!("Unhandled event: {:?}", notification);
}
}
}
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum MqttResponseState {
Waiting,
Received,
Timeout,
}
/// MQTT请求响应
#[derive(Clone, Debug)]
pub struct MqttResponse {
state: Arc<Mutex<MqttResponseState>>,
response: Arc<Mutex<Bytes>>,
}
impl MqttResponse {
pub fn new() -> Self {
Self {
state: Arc::new(Mutex::new(MqttResponseState::Waiting)),
response: Arc::new(Mutex::new(Bytes::new())),
}
}
pub fn is_waiting(&self) -> bool {
*self.state.lock().unwrap() == MqttResponseState::Waiting
}
pub fn is_received(&self) -> bool {
*self.state.lock().unwrap() == MqttResponseState::Received
}
pub fn is_timeout(&self) -> bool {
*self.state.lock().unwrap() == MqttResponseState::Timeout
}
pub fn set_timeout(&self) {
*self.state.lock().unwrap() = MqttResponseState::Timeout;
}
pub fn set(&self, response: Bytes) {
*self.state.lock().unwrap() = MqttResponseState::Received;
*self.response.lock().unwrap() = response;
}
pub fn get(&self) -> Bytes {
self.response.lock().unwrap().clone()
}
}
/// MQTT响应Future
#[derive(Clone)]
pub struct MqttResponseFuture {
pub start_time: std::time::Instant,
timeout: Duration,
tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
waker: Arc<Mutex<Option<Waker>>>,
response_topic: String,
response: MqttResponse,
}
impl MqttResponseFuture {
pub fn new(response_topic: &str, timeout: Duration) -> Self {
let (tx, rx) = oneshot::channel();
let r = Self {
start_time: std::time::Instant::now(),
timeout,
tx: Arc::new(Mutex::new(Some(tx))),
waker: Arc::new(Mutex::new(None)),
response_topic: response_topic.to_string(),
response: MqttResponse::new(),
};
// 启动超时检查
r.start_timeout_monitor(rx);
r
}
/// 启动超时监控任务逻辑
fn start_timeout_monitor(&self, rx: oneshot::Receiver<()>) {
let response = self.response.clone();
let response_topic = self.response_topic.clone();
let duration = self.timeout.clone();
let waker = self.waker.clone();
tokio::spawn(async move {
if let Err(_) = timeout(duration, rx).await {
error!("Mqtt response timeout: {:?}", response_topic);
response.set_timeout();
if let Some(waker) = waker.lock().unwrap().take() {
waker.wake();
}
}
});
}
}
#[async_trait::async_trait]
impl MessageHandler for MqttResponseFuture {
async fn handle(&self, publish: Publish) {
if publish.topic == self.response_topic {
self.response.set(publish.payload);
if let Some(tx) = self.tx.lock().unwrap().take() {
tx.send(())
.expect("Send Mqtt response timeout signal failed");
}
if let Some(waker) = self.waker.lock().unwrap().take() {
waker.wake();
}
}
}
}
impl std::future::Future for MqttResponseFuture {
type Output = MqttResponse;
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
if self.response.is_waiting() {
debug!("Response future poll waiting...");
self.waker.lock().unwrap().replace(cx.waker().clone());
std::task::Poll::Pending
} else {
debug!("Response future poll ready: {:?}", self.response.get());
std::task::Poll::Ready(self.response.clone())
}
}
}
pub fn get_publish_response_topic(publish: Option<PublishProperties>) -> Option<String> {
publish.map(|p| p.response_topic.clone()).flatten()
}
#[cfg(test)]
mod tests {
use super::*;
use rtss_log::tracing::{info, Level};
use tokio::time::{sleep, Duration};
#[tokio::test]
async fn test_subscribe_and_publish() {
rtss_log::Logging::default().with_level(Level::DEBUG).init();
let client = MqttClientOptions::new("rtss_test1", "tcp://localhost:1883")
.set_credentials("rtss_simulation", "Joylink@0503")
.build();
client
.subscribe("test/topic", QoS::AtMostOnce)
.await
.unwrap();
let handler1 = client.register_topic_handler("test/topic", |publish: Publish| {
info!(
"Handler 1 received: topic={}, payload={:?}",
String::from_utf8_lossy(&publish.topic),
String::from_utf8_lossy(&publish.payload)
);
});
let h2 = client.register_topic_handler("test/topic", |publish: Publish| {
info!(
"Handler 2 received: topic={}, payload={:?}",
String::from_utf8_lossy(&publish.topic),
String::from_utf8_lossy(&publish.payload)
);
});
assert_eq!(client.topic_handler_count("test/topic"), 2);
client
.publish("test/topic", QoS::AtMostOnce, b"Hello, MQTT!".to_vec())
.await
.unwrap();
// Wait for a moment to allow handlers to process the message
sleep(Duration::from_millis(200)).await;
// Test remove_handler
client.unsubscribe("test/topic").await.unwrap();
handler1.unregister();
assert_eq!(client.topic_handler_count("test/topic"), 1);
h2.unregister();
assert_eq!(client.topic_handler_count("test/topic"), 0);
// Test unsubscribe
client.close().await.unwrap();
}
#[tokio::test]
async fn test_request() {
rtss_log::Logging::default().with_level(Level::DEBUG).init();
let client = MqttClientOptions::new("rtss_test1", "tcp://localhost:1883")
.set_credentials("rtss_simulation", "Joylink@0503")
.build();
set_global_mqtt_client(client.clone()).await.unwrap();
if let Some(c) = get_global_mqtt_client().await {
c.subscribe("test/request", QoS::AtMostOnce).await.unwrap();
let handler = |p: Publish| {
info!(
"Request handler received: topic={}, payload={:?}",
String::from_utf8_lossy(&p.topic),
String::from_utf8_lossy(&p.payload)
);
let response = Bytes::from("Hello, response!");
let resp_topic = get_publish_response_topic(p.properties.clone());
if let Some(r_topic) = resp_topic {
tokio::spawn(async move {
if let Some(c) = get_global_mqtt_client().await {
c.publish(&r_topic, QoS::AtMostOnce, response.to_vec())
.await
.unwrap();
}
});
}
};
let _ = c.register_topic_handler("test/request", handler);
}
if let Some(c) = get_global_mqtt_client().await {
let response = c
.request("test/request", QoS::AtMostOnce, b"Hello, request!".to_vec())
.await
.unwrap();
info!("Request response: {:?}", response);
}
client.close().await.unwrap();
}
}

View File

@ -1,6 +1,9 @@
-- 初始化数据库SCHEMA(所有轨道交通信号系统仿真的表、类型等都在rtss SCHEMA下)
CREATE SCHEMA rtss;
-- 创建mqtt客户端id序列
CREATE SEQUENCE rtss.mqtt_client_id_seq;
-- 创建用户表
CREATE TABLE
rtss.user (