实现引用计数线程安全的shared_ptr
c++11引入了三个智能指针,用来自动管理内存,使用智能指针可以有效地减少内存泄漏。
其中,shared_ptr是共享智能指针,可以被多次拷贝,拷贝时其内部的引用计数+1,被销毁时引用计数-1,如果引用计数为0,那么释放其所管理的资源
在线程安全上,shared_ptr具有如下特点:
- shared_ptr的引用计数是线程安全的
- 修改shared_ptr不是线程安全的
- 读写shared_ptr管理的数据不是线程安全的
具体可以参考:https://zhuanlan.zhihu.com/p/664993437
在网上找到的shared_ptr的手动实现都是线程不安全的,那么如何实现一个引用计数线程安全的shared_ptr呢?
参考:从零简单实现一个线程安全的C++共享指针(shared_ptr)-CSDN博客,本文在这篇博客的基础上增加了验证代码,并指出原有实现一个潜在的bug
#include <iostream>
#include <atomic>
#include <mutex>
#include <thread>
#include <vector>using namespace std;#define N 10000class Counter
{
public:Counter() { count = 1; }void add() {lock_guard<std::mutex> lk(mutex_);count++; }void sub() {lock_guard<std::mutex> lk(mutex_);count--;}int get() {lock_guard<std::mutex> lk(mutex_);return count; }private:int count;std::mutex mutex_;
};template <typename T>
class Sp
{
public:Sp(); //默认构造函数~Sp(); //析构函数Sp(T *ptr); //参数构造函数Sp(const Sp &obj); //复制构造函数Sp &operator=(const Sp &obj); //重载=T *get(); //得到共享指针指向的类int getcount(); //得到引用计数器
private:T *my_ptr; //共享指针所指向的对象Counter* counter; //引用计数器void clear(); //清理函数
};//默认构造函数,参数为空,构造一个引用计数器
template<typename T>
Sp<T>::Sp()
{my_ptr = nullptr;counter = new Counter();
}//复制构造函数,新的共享指针指向旧的共享指针所指对象
template<typename T>
Sp<T>::Sp(const Sp &obj)
{//将所指对象也变为目标所指的对象my_ptr = obj.my_ptr;//获取引用计数器,使得两个共享指针用一个引用计数器counter = obj.counter;//使这个对象的引用计数器+1counter->add();
};//重载=
template<typename T>
Sp<T> &Sp<T>::operator=(const Sp&obj)
{//清理当前所引用对象和引用计数器clear();//指向新的对象,并获取目标对象的引用计数器my_ptr = obj.my_ptr;counter = obj.counter;//引用计数器+1counter->add();//返回自己return *this;
}//创建一个共享指针指向目标类,构造一个新的引用计数器
template<typename T>
Sp<T>::Sp(T *ptr)
{my_ptr = ptr;counter = new Counter();
}//析构函数,出作用域的时候,调用清理函数
template<typename T>
Sp<T>:: ~Sp()
{clear();
}//清理函数,调用时将引用计数器的值减1,若减为0,清理指向的对象内存区域
template<typename T>
void Sp<T>::clear()
{//引用计数器-1counter->sub();//如果引用计数器变为0,清理对象if(0 == counter->get()){// 这里有个bug,如果在此间隙处,有另外一个地方执行了share ptr的copy操作,则会crashif(my_ptr){delete my_ptr;}delete counter;}
}//当前共享指针指向的对象,被几个共享指针所引用
template<typename T>
int Sp<T>::getcount()
{return counter->get();
};class A{
public:A(){ cout<<"A construct!"<<endl; };~A() { cout<<"A destruct!"<<endl; };
};Sp<A> sp(new A);
std::vector<Sp<A>> vec1(N);
std::vector<Sp<A>> vec2(N);Sp<A> sp1(new A);
Sp<A> sp2(new A);
Sp<A> sp3(new A);void thread_func1() {for(int i = 0; i < N; i++) {vec1[i] = sp;}
}void thread_func2() {for(int i = 0; i < N; i++) {vec2[i] = sp;}
}void test_crash_func1() {sp1 = sp2;
}void test_crash_func2() {sp3 = sp1;
}void test_crash() {for(int i = 0; i < 10 * N; i++) {std::thread t1(test_crash_func1);std::thread t2(test_crash_func2);t1.join();t2.join();}
}int main()
{std::thread t1(thread_func1);std::thread t2(thread_func2);t1.join();t2.join();std::cout<<"the count is:"<<sp.getcount()<<std::endl;test_crash();
}
按理说调用test_crash应该会导致crash才对,但是不知道为什么没有crash
TODO:使用原子操作实现,对比性能