Skip to content
Go back

lottery_rust_practice

简易抽奖系统学习


代码如下

mod test;

use std::collections::HashMap;
use std::fmt;
use std::fmt::Formatter;
use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
use rand::distr::Distribution;
use rand::distr::weighted::WeightedIndex;
use rand::Rng;
use serde::{Serialize, Serializer};
use serde::ser::SerializeStruct;
use thiserror::Error;

fn main() -> Result<(), LotteryError> {
    // 创建并发安全的加权抽奖系统
    let lottery = ConcurrentLottery::new(WeightRandom);

    // 添加参与者
    lottery.add_participant(
        "user1".to_string(),
        Participant {
            name: "Alice".to_string(),
            email: "alice@example.com".to_string(),
        },
        3, // 高权重
    )?;

    lottery.add_participant(
        "user2".to_string(),
        Participant {
            name: "Bob".to_string(),
            email: "bob@example.com".to_string(),
        },
        1, // 低权重
    )?;

    // 单次抽奖
    println!("单次抽奖结果: {:?}", lottery.draw()?);

    // 多次抽奖(无放回)
    println!("两次抽奖结果: {:?}", lottery.draw_multiple(2)?);

    // 更新权重
    lottery.update_weight("user1", 5)?;
    println!("更新权重后抽奖: {:?}", lottery.draw()?);

    Ok(())
    // 单次抽奖结果: Participant { name: "Bob", email: "bob@example.com" }
    // 两次抽奖结果: [Participant { name: "Alice", email: "alice@example.com" }, Participant { name: "Bob", email: "bob@example.com" }]
    // 更新权重后抽奖: Participant { name: "Alice", email: "alice@example.com" }
}



//核心设计思路:
// 参与者管理:支持添加/移除参与者,可配置权重
// 抽奖策略:
// 简单随机抽取
// 加权随机抽取
// 不放回抽取(多次抽奖)
// 并发安全:使用RwLock处理并发访问
// 可扩展性:通过trait支持自定义抽奖算法

// 错误类型
#[derive(Debug, Error)]
pub enum LotteryError {
    #[error("没有可用的参与者")]
    NoParticipants,
    #[error("无效权重值:{0} (必须大于0)")]
    InvalidWeight(u32),
    #[error("找不到指定参与者")]
    ParticipantNotFound,
    #[error("权重分布错误,参与者与权重数量不匹配")]
    NotEnoughParticipants,
}

/// 抽奖算法 trait
/// 定义统一的抽奖接口,支持不同类型参与者的抽奖
pub trait LotteryAlgorithm<T>: Send + Sync {
    /// 执行单次抽奖
    /// 参与:
    ///   - participants: 参与者列表
    ///   - weights: 对应权重列表
    /// 返回: 抽中的参与者或错误
    fn draw(&self, participants: &[T], weights: &[u32]) -> Result<T, LotteryError>;
}

/// 简单随机抽奖算法
/// 所有参与者等概率中奖
pub struct SimpleRandom;

impl<T: Clone> LotteryAlgorithm<T> for SimpleRandom {
    fn draw(&self, participants: &[T], weights: &[u32]) -> Result<T, LotteryError> {
        if participants.is_empty() {
            return Err(LotteryError::NoParticipants);
        }

        let mut rng = rand::rng();
        // 生成0到参与者数量之间的随机索引
        let index = rng.random_range(0..participants.len());
        // 返回克隆的参与者 实例
        Ok(participants[index].clone())
    }
}


/// 加权随机抽奖算法
/// 参与者权重越高,中奖概率越大
pub struct WeightRandom;

impl<T: Clone> LotteryAlgorithm<T> for WeightRandom {
    fn draw(&self, participants: &[T], weights: &[u32]) -> Result<T, LotteryError> {
        if participants.is_empty() {
            return Err(LotteryError::NoParticipants);
        }
        // 验证参与者与权重数量匹配
        if participants.len() != weights.len() {
            return Err(LotteryError::NotEnoughParticipants);
        }

        // 创建加权索引分布
        let dist = WeightedIndex::new(weights).map_err(|_| LotteryError::InvalidWeight(0))?;
        let mut rng = rand::rng();
        // 根据权重分布随机选择索引
        let index = dist.sample(&mut rng);
        Ok(participants[index].clone())
    }
}


