Home

LinkWorld Algorithm Blog

21 Jan 2018

使用神经网络的Spam分类器

人工神经网络在机器学习领域得到了非常广泛的应用,一个神经网络由多个层组成。有输入层,输入的参数为$x$,输入的数据通过多个隐含层,最终在最后一层输出层得到一个输出。

现在给定一个非常简单的神经网络,只有$N$个隐含层,每一个层只有一个神经元,每个神经元有两个与它关联的值:$w_i$和$b_i$,分别表示权重和偏差,如果你给一个神经元一个输入$x$,它会产生输出$(w_i*x)+b_i$。

因此,一个输入$x$在神经网络中经过如下的方式进行变换。第一层隐含神经元接收输入$x$并产生输出,而这个输出作为第二个神经元的输入,然后,第二个神经元接收输入$y$并产生输出$z=w_2*y+b_2$,以此类推直到第$N$个神经元。

现有一些用户,我们想知道他们是否在发送垃圾邮件(判断是否是spammer)。他们每个人都有一个范围从$minX$到$maxX$的整数ID。所以我们将每个人的ID作为输入放进神经网络中,最终得到一个结果,如果结果是偶数,那么这个用户就不是spammer,否则是spammer。你必须计算出spammers和non-spammers的数量。

输入

输入的第一行包含一个整数$T$,表示测试情况的数量,对于每一个测试情况的输入如下:

每个测试情况的第一行包含三个空格分隔开的整数$N$、$minX$、$maxX$。

接下来的每$N$行包含两个空格分隔开的整数$w_i$和$b_i$,分别表示第$i$个神经元的权重和偏差。

输出

对于每一个测试情况,输出两个空格分隔开的整数,分别表示non-spammers和spammers的数量。

限制

  • $1 \leq T \leq 10$
  • $1 \leq N \leq 10^5$
  • $1 \leq minX \leq maxX \leq 10^9$
  • $1 \leq wi, bi \leq 10^9$

例子

Input
3
1 1 2
1 2
2 1 4
2 4
2 3
3 2 1000000000
2 4
2 2
5 4

Output
1 1
0 4
999999999 0

解释

例子一:这里有一个简单的神经元,其权重为1,偏差为2。我们来检查一下输出:当$x=1$时,输出为。当$x=2$时,输出为。可见其中一个是偶数另一个是奇数,所以这里有一个spammer和一个non-spammer。

例子二:这里有两个权重为2的神经元,但是他们的偏差分别为4和3。

可见这里所有的输出都是奇数,所以这里没有non-spammers,有4个spammers。

解法

这里我们使用C++语言实现。

第一行输入是测试情况的数量,像往常一样的写法:

#include <iostream>

int main()
{
    int T;
    cin>>T:
    while(T--)
    {
        //code
    }
    return 0;
}

接着我们需要获取每个测试情况的权重和偏差。可以使用数组存储权重和偏差,为了安全和方便,为了内存的世界和平,我们使用vector作为数组。vector相对于普通数组的特点是可变长度。

同时,我们发现限制中的数据范围很大,所以用int做数据类型肯定是不够的,为了方便,我们统统把所有的数据用long long作为它的类型。但是long long写起来太长了,不如用typedef将其用ll表示。

typedef long long int ll;

于是,代码如下:

ll N,minX,maxX;
cin>>N>>minX>>maxX;
vector<ll> wl,bl;//定义两个数组,尖括号内表示数组的元素类型
for(ll i = 0;i<N;++i)
{
    ll w,b;
    cin>>w>>b;
    wl.push_back(w);//往wl数组后添加一个数字w
    bl.push_back(b);//往bl数组后添加一个数字b
}

顺便写一个函数,用于神经网络的计算,它的参数是一个输入值,和权重和偏差数组,输出输出值。

ll compute(ll input,ll depth,vector<ll> wl,vector<ll> bl)
{
    ll output = input;
    for(ll i = 0;i<depth;++i)
        output = output * wl[i] + bl[i];
    return output%2;
}

看起来很不错,接下来只要逐个计算$minX$和$maxX$范围内的所有数据的spammers数量即可:

ll notspammer = 0,spammer = 0;
for(ll i = minX;i<=maxX;i++)
{
    if(compute(i,wl,bl)%2==0)//如果结果可被2整除,也就是说结果是偶数
        notspammer++;
    else
        spammer++;
}

很好,看起来一切都写好了呢。运行一下上面提供的例子,然后发现程序卡在了最后一个测试情况上。定睛一看,卧槽,这个数据竟然有十多亿那么大,按照电脑的速度这得算到何年何月。

这时,突然想到,初中老师告诉我们奇数和偶数运算的规律,突然发现我们的代码根本不用计算每一个数据在神经网络中的输出,按照运算规律,所有的奇数经过这一神经网络出来的结果的奇偶性都是相同的,反之亦然。

最后只要计算$minX$和$maxX$中有多少个奇数和偶数,然后就能得到结果啦。

所以我们把代码改写如下:

#include <iostream>
#include <vector>

typedef long long int ll;

using namespace std;

ll compute(ll input,ll depth,vector<ll> wl,vector<ll> bl)
{
    ll output = input;
    for(ll i = 0;i<depth;++i)
        output = output * wl[i] + bl[i];
    return output%2;
}

int main()
{
    int testCase;
    cin>>testCase;
    while(testCase--)
    {
        ll T,minX,maxX;
        cin>>T>>minX>>maxX;
        vector<ll> wl,bl;
        for(ll i = 0;i<T;++i)
        {
            ll w,b;
            cin>>w>>b;
            wl.push_back(w);
            bl.push_back(b);
        }
        ll notspammer = 0 ,spammer = 0;
        ll evenCount = (maxX - minX + 1)/2;
        ll oddCount = maxX - minX + 1 - evenCount;
        if(compute(0,T,wl,bl)==0)
            notspammer+=evenCount;
        else
            spammer+=evenCount;
        if(compute(1,T,wl,bl)==0)
            notspammer+=oddCount;
        else
            spammer+=oddCount;
        cout<<notspammer<<" "<<spammer<<endl;
    }
    return 0;
}

用上面的代码例子就通过啦。

使用神经网络的Spam分类器

Til next time,
LinkWorld at 21:00

scribble