diff --git a/crates/rtss_api/src/apis/draft_data.rs b/crates/rtss_api/src/apis/draft_data.rs index 6285fd9..7175ec0 100644 --- a/crates/rtss_api/src/apis/draft_data.rs +++ b/crates/rtss_api/src/apis/draft_data.rs @@ -1,4 +1,5 @@ -use async_graphql::{Context, InputObject, Object, SimpleObject}; +use async_graphql::dataloader::DataLoader; +use async_graphql::{ComplexObject, Context, InputObject, Object, SimpleObject}; use base64::prelude::*; use base64::Engine; use chrono::NaiveDateTime; @@ -10,7 +11,8 @@ use serde_json::Value; use crate::apis::PageQueryDto; use super::common::{DataOptions, IscsDataOptions}; -use super::PageDto; +use super::release_data::ReleaseDataId; +use super::{PageDto, RtssDbLoader}; #[derive(Default)] pub struct DraftDataQuery; @@ -256,6 +258,7 @@ impl From for rtss_db::DraftDataQuery { } #[derive(Debug, SimpleObject)] +#[graphql(complex)] pub struct DraftDataDto { pub id: i32, pub name: String, @@ -270,6 +273,22 @@ pub struct DraftDataDto { pub updated_at: NaiveDateTime, } +#[ComplexObject] +impl DraftDataDto { + async fn default_release_data_name( + &self, + ctx: &Context<'_>, + ) -> async_graphql::Result> { + if let Some(version_id) = self.default_release_data_id { + let loader = ctx.data_unchecked::>(); + let name = loader.load_one(ReleaseDataId::new(version_id)).await?; + Ok(name) + } else { + Ok(None) + } + } +} + impl From for DraftDataDto { fn from(value: rtss_db::model::DraftDataModel) -> Self { Self { diff --git a/crates/rtss_api/src/apis/release_data.rs b/crates/rtss_api/src/apis/release_data.rs index fba7c68..0aa82fc 100644 --- a/crates/rtss_api/src/apis/release_data.rs +++ b/crates/rtss_api/src/apis/release_data.rs @@ -245,6 +245,38 @@ impl ReleaseDataDto { } } +#[derive(Clone, Copy, Hash, PartialEq, Eq)] +pub struct ReleaseDataId { + pub id: i32, +} + +impl ReleaseDataId { + pub fn new(id: i32) -> Self { + Self { id } + } +} + +impl Loader for RtssDbLoader { + type Value = String; + type Error = Arc; + + async fn load( + &self, + keys: &[ReleaseDataId], + ) -> Result, Self::Error> { + let ids: Vec = keys.iter().map(|k| k.id).collect(); + let rows = self + .db_accessor + .query_release_data_names(ids.as_slice()) + .await?; + let map: HashMap = rows + .into_iter() + .map(|r| (ReleaseDataId { id: r.0 }, r.1)) + .collect(); + Ok(map) + } +} + #[derive(Clone, Copy, Hash, PartialEq, Eq)] pub struct ReleaseDataVersionId { pub id: i32, diff --git a/crates/rtss_db/src/db_access/release_data.rs b/crates/rtss_db/src/db_access/release_data.rs index 329ecc4..52dfc1b 100644 --- a/crates/rtss_db/src/db_access/release_data.rs +++ b/crates/rtss_db/src/db_access/release_data.rs @@ -45,6 +45,11 @@ pub trait ReleaseDataAccessor { &self, release_id: i32, ) -> Result; + /// 根据id列表查询发布数据name + async fn query_release_data_names( + &self, + release_ids: &[i32], + ) -> Result, DbAccessError>; /// 查询发布数据所有版本信息 async fn query_release_data_version_list( &self, @@ -422,6 +427,24 @@ impl ReleaseDataAccessor for RtssDbAccessor { Ok(rd) } + async fn query_release_data_names( + &self, + release_ids: &[i32], + ) -> Result, DbAccessError> { + // 查询发布数据 + let rd_table = ReleaseDataColumn::Table.name(); + let rd_id = ReleaseDataColumn::Id.name(); + let rd_name = ReleaseDataColumn::Name.name(); + let select_columns = format!("{rd_id}, {rd_name}"); + let rd_query_clause = + format!("SELECT {select_columns} FROM {rd_table} WHERE {rd_id} = ANY($1)",); + let rd = sqlx::query_as::<_, (i32, String)>(&rd_query_clause) + .bind(release_ids) + .fetch_all(&self.pool) + .await?; + Ok(rd) + } + async fn query_release_data_version_list( &self, release_id: i32, @@ -841,16 +864,25 @@ mod tests { assert_eq!(page_result.total, 8); println!("分页查询发布数据测试成功"); + // 测试根据id列表查询发布数据name + let release_ids: Vec = page_result.data.iter().map(|d| d.id).collect(); + let names = accessor + .query_release_data_names(release_ids.as_slice()) + .await?; + println!("{:?}", names); + assert_eq!(names.len(), page_result.data.len()); + // 测试根据数据版本id查询descriptions let version_ids: Vec = page_result .data - .into_iter() + .iter() .map(|d| d.used_version_id.unwrap()) .collect(); let description_map = accessor .query_release_data_version_descriptions(version_ids.as_slice()) .await?; println!("{:?}", description_map); + assert_eq!(description_map.len(), page_result.data.len()); Ok(()) }