简易抽奖系统学习
代码如下
mod test;
use std::collections::HashMap;
use std::fmt;
use std::fmt::Formatter;
use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
use rand::distr::Distribution;
use rand::distr::weighted::WeightedIndex;
use rand::Rng;
use serde::{Serialize, Serializer};
use serde::ser::SerializeStruct;
use thiserror::Error;
fn main() -> Result<(), LotteryError> {
// 创建并发安全的加权抽奖系统
let lottery = ConcurrentLottery::new(WeightRandom);
// 添加参与者
lottery.add_participant(
"user1".to_string(),
Participant {
name: "Alice".to_string(),
email: "alice@example.com".to_string(),
},
3, // 高权重
)?;
lottery.add_participant(
"user2".to_string(),
Participant {
name: "Bob".to_string(),
email: "bob@example.com".to_string(),
},
1, // 低权重
)?;
// 单次抽奖
println!("单次抽奖结果: {:?}", lottery.draw()?);
// 多次抽奖(无放回)
println!("两次抽奖结果: {:?}", lottery.draw_multiple(2)?);
// 更新权重
lottery.update_weight("user1", 5)?;
println!("更新权重后抽奖: {:?}", lottery.draw()?);
Ok(())
// 单次抽奖结果: Participant { name: "Bob", email: "bob@example.com" }
// 两次抽奖结果: [Participant { name: "Alice", email: "alice@example.com" }, Participant { name: "Bob", email: "bob@example.com" }]
// 更新权重后抽奖: Participant { name: "Alice", email: "alice@example.com" }
}
//核心设计思路:
// 参与者管理:支持添加/移除参与者,可配置权重
// 抽奖策略:
// 简单随机抽取
// 加权随机抽取
// 不放回抽取(多次抽奖)
// 并发安全:使用RwLock处理并发访问
// 可扩展性:通过trait支持自定义抽奖算法
// 错误类型
#[derive(Debug, Error)]
pub enum LotteryError {
#[error("没有可用的参与者")]
NoParticipants,
#[error("无效权重值:{0} (必须大于0)")]
InvalidWeight(u32),
#[error("找不到指定参与者")]
ParticipantNotFound,
#[error("权重分布错误,参与者与权重数量不匹配")]
NotEnoughParticipants,
}
/// 抽奖算法 trait
/// 定义统一的抽奖接口,支持不同类型参与者的抽奖
pub trait LotteryAlgorithm<T>: Send + Sync {
/// 执行单次抽奖
/// 参与:
/// - participants: 参与者列表
/// - weights: 对应权重列表
/// 返回: 抽中的参与者或错误
fn draw(&self, participants: &[T], weights: &[u32]) -> Result<T, LotteryError>;
}
/// 简单随机抽奖算法
/// 所有参与者等概率中奖
pub struct SimpleRandom;
impl<T: Clone> LotteryAlgorithm<T> for SimpleRandom {
fn draw(&self, participants: &[T], weights: &[u32]) -> Result<T, LotteryError> {
if participants.is_empty() {
return Err(LotteryError::NoParticipants);
}
let mut rng = rand::rng();
// 生成0到参与者数量之间的随机索引
let index = rng.random_range(0..participants.len());
// 返回克隆的参与者 实例
Ok(participants[index].clone())
}
}
/// 加权随机抽奖算法
/// 参与者权重越高,中奖概率越大
pub struct WeightRandom;
impl<T: Clone> LotteryAlgorithm<T> for WeightRandom {
fn draw(&self, participants: &[T], weights: &[u32]) -> Result<T, LotteryError> {
if participants.is_empty() {
return Err(LotteryError::NoParticipants);
}
// 验证参与者与权重数量匹配
if participants.len() != weights.len() {
return Err(LotteryError::NotEnoughParticipants);
}
// 创建加权索引分布
let dist = WeightedIndex::new(weights).map_err(|_| LotteryError::InvalidWeight(0))?;
let mut rng = rand::rng();
// 根据权重分布随机选择索引
let index = dist.sample(&mut rng);
Ok(participants[index].clone())
}
}
/// 抽奖系统核心结构
/// T: 参与者类型,需实现Clone和PartialEq
// #[derive(Debug, Serialize)] LotterySystem结构体尝试实现Serialize特性时,其中包含的Arc<dyn LotteryAlgorithm<T>>字段无法被序列化引起的。让我们修复这个问题:
pub struct LotterySystem<T> {
// 参与者存储:ID -> 参与者对象
participants: HashMap<String, T>,
// 权重存储: ID -> 权重值
weights: HashMap<String, u32>,
/// 抽奖算法实现,使用Arc实现共享所有权
algorithm: Arc<dyn LotteryAlgorithm<T>>,
}
// 手动实现Debug,避免算法字段的Debug约束
impl<T: fmt::Debug> fmt::Debug for LotterySystem<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("LotterySystem")
.field("participants", &self.participants)
.field("weights", &self.weights)
.field("algorithm", &"[dynamic algorithm]") // 避免尝试调试trait对象
.finish()
}
}
// 手动实现序列化,仅序列化数据和状态
impl<T: Serialize> Serialize for LotterySystem<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer
{
// 创建包含2个字段的结构体
let mut state = serializer.serialize_struct("LotterySystem", 2)?;
state.serialize_field("participants", &self.participants)?;
state.serialize_field("weights", &self.weights)?;
state.end()
}
}
impl<T: Clone + PartialEq> LotterySystem<T> {
/// 创建新的抽奖系统
/// 参数:
/// - algorithm: 抽奖算法实现
pub fn new(algorithm: impl LotteryAlgorithm<T> + 'static) -> Self {
Self {
participants: HashMap::new(),
weights: HashMap::new(),
algorithm: Arc::new(algorithm),
}
}
/// 添加参与者
/// 参数:
/// - id: 参与者唯一ID
/// - participant: 参与者对象
/// - weight: 权重值 (必须 > 0)
pub fn add_participant(&mut self, id: String, participant: T, weight: u32) -> Result<(), LotteryError> {
if weight == 0 {
return Err(LotteryError::InvalidWeight(weight));
}
// 插入参与者和对应权重
self.participants.insert(id.clone(), participant); // 这里id需要 clone, 不然后面就用不了了.
self.weights.insert(id, weight);
Ok(())
}
/// 移除参与者
/// 参数:
/// - id: 要移除的参与者ID
pub fn remove_participant(&mut self, id: &str) -> Result<(), LotteryError> {
// 从参与者和权重映射中同时移除
self.participants
.remove(id)
.ok_or(LotteryError::ParticipantNotFound)?;
self.weights.remove(id);
Ok(())
}
/// 更新参与者权重
/// 参数:
/// - id: 参与者ID
/// - weight: 新的权重值 (必须 > 0)
pub fn update_weight(&mut self, id: &str, weight: u32) -> Result<(), LotteryError> {
if weight == 0 {
return Err(LotteryError::InvalidWeight(weight));
}
// 检查参与者是否存在
if self.weights.contains_key(id) {
self.weights.insert(id.to_string(), weight);
Ok(())
} else {
Err(LotteryError::ParticipantNotFound)
}
}
/// 执行单次抽奖
pub fn draw(&self) -> Result<T, LotteryError> {
// 准备数据:(id、参与者、权重)
let data: Vec<(&String, &T, u32)> = self.participants.iter().map(|(id, p)| (id, p, *self.weights.get(id).unwrap_or(&1)))
.collect();
// 检查是否有参与者
if data.is_empty() {
return Err(LotteryError::NoParticipants);
}
// 分离出参与者和权重列表
let (participants, weights): (Vec<_>, Vec<_>) = data.into_iter().map(|(_, p, w)| (p.clone(), w)).unzip();
self.algorithm.draw(&participants, &weights)
}
/// 执行多次抽奖(无放回)
/// 参数:
/// - count: 要抽取的数量
pub fn draw_multiple(&self, count: usize) -> Result<Vec<T>, LotteryError> {
// 检查是否有足够参与者
if count > self.participants.len() {
return Err(LotteryError::NotEnoughParticipants);
}
// 克隆当前系统状态作为临时工作副本
let mut temp_system = self.clone();
let mut results = Vec::with_capacity(count);
for _ in 0..count {
// 执行单次抽奖
let winner = temp_system.draw()?;
// 查找获奖者id
let winner_id = temp_system.participants.iter()
.find(|(_, p)| **p == winner)
.map(|(id, _)| id.clone())
.ok_or(LotteryError::NotEnoughParticipants)?;
// 移除获奖者,确保后续不会再次中奖
temp_system.remove_participant(&winner_id)?;
results.push(winner);
}
Ok(results)
}
}
impl<T:Clone> Clone for LotterySystem<T> {
fn clone(&self) -> Self {
Self{
participants: self.participants.clone(),
weights: self.weights.clone(),
algorithm: self.algorithm.clone(),
}
}
}
/// 并发安全的抽奖系统包装
/// 使用读写锁保证线程安全
#[derive(Clone)]
pub struct ConcurrentLottery<T>(Arc<RwLock<LotterySystem<T>>>);
impl<T: Clone + PartialEq> ConcurrentLottery<T> {
/// 创建新的并发抽奖系统
pub fn new(algorithm: impl LotteryAlgorithm<T> + 'static) -> Self {
Self(Arc::new(RwLock::new(LotterySystem::new(algorithm))))
}
/// 获取读锁
fn read(&self) -> RwLockReadGuard<LotterySystem<T>> {
self.0.read().unwrap()
}
/// 获取写锁
fn write(&self) -> RwLockWriteGuard<LotterySystem<T>> {
self.0.write().unwrap()
}
/// 添加参与者(线程安全)
pub fn add_participant(&self, id: String, participant: T, weight: u32)
-> Result<(), LotteryError> {
self.write().add_participant(id, participant, weight)
}
/// 移除参与者(线程安全)
pub fn remove_participant(&self, id: &str) -> Result<(), LotteryError> {
self.write().remove_participant(id)
}
/// 更新权重(线程安全)
pub fn update_weight(&self, id: &str, weight: u32) -> Result<(), LotteryError> {
self.write().update_weight(id, weight)
}
/// 执行单次抽奖(线程安全)
pub fn draw(&self) -> Result<T, LotteryError> {
self.read().draw()
}
/// 执行多次抽奖(线程安全)
pub fn draw_multiple(&self, count: usize) -> Result<Vec<T>, LotteryError> {
// 注意:这里使用读锁,但内部会克隆系统状态
self.read().draw_multiple(count)
}
}
/// 示例参与者结构
#[derive(Debug, Clone, PartialEq, Serialize)]
pub struct Participant {
pub name: String,
pub email: String,
}
测试用例
#[cfg(test)]
mod tests {
use crate::{ConcurrentLottery, LotterySystem, SimpleRandom, WeightRandom};
use super::*;
/// 测试简单随机抽奖
#[test]
fn test_simple_random() {
let mut lottery = LotterySystem::new(SimpleRandom);
lottery.add_participant("1".to_string(), "Alice".to_string(), 1).unwrap();
lottery.add_participant("2".to_string(), "Bob".to_string(), 1).unwrap();
let winner = lottery.draw().unwrap();
assert!(winner == "Alice" || winner == "Bob");
}
/// 测试加权随机抽奖
#[test]
fn test_weighted_random() {
let mut lottery = LotterySystem::new(WeightRandom);
lottery.add_participant("1".to_string(), "Alice".to_string(), 3).unwrap();
lottery.add_participant("2".to_string(), "Bob".to_string(), 1).unwrap();
// 统计中奖次数(概率上Alice应该多于Bob)
let mut alice_wins = 0;
for _ in 0..1000 {
let winner = lottery.draw().unwrap();
if winner == "Alice" {
alice_wins += 1;
}
}
assert!(alice_wins > 600); // 3:1的权重,Alice应该大约占75%
}
/// 测试多次抽奖(无放回)
#[test]
fn test_draw_multiple() {
let mut lottery = LotterySystem::new(SimpleRandom);
lottery.add_participant("1".to_string(), "A".to_string(), 1).unwrap();
lottery.add_participant("2".to_string(), "B".to_string(), 1).unwrap();
lottery.add_participant("3".to_string(), "C".to_string(), 1).unwrap();
let winners = lottery.draw_multiple(3).unwrap();
assert_eq!(winners.len(), 3);
assert_ne!(winners[0], winners[1]);
assert_ne!(winners[0], winners[2]);
assert_ne!(winners[1], winners[2]);
}
/// 测试并发安全性
#[test]
fn test_concurrent_access() {
use std::thread;
let lottery = ConcurrentLottery::new(SimpleRandom);
lottery.add_participant("1".to_string(), "Alice".to_string(), 1).unwrap();
let handles: Vec<_> = (0..10)
.map(|_| {
let lottery = lottery.clone();
thread::spawn(move || lottery.draw().unwrap())
})
.collect();
for handle in handles {
assert!(handle.join().unwrap() == "Alice");
}
}
}
impl<T: Clone + PartialEq>
泛型约束的组合(类似 Java 的多个接口约束)
Clone:允许复制值(类似 Java 的 clone())
PartialEq:允许部分相等比较(类似 Java 的 equals())
pub trait LotteryAlgorithm
: Send + Sync {
Send:表示类型可以安全地跨线程传递(转移所有权)
Sync:表示类型可以安全地在多线程间共享引用
trait 类似于 Java 的 interface
关键区别: Rust 的 trait 支持关联类型(类似于泛型接口),可以用于运算符重载(如 Add trait 重载 + 运算符)
pub fn new(algorithm: impl LotteryAlgorithm
+ ‘static) -> Self
'static 表示”整个程序运行期间都有效”的生命周期
这里要求传入的 algorithm 必须满足:
要么是拥有所有权的值(不包含任何引用)
要么包含的引用必须是 'static 的(全局有效)
pub fn new(algorithm: impl LotteryAlgorithm<T> + 'static) -> Self {
Self {
participants: HashMap::new(),
weights: HashMap::new(),
algorithm: Box::new(algorithm),
}
}
为什么需要 ‘static
想象抽奖系统的使用场景:
fn main() { let lottery = LotterySystem::new(WeightedRandom);
// 假设这里启动新线程使用抽奖系统
std::thread::spawn(move || {
lottery.draw();
});
}
如果不加 ‘static 约束: 算法可能包含对局部变量的引用 当主线程结束时,这些引用会失效 新线程可能访问到已释放的内存 → 崩溃!
加上 ‘static 后: 编译器保证算法不依赖短生命周期的数据 可以安全地在多线程间传递
Rust 的 ‘static 相当于在编译时防止了这种情况:
fn main() {
let config = LocalConfig::new(); // 局部变量
// 编译错误!因为算法依赖局部变量config
// 不满足 'static 约束
let lottery = LotterySystem::new(WeightedRandom::new(&config));
}
‘static 的两种形式
在 Rust 中,‘static 有两种实现方式:
(1) 拥有所有权的类型
// SimpleRandom 不包含任何引用 → 自动满足 'static
let lottery = LotterySystem::new(SimpleRandom);
(2) 静态引用
// 全局静态配置
static GLOBAL_CONFIG: Config = Config::default();
struct ConfigBasedAlgorithm;
impl<T> LotteryAlgorithm<T> for ConfigBasedAlgorithm {
// ...
}
// 使用全局配置的算法
let lottery = LotterySystem::new(ConfigBasedAlgorithm);
总结:‘static 是 Rust 向编译器做出的承诺:“这个值不依赖任何短期数据,可以安全地活到程序结束”。编译器会严格验证这个承诺,防止悬垂引用。