/// 抽奖系统核心结构
/// T: 参与者类型,需实现Clone和PartialEq
// #[derive(Debug, Serialize)]  LotterySystem结构体尝试实现Serialize特性时,其中包含的Arc<dyn LotteryAlgorithm<T>>字段无法被序列化引起的。让我们修复这个问题:
pub struct LotterySystem<T> {
    // 参与者存储:ID -> 参与者对象
    participants: HashMap<String, T>,

    // 权重存储: ID -> 权重值
    weights: HashMap<String, u32>,

    /// 抽奖算法实现,使用Arc实现共享所有权
    algorithm: Arc<dyn LotteryAlgorithm<T>>,
}

// 手动实现Debug,避免算法字段的Debug约束
impl<T: fmt::Debug> fmt::Debug for LotterySystem<T> {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        f.debug_struct("LotterySystem")
            .field("participants", &self.participants)
            .field("weights", &self.weights)
            .field("algorithm", &"[dynamic algorithm]") // 避免尝试调试trait对象
            .finish()
    }
}


// 手动实现序列化,仅序列化数据和状态
impl<T: Serialize> Serialize for LotterySystem<T> {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: Serializer
    {
        // 创建包含2个字段的结构体
        let mut state = serializer.serialize_struct("LotterySystem", 2)?;
        state.serialize_field("participants", &self.participants)?;
        state.serialize_field("weights", &self.weights)?;
        state.end()
    }
}


impl<T: Clone + PartialEq> LotterySystem<T> {
    /// 创建新的抽奖系统
    /// 参数:
    ///   - algorithm: 抽奖算法实现
    pub fn new(algorithm: impl LotteryAlgorithm<T> + 'static) -> Self {
        Self {
            participants: HashMap::new(),
            weights: HashMap::new(),
            algorithm: Arc::new(algorithm),
        }
    }

    /// 添加参与者
    /// 参数:
    ///   - id: 参与者唯一ID
    ///   - participant: 参与者对象
    ///   - weight: 权重值 (必须 > 0)
    pub fn add_participant(&mut self, id: String, participant: T, weight: u32) -> Result<(), LotteryError> {
        if weight == 0 {
            return Err(LotteryError::InvalidWeight(weight));
        }

        // 插入参与者和对应权重
        self.participants.insert(id.clone(), participant); // 这里id需要 clone, 不然后面就用不了了.
        self.weights.insert(id, weight);
        Ok(())
    }

    /// 移除参与者
    /// 参数:
    ///   - id: 要移除的参与者ID
    pub fn remove_participant(&mut self, id: &str) -> Result<(), LotteryError> {
        // 从参与者和权重映射中同时移除
        self.participants
            .remove(id)
            .ok_or(LotteryError::ParticipantNotFound)?;
        self.weights.remove(id);
        Ok(())
    }

    /// 更新参与者权重
    /// 参数:
    ///   - id: 参与者ID
    ///   - weight: 新的权重值 (必须 > 0)
    pub fn update_weight(&mut self, id: &str, weight: u32) -> Result<(), LotteryError> {
        if weight == 0 {
            return Err(LotteryError::InvalidWeight(weight));
        }

        // 检查参与者是否存在
        if self.weights.contains_key(id) {
            self.weights.insert(id.to_string(), weight);
            Ok(())
        } else {
            Err(LotteryError::ParticipantNotFound)
        }
    }


    /// 执行单次抽奖
    pub fn draw(&self) -> Result<T, LotteryError> {
        // 准备数据:(id、参与者、权重)
        let data: Vec<(&String, &T, u32)> = self.participants.iter().map(|(id, p)| (id, p, *self.weights.get(id).unwrap_or(&1)))
            .collect();
        // 检查是否有参与者
        if data.is_empty() {
            return Err(LotteryError::NoParticipants);
        }
        // 分离出参与者和权重列表
        let (participants, weights): (Vec<_>, Vec<_>) = data.into_iter().map(|(_, p, w)| (p.clone(), w)).unzip();
        self.algorithm.draw(&participants, &weights)
    }


    /// 执行多次抽奖(无放回)
    /// 参数:
    ///   - count: 要抽取的数量
    pub fn draw_multiple(&self, count: usize) -> Result<Vec<T>, LotteryError> {
        // 检查是否有足够参与者
        if count > self.participants.len() {
            return Err(LotteryError::NotEnoughParticipants);
        }
        // 克隆当前系统状态作为临时工作副本
        let mut temp_system = self.clone();
        let mut results = Vec::with_capacity(count);

        for _ in 0..count {
            // 执行单次抽奖
            let winner = temp_system.draw()?;

            // 查找获奖者id
            let winner_id = temp_system.participants.iter()
                .find(|(_, p)| **p == winner)
                .map(|(id, _)| id.clone())
                .ok_or(LotteryError::NotEnoughParticipants)?;

            // 移除获奖者,确保后续不会再次中奖
            temp_system.remove_participant(&winner_id)?;
            results.push(winner);
        }
        Ok(results)
    }
}

impl<T:Clone> Clone for LotterySystem<T> {
    fn clone(&self) -> Self {
        Self{
            participants: self.participants.clone(),
            weights: self.weights.clone(),
            algorithm: self.algorithm.clone(),
        }
    }
}

/// 并发安全的抽奖系统包装
/// 使用读写锁保证线程安全
#[derive(Clone)]
pub struct ConcurrentLottery<T>(Arc<RwLock<LotterySystem<T>>>);


impl<T: Clone + PartialEq> ConcurrentLottery<T>  {

    /// 创建新的并发抽奖系统
    pub fn new(algorithm: impl LotteryAlgorithm<T> + 'static) -> Self {
        Self(Arc::new(RwLock::new(LotterySystem::new(algorithm))))
    }

    /// 获取读锁
    fn read(&self) -> RwLockReadGuard<LotterySystem<T>> {
        self.0.read().unwrap()
    }

    /// 获取写锁
    fn write(&self) -> RwLockWriteGuard<LotterySystem<T>> {
        self.0.write().unwrap()
    }

    /// 添加参与者(线程安全)
    pub fn add_participant(&self, id: String, participant: T, weight: u32)
        -> Result<(), LotteryError> {
        self.write().add_participant(id, participant, weight)
    }


    /// 移除参与者(线程安全)
    pub fn remove_participant(&self, id: &str) -> Result<(), LotteryError> {
        self.write().remove_participant(id)
    }

    /// 更新权重(线程安全)
    pub fn update_weight(&self, id: &str, weight: u32) -> Result<(), LotteryError> {
        self.write().update_weight(id, weight)
    }

    /// 执行单次抽奖(线程安全)
    pub fn draw(&self) -> Result<T, LotteryError> {
        self.read().draw()
    }

    /// 执行多次抽奖(线程安全)
    pub fn draw_multiple(&self, count: usize) -> Result<Vec<T>, LotteryError> {
        // 注意:这里使用读锁,但内部会克隆系统状态
        self.read().draw_multiple(count)
    }

}

/// 示例参与者结构
#[derive(Debug, Clone, PartialEq, Serialize)]
pub struct Participant {
    pub name: String,
    pub email: String,
}


测试用例

#[cfg(test)]
mod tests {
    use crate::{ConcurrentLottery, LotterySystem, SimpleRandom, WeightRandom};
    use super::*;

    /// 测试简单随机抽奖
    #[test]
    fn test_simple_random() {
        let mut lottery = LotterySystem::new(SimpleRandom);
        lottery.add_participant("1".to_string(), "Alice".to_string(), 1).unwrap();
        lottery.add_participant("2".to_string(), "Bob".to_string(), 1).unwrap();

        let winner = lottery.draw().unwrap();
        assert!(winner == "Alice" || winner == "Bob");
    }

    /// 测试加权随机抽奖
    #[test]
    fn test_weighted_random() {
        let mut lottery = LotterySystem::new(WeightRandom);
        lottery.add_participant("1".to_string(), "Alice".to_string(), 3).unwrap();
        lottery.add_participant("2".to_string(), "Bob".to_string(), 1).unwrap();

        // 统计中奖次数(概率上Alice应该多于Bob)
        let mut alice_wins = 0;
        for _ in 0..1000 {
            let winner = lottery.draw().unwrap();
            if winner == "Alice" {
                alice_wins += 1;
            }
        }

        assert!(alice_wins > 600); // 3:1的权重,Alice应该大约占75%
    }

