C++如何自己实现一个shared_ptr
1. shared_ptr介绍
C++中的shared_ptr智能指针是行为类似于指针的类对象,封装了原始指针并提供了自动内存管理的功能(不用手动delete),从而实现了RAII的思想。
shared_ptr 内部是利用引用计数来实现内存的自动管理,每当复制一个 shared_ptr,引用计数会 + 1。当一个 shared_ptr 离开作用域时,引用计数会 - 1。当引用计数为 0 的时候,则 delete 内存。
2. 实现的功能
- 构造函数
- 析构函数
- 拷贝构造函数
- 拷贝赋值运算符
- 移动构造函数
- 移动赋值运算符
- 解引用、箭头运算符
- 引用计数、原始指针、重置指针
3. 具体实现
shared_ptr.h 如下:
#pragma once#include <atomic> // 引入原子操作template <typename T>
class shared_ptr {
private:T* ptr; // 指向管理的对象std::atomic<std::size_t>* ref_count; // 原子引用计数// 释放资源void release() {// P.S. 这里使用 std::memory_order_acq_rel 内存序,保证释放资源的同步if (ref_count && ref_count->fetch_sub(1, std::memory_order_acq_rel) == 1) {delete ptr;delete ref_count;}}public:// 默认构造函数shared_ptr() : ptr(nullptr), ref_count(nullptr) {}// 构造函数// P.S. 这里使用 explicit 关键字,防止隐式类型转换// shared_ptr<int> ptr1 = new int(10); 不允许出现explicit shared_ptr(T* p) : ptr(p), ref_count(p ? new std::atomic<std::size_t>(1) : nullptr) {}// 析构函数~shared_ptr() {release();}// 拷贝构造函数shared_ptr(const shared_ptr<T>& other) : ptr(other.ptr), ref_count(other.ref_count) {if (ref_count) {ref_count->fetch_add(1, std::memory_order_relaxed); // 引用计数增加,不需要强内存序}}// 拷贝赋值运算符shared_ptr<T>& operator=(const shared_ptr<T>& other) {if (this != &other) {release(); // 释放当前资源ptr = other.ptr;ref_count = other.ref_count;if (ref_count) {ref_count->fetch_add(1, std::memory_order_relaxed); // 引用计数增加}}return *this;}// 移动构造函数// P.S. noexcept 关键字表示该函数不会抛出异常。// 标准库中的某些操作(如 std::swap)要求移动操作是 noexcept 的,以确保异常安全。// noexcept 可以帮助编译器生成更高效的代码,因为它不需要为异常处理生成额外的代码。shared_ptr(shared_ptr<T>&& other) noexcept : ptr(other.ptr), ref_count(other.ref_count) {other.ptr = nullptr;other.ref_count = nullptr;}// 移动赋值运算符shared_ptr<T>& operator=(shared_ptr<T>&& other) noexcept {if (this != &other) {release(); // 释放当前资源ptr = other.ptr;ref_count = other.ref_count;other.ptr = nullptr;other.ref_count = nullptr;}return *this;}// 解引用运算符// P.S. const 关键字表示该函数不会修改对象的状态。T& operator*() const {return *ptr;}// 箭头运算符T* operator->() const {return ptr;}// 获取引用计数std::size_t ./use_count() const {return ref_count ? ref_count->load(std::memory_order_acquire) : 0;}// 获取原始指针T* get() const {return ptr;}// 重置指针void reset(T* p = nullptr) {release();ptr = p;ref_count = p ? new std::atomic<std::size_t>(1) : nullptr;}
};
测试代码testExample.cc如下
#include <iostream>
#include "shared_ptr.h"
#include <thread>
#include <vector>
#include <chrono>void test_shared_ptr_thread_safety() {shared_ptr<int> ptr(new int(42));// 创建多个线程,每个线程都增加和减少引用计数const int num_threads = 10;std::vector<std::thread> threads;for (int i = 0; i < num_threads; ++i) {threads.emplace_back([&ptr]() {for (int j = 0; j < 100; ++j) {shared_ptr<int> local_ptr(ptr);// 短暂暂停,增加线程切换的可能性std::this_thread::sleep_for(std::chrono::milliseconds(1));}});}// 等待所有线程完成for (auto& thread : threads) {thread.join();}// 检查引用计数是否正确std::cout << "use_count: " << ptr.use_count() << std::endl;if (ptr.use_count() == 1) {std::cout << "Test passed: shared_ptr is thread-safe!" << std::endl;} else {std::cout << "Test failed: shared_ptr is not thread-safe!" << std::endl;}
}// 测试代码
int main() {shared_ptr<int> ptr1(new int(10));std::cout << "ptr1 use_count: " << ptr1.use_count() << std::endl; // 1{shared_ptr<int> ptr2 = ptr1;std::cout << "ptr1 use_count: " << ptr1.use_count() << std::endl; // 2std::cout << "ptr2 use_count: " << ptr2.use_count() << std::endl; // 2}std::cout << "ptr1 use_count: " << ptr1.use_count() << std::endl; // 1shared_ptr<int> ptr3(new int(20));ptr1 = ptr3;std::cout << "ptr1 use_count: " << ptr1.use_count() << std::endl; // 2std::cout << "ptr3 use_count: " << ptr3.use_count() << std::endl; // 2ptr1.reset();std::cout << "ptr1 use_count: " << ptr1.use_count() << std::endl; // 0std::cout << "ptr3 use_count: " << ptr3.use_count() << std::endl; // 1test_shared_ptr_thread_safety();return 0;
}
输出结果:
ptr1 use_count: 1
ptr1 use_count: 2
ptr2 use_count: 2
ptr1 use_count: 1
ptr1 use_count: 2
ptr3 use_count: 2
ptr1 use_count: 0
ptr3 use_count: 1
use_count: 1
Test passed: shared_ptr is thread-safe!
补充知识
(1)shared_ptr中的引用计数,可以使用std::atomic来管理
std::atomic<std::size_t>* ref_count; // 原子引用计数
进而达到以下目的:
- 原子操作:对原子变量的操作是不可分割的,意味着在多线程中不会被打断。
- 原子变量:一个变量可以被声明为原子类型(如
std::atomic<int>
),它保证在多线程环境下对该变量的操作是安全的
更具体的,我们会用到atomic的以下方法:
fetch_add() 和 fetch_sub()
fetch_add()
:执行原子加法操作,并返回旧值。fetch_sub()
:执行原子减法操作,并返回旧值。
memory_order(内存序)
std::atomic 的操作可以指定不同的内存顺序(memory ordering),控制不同线程之间的操作顺序。这对于高效并发编程非常重要。常见的内存顺序有:
- memory_order_relaxed:不保证其他线程与该线程的操作顺序。
- memory_order_consume:保证后续操作依赖于当前操作。
- memory_order_acquire:保证所有的读取操作不会在当前操作之前执行。
- memory_order_release:保证所有的写操作不会在当前操作之后执行。
- memory_order_acq_rel:同时保证 acquire 和 release。
- memory_order_seq_cst:最强的内存顺序,保证所有操作的顺序一致。
注:在shared_ptr的实现代码中,如果对内存序这个概念不熟,所有出现它的地方都可以不填,默认使用memory_order_seq_cst
(2)线程安全
如果不用std::atomic来管理引用计数,那么可以用mutex(互斥锁),所有对ref_count的操作都要加上mutex。
(3)构造函数与析构函数
- 1. 默认构造函数 (`Default Constructor`): 用于创建类的对象。如果没有定义任何构造函数,编译器会生成一个默认构造函数。
- 2. 拷贝构造函数 (`Copy Constructor`): 接收同类的另一个对象的引用,用于通过已存在的对象来初始化新对象的成员。
- 3. 拷贝赋值操作符 (`Copy Assignment Operator`): 用于将一个对象的内容复制到另一个已经存在的对象中。
- 4. 移动构造函数 (`Move Constructor`): C++11 引入。如果可能,用于将一个对象的资源“移动”到新创建的对象中,而非复制。
- 5. 移动赋值操作符 (`Move Assignment Operator`): C++11 引入。用于将一个对象的资源转移给另一个已经存在的对象。
- 6. 析构函数 (`Destructor`): 当对象的生命周期结束时被调用,用于执行清理工作,如释放资源。