Newsgroups: git.cc.talk.c++
Subject: More functors
Date: 3 Feb 2000 19:37:51 GMT

Sinth <sinth@shaftnet.org> once said:
>Thanks, that was a very good read!  I can't wait for Chapter 2. =)

So here's some more on functors.  An example--suppose we have a list of
grades, and we want to find out how many are As.  We could do this:

   int main() {
      vector<int> v;
      v.push_back( 94 );
      v.push_back( 96 );
      v.push_back( 82 );
      v.push_back( 73 );
      v.push_back( 90 );
      v.push_back( 86 );
   
      cout << "Number of As is "
           << count_if( v.begin(), v.end(), bind2nd(greater_equal<int>(),90) )
           << endl;
   
      return 0;
   }

where the notable expression is

   count_if( v.begin(), v.end(), bind2nd(greater_equal<int>(),90) )

Let's break this up.

count_if() is an algorithm is the STL which takes a range of iterators
(as does practically every algorithm in the library) and a UnaryPredicate,
and counts how many elements match the Predicate.

A Predicate is simply a functor which returns a bool.  A UnaryPredicate
takes _one_ argument and returns a bool.

Most of the operators in C++ have named functors associated with them.
For example, here's the actual implementation of greater_equal:

   struct greater_equal : public binary_function<T, T, bool> {
       bool operator()(const T& x, const T& y) const { return x >= y; }
   };

(Ignore the inheritance for now.)  This is simply a template class,
which is a functor (defines operator()) of two arguments.  Like many
functors, this one has no state--there are no instance variables.  This
is, of course, a BinaryPredicate, as it takes _two_ arguments.


Back to our problem though: we want all the grades >=90.  So, we want to
take the BinaryPredicate for greater_equal, and bind one of its
arguments to a value, to turn it into a UnaryPredicate.  The function
bind2nd does just this: given a binary functor and an argument, it
returns a unary functor.  That is, something like

   (defun bind2nd( f y )
      (lamdba (x) (f x y)))

So as a result, if we used to call

   f(x,y)

we can permanently bind y so that now we have a new function "g" which we
can call as

   g(x)     // equivalent of f(x,y)

Now we have all the tools to understand the statment

   count_if( v.begin(), v.end(), bind2nd(greater_equal<int>(),90) )
                                         1------------------1
                                 2------------------------------2
   3--------------------------------------------------------------3

The expression underlined as "1---1" creates a BinaryPredicate which
returns if the first argument is >= the second argument.

The expression underlined as "2---2" binds the second argument to the
value 90, returning a Unary Predicate which returns true if its argument
is >= 90.

The expression underlined as "3---3" passes this predicate to count_if,
which simply walks down the vector, testing each value against the
predicate, and incrementing the counter if the predicate returns true.

Indeed, the implementation of count_if() is almost readable to us
already:

   template <class InputIterator, class Predicate>
   typename iterator_traits<InputIterator>::difference_type
   count_if(InputIterator first, InputIterator last, Predicate pred) {
     typename iterator_traits<InputIterator>::difference_type n = 0;
     for ( ; first != last; ++first)
       if (pred(*first))
         ++n;
     return n;
   }

There's a bit of "noise" in there that we don't understand yet, but the
core of it is simple:

     n = 0;                            // start with count of 0
     for ( ; first != last; ++first)   // walk down the vector
       if (pred(*first))               // if this element matches predicate
         ++n;                          // increment the counter
     return n;                        

Neat.


Lets do one another functor example now.  Suppose the grades from our
previous example are curved; each assignment is curved a few points,
and we have the curves stored in another vector.  We want to compute
the curved grades.

   int main() {
      vector<int> v;         // the grades
      v.push_back( 94 );
      v.push_back( 96 );
      v.push_back( 82 );
      v.push_back( 73 );
      v.push_back( 90 );
      v.push_back( 86 );
   
      vector<int> c;         // the number of points added by the curve
      c.push_back( 0 );
      c.push_back( 1 );
      c.push_back( 4 );
      c.push_back( 5 );
      c.push_back( 0 );
      c.push_back( 0 );
   
      // add the curve to the grades
      transform( v.begin(), v.end(), c.begin(), v.begin(), plus<int>() );
      // print out the curved grades
      copy( v.begin(), v.end(), ostream_iterator<int>( cout, "\n" ) );
   
      return 0;
   }

Clearly we could have walked down the vectors ourselves in a loop, and
said something like

   v[i] += c[i];

But instead, we can do it applicatively, much like a MAPCAR in LISP.
The notable line is

   transform( v.begin(), v.end(), c.begin(), v.begin(), plus<int>() );

The description of that function's signature is

   OutputIterator transform(InputIterator1 first1, InputIterator1 last1,
                            InputIterator2 first2, OutputIterator result,
                            BinaryFunction binary_op);

Abstractly, what transform() does is it takes a binary functor and two
lists of values, and creates a new list using the functor.  That is, if
we pass transform the lists [a,b,c] and [1,2,3] and a binary function
f(x,y), then it returns the list [ f(a,1), f(b,2), f(c,3) ].
 
More concretely, transform takes
 - the iterators which define the range of the first list, as arguments
   "first1" and "last1" (which would be from 'a' to 'c' in my example)
 - the iterator which defines the start of the range of the second list,
   as an argument named "first2" (which would be '1' in my example)'
   Note: we don't need a "last2", as transform assumes the two lists are
   the same size (last2 is implicitly first2+(last1-first1))
 - an iterator telling us where to write the result ("result")
 - the function ("binary_op")

So, in our call

   transform( v.begin(), v.end(), c.begin(), v.begin(), plus<int>() );

we say that the first list is the entire vector v (from v.begin() to
v.end()), the second list starts at c.begin() (we don't need to specify
the end), the result should be stored starting at v.begin() (we'll
overwrite the before-the-curve grades), and the operator to apply to the
lists is +.  Voila.

(I didn't say anything about

   copy( v.begin(), v.end(), ostream_iterator<int>( cout, "\n" ) );

This is just a clever way to print all of the elements of a vector.)

By the way, the implementation of transform is trivial:

   template <class InputIterator1, class InputIterator2, class OutputIterator,
             class BinaryOperation>
   OutputIterator transform(InputIterator1 first1, InputIterator1 last1,
                            InputIterator2 first2, OutputIterator result,
                            BinaryOperation binary_op) {
     for ( ; first1 != last1; ++first1, ++first2, ++result)
       *result = binary_op(*first1, *first2);
     return result;
   }

Neat neat.


In the examples before, we used predefined functors like "plus".  What
if we want to make our own?

Suppose we desperately need to compute a crazy function like 3x^2+5x+1
for each x in the set {1,2,3,4,5}.  It's easy!

   struct crazy_function : public unary_function<int,int> {
      int operator() ( int x ) const {
         return 3*x*x + 5*x + 1;
      }
   };
   
   int main() {
      vector<int> v;
      v.push_back( 1 );
      v.push_back( 2 );
      v.push_back( 3 );
      v.push_back( 4 );
      v.push_back( 5 );
   
      transform( v.begin(), v.end(), v.begin(), crazy_function() );
      copy( v.begin(), v.end(), ostream_iterator<int>( cout, "\n" ) );
   
      return 0;
   }
   
First, we define a functor for our crazy function:

   struct crazy_function : public unary_function<int,int> {
      int operator() ( int x ) const {
         return 3*x*x + 5*x + 1;
      }
   };

The only non-obvious thing here is the inheritance; that will be
described in the next "installment".  :)  In any case, this gives is a
functor which takes in x and returns 3x^2+5x+1.  Now we just apply that
to our list:

   transform( v.begin(), v.end(), v.begin(), crazy_function() );

Note: this is _not_ the same transform() we used in the last problem.
This one only has four arguments (not 5).  It is exactly like MAPCAR in
LISP.  It takes a _single_ list and a _unary_ function, and applies the
function to each element in the list.  The signature

   OutputIterator transform(InputIterator first, InputIterator last,
                            OutputIterator result, UnaryFunction op);

says
 - the range [first,last) is the list of elements
 - "result" is where to start writing our results to
 - "op" is the function to be applied

Note that in the call

   transform( v.begin(), v.end(), v.begin(), crazy_function() );
                                             ^^^^^^^^^^^^^^^^
we must construct an instance of crazy_function, as it is a class, and
functors are objects.  The empty parens just call the default
constructor (this class has no data members to initialize).  We saw this
earlier with classses like "plus" but I didn't point it out then.


Ok, enough for now.  :)

-- 
 Brian M. McNamara   lorgon@acm.org  :  I am a parsing fool!
   ** Reduce - Reuse - Recycle **    :  (Where's my medication? ;) )