    /// 测试多次抽奖(无放回)
    #[test]
    fn test_draw_multiple() {
        let mut lottery = LotterySystem::new(SimpleRandom);
        lottery.add_participant("1".to_string(), "A".to_string(), 1).unwrap();
        lottery.add_participant("2".to_string(), "B".to_string(), 1).unwrap();
        lottery.add_participant("3".to_string(), "C".to_string(), 1).unwrap();

        let winners = lottery.draw_multiple(3).unwrap();
        assert_eq!(winners.len(), 3);
        assert_ne!(winners[0], winners[1]);
        assert_ne!(winners[0], winners[2]);
        assert_ne!(winners[1], winners[2]);
    }

    /// 测试并发安全性
    #[test]
    fn test_concurrent_access() {
        use std::thread;

        let lottery = ConcurrentLottery::new(SimpleRandom);
        lottery.add_participant("1".to_string(), "Alice".to_string(), 1).unwrap();

        let handles: Vec<_> = (0..10)
            .map(|_| {
                let lottery = lottery.clone();
                thread::spawn(move || lottery.draw().unwrap())
            })
            .collect();

        for handle in handles {
            assert!(handle.join().unwrap() == "Alice");
        }
    }
}

impl<T: Clone + PartialEq>

泛型约束的组合(类似 Java 的多个接口约束) Clone:允许复制值(类似 Java 的 clone()) PartialEq:允许部分相等比较(类似 Java 的 equals())


pub trait LotteryAlgorithm: Send + Sync {

Send:表示类型可以安全地跨线程传递(转移所有权) Sync:表示类型可以安全地在多线程间共享引用


trait 类似于 Java 的 interface

关键区别: Rust 的 trait 支持关联类型(类似于泛型接口),可以用于运算符重载(如 Add trait 重载 + 运算符)



pub fn new(algorithm: impl LotteryAlgorithm + ‘static) -> Self

'static 表示”整个程序运行期间都有效”的生命周期

这里要求传入的 algorithm 必须满足:
要么是拥有所有权的值(不包含任何引用)
要么包含的引用必须是 'static 的(全局有效)

pub fn new(algorithm: impl LotteryAlgorithm<T> + 'static) -> Self {
    Self {
        participants: HashMap::new(),
        weights: HashMap::new(),
        algorithm: Box::new(algorithm),
    }
}

为什么需要 ‘static

想象抽奖系统的使用场景:

fn main() { let lottery = LotterySystem::new(WeightedRandom);

// 假设这里启动新线程使用抽奖系统
std::thread::spawn(move || {
    lottery.draw();
});

}


如果不加 ‘static 约束: 算法可能包含对局部变量的引用 当主线程结束时,这些引用会失效 新线程可能访问到已释放的内存 → 崩溃!


加上 ‘static 后: 编译器保证算法不依赖短生命周期的数据 可以安全地在多线程间传递


Rust 的 ‘static 相当于在编译时防止了这种情况:

fn main() {
    let config = LocalConfig::new(); // 局部变量
    
    // 编译错误!因为算法依赖局部变量config
    // 不满足 'static 约束
    let lottery = LotterySystem::new(WeightedRandom::new(&config));
}

‘static 的两种形式

在 Rust 中,‘static 有两种实现方式:

(1) 拥有所有权的类型

// SimpleRandom 不包含任何引用 → 自动满足 'static
let lottery = LotterySystem::new(SimpleRandom);

(2) 静态引用

// 全局静态配置
static GLOBAL_CONFIG: Config = Config::default();

struct ConfigBasedAlgorithm;

impl<T> LotteryAlgorithm<T> for ConfigBasedAlgorithm {
    // ...
}

// 使用全局配置的算法
let lottery = LotterySystem::new(ConfigBasedAlgorithm);


总结:‘static 是 Rust 向编译器做出的承诺:“这个值不依赖任何短期数据,可以安全地活到程序结束”。编译器会严格验证这个承诺,防止悬垂引用。


Share this post on: