Diving deeper into the TaskScheduler work item pool

Two days ago, I spent my hours of insomnia thinking about the question: Just how parallel is my ParallelTaskScheduler custom TaskScheduler implementation, really?

It divides its work items into batches (of Task objects), and then process each batch’s items in parallel, on a single ThreadPool thread… But, thinking of that a little further, I know that a call to Parallel.ForEach accepts an optional ParallelOptions parameter. One of the properties that can be set on ParallelOptions is a TaskScheduler. I don’t pass it one, but that’s not the point. I can safely assume then that it is scheduled by TaskScheduler.Default. Although I haven’t seen its source code, I know how I would have implemented a default scheduler. One optimization I would have made was that if a Task got queued from a ThreadPool thread, since it is already a background thread, I’d run it inline. In fact, I’m pretty sure that it does just that.

So what does that mean for my ParallelTaskScheduler? Nothing good, I’m afraid. If the calls run inline, I could just as well have written them in an ordinary for or foreach loop. This is actually a considerable deviation  from the behaviour I intended. I have improved it, and I will include that code at the end, but first, I’ll write about my futile exercise yesterday implementing a TaskScheduler using Windows operating system fibers.

An experiment; implementing a custom TaskScheduler using Fibers

Fibers are also often referred to as lightweight threads, although they aren’t threads at all. They are units of execution with a stack and set of registers, all running on a single thread, that you schedule yourself manually by switching between them.

They do seem to be very popular, and besides, they are not implemented in .Net (but are to be found in the unmanaged API), so I thought it would be a fun exercise to implement Fibers in C# by platform invoking the unmanaged functions.

Actually it was a waste of a day. My FiberPerTaskScheduler runs nicely, but for a reason beyond my control, unmanaged Fibers should not be used in .Net. I’ll explain why further on, but for interest, here is the code anyway.

First off, here is my very simple managed Fiber wrapper. (It’s very simple because I do no scheduling from this code. If you search for Fiber implementations on the web, you’re likely to find coroutine schedulers written from scratch. I didn’t need to do that, because I’m calling this from a TaskScheduler, which does the scheduling already. All this does is the bare minimum to run a number of Tasks on Fibers.)

using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;

namespace Romy.Core
{
    public static class ManagedFiberWrapper
    {
        public static void InvokeFibers(IEnumerable<Task> tasks, Action<Task> action)
        {
            int remaining = tasks.Count();
            var exceptions = new List<Exception>();

            using (ManualResetEvent mre = new ManualResetEvent(false))
            {
                var primaryFiber = new Thread(() =>
                {
                    uint primaryFiberId = NativeMethods.ConvertThreadToFiber(0);

                    if (primaryFiberId != 0)
                    {
                        try
                        {
                            List<uint> fiberIds = new List<uint>();

                            foreach (var task in tasks)
                            {
                                try
                                {
                                    uint fiberId = 0;

                                    fiberIds.Add(fiberId = NativeMethods.CreateFiber(0, lParam => action(task), 0));

                                    NativeMethods.SwitchToFiber(fiberId);
                                    NativeMethods.DeleteFiber(fiberId);
                                }
                                catch (Exception exc)
                                {
                                    lock (exceptions) exceptions.Add(exc);
                                }
                                finally
                                {
                                    if (Interlocked.Decrement(ref remaining) == 0) mre.Set();
                                }
                            }
                        }
                        finally { NativeMethods.DeleteFiber(primaryFiberId); }
                    }
                });

                primaryFiber.Start();
            }

            if (exceptions.Count > 0) throw new AggregateException(exceptions);
        }

        static class NativeMethods
        {
            public delegate void FiberProc(uint param);

            [DllImport("Kernel32.dll")]
            public static extern uint ConvertThreadToFiber(uint lpParameter);

            [DllImport("Kernel32.dll")]
            public static extern uint CreateFiber(uint dwStackSize, FiberProc lpStartAddress, uint lpParameter);

            [DllImport("Kernel32.dll")]
            public static extern void DeleteFiber(uint lpFiber);

            [DllImport("Kernel32.dll")]
            public static extern void SwitchToFiber(uint lpFiber);
        }
    }
}

And here is my FiberPerTaskScheduler that runs the above code.

using Romy.Core;
using System.Collections.Generic;
using System.Linq;

namespace System.Threading.Tasks.Schedulers
{
    /// <summary>Custom TaskScheduler that processes work items in batches, 
    /// with each work item in the batch processed by a Fiber.</summary>
    public class FiberPerTaskScheduler : TaskScheduler
    {
        [ThreadStatic]
        private static bool currentThreadIsProcessingItems;

        private int maxDegreeOfParallelism;

        private volatile int runningOrQueuedCount;

        private readonly LinkedList<Task> tasks = new LinkedList<Task>();

        public FiberPerTaskScheduler(int maxDegreeOfParallelism)
        {
            if (maxDegreeOfParallelism < 1)
                throw new ArgumentOutOfRangeException("maxDegreeOfParallelism");

            this.maxDegreeOfParallelism = maxDegreeOfParallelism;
        }

        public FiberPerTaskScheduler() : this(Environment.ProcessorCount) { }

        public override int MaximumConcurrencyLevel
        {
            get { return maxDegreeOfParallelism; }
        }

        protected override bool TryDequeue(Task task)
        {
            lock (tasks) return tasks.Remove(task);
        }

        protected override bool TryExecuteTaskInline(Task task,
            bool taskWasPreviouslyQueued)
        {
            if (!currentThreadIsProcessingItems) return false;

            if (taskWasPreviouslyQueued) TryDequeue(task);

            return base.TryExecuteTask(task);
        }

        protected override IEnumerable<Task> GetScheduledTasks()
        {
            var lockTaken = false;
            try
            {
                Monitor.TryEnter(tasks, ref lockTaken);
                if (lockTaken) return tasks.ToArray();
                else throw new NotSupportedException();
            }
            finally { if (lockTaken) Monitor.Exit(tasks); }
        }

        protected override void QueueTask(Task task)
        {
            lock (tasks) tasks.AddLast(task);

            if (runningOrQueuedCount < maxDegreeOfParallelism)
            {
                runningOrQueuedCount++;
                RunTasks();
            }
        }

        private void RunTasks()
        {
            List<Task> taskList = new List<Task>();

            currentThreadIsProcessingItems = true;
            try
            {
                while (true)
                {
                    lock (tasks)
                    {
                        if (tasks.Count == 0)
                        {
                            runningOrQueuedCount--;
                            break;
                        }

                        var t = tasks.First.Value;
                        taskList.Add(t);
                        tasks.RemoveFirst();
                    }
                }

                if (taskList.Count > 0)
                {
                    var batches = taskList.GroupBy(
                        task => taskList.IndexOf(task) / maxDegreeOfParallelism);

                    foreach (var batch in batches)
                    {
                        ManagedFiberWrapper.InvokeFibers(batch, t => base.TryExecuteTask(t));
                    }
                }
            }
            finally { currentThreadIsProcessingItems = false; }
        }
    }
}

Even though the above code works well, there is one problem: If an exception is thrown in the Action executed on a Fiber, and I don’t mean an unhandled exception, any exception at all, the managed exception handler fails to catch the exception. Not only does it fail, it throws an unhandled StackOverflowException in “module unknown”, and brings down the whole process in a blaze of… well, nothing.

The issue is apparently related to TLS (Thread Local Storage), which is heavily used by both Fibers and the managed Exception handlers, and has something to do with the fact that all the Fibers are running on the same thread.

Since I make heavy use of Cancellation in my Tasks, my old friend OperationCancelledException is not welcome in the Fiber family home.

That sucks, almost as much as the fact that I spent most of the day yesterday reading about Fibers, and was initially chuffed when my managed wrapper implementation of them seemed to work perfectly straight away.

I have seen some managed Fiber implementations using C# iterators, but I don’t like the code, so I will not use such a pattern myself.

And thus, I fixed my ParallelTaskScheduler implementation so that it now really does invoke its Tasks in parallel. It is now very much faster than it was.

The improved ParallelTaskScheduler

using System.Collections.Generic;
using System.Linq;

namespace System.Threading.Tasks.Schedulers
{
    /// <summary>Custom TaskScheduler that processes work items in batches, with the 
    /// work items of each batch processed by a ThreadPool thread, in parallel.</summary>
    public class ParallelTaskScheduler : TaskScheduler
    {
        [ThreadStatic]
        private static bool currentThreadIsProcessingItems;

        private int maxDegreeOfParallelism;

        private volatile int runningOrQueuedCount;

        private readonly LinkedList<Task> tasks = new LinkedList<Task>();

        public ParallelTaskScheduler(int maxDegreeOfParallelism)
        {
            if (maxDegreeOfParallelism < 1)
                throw new ArgumentOutOfRangeException("maxDegreeOfParallelism");

            this.maxDegreeOfParallelism = maxDegreeOfParallelism;
        }

        public ParallelTaskScheduler() : this(Environment.ProcessorCount) { }

        public override int MaximumConcurrencyLevel
        {
            get { return maxDegreeOfParallelism; }
        }

        public static void ParallelInvoke(Action<Task> action, params Task[] tasks)
        {
            if (tasks == null) throw new ArgumentNullException("tasks");
            if (action == null) throw new ArgumentNullException("action");
            if (tasks.Length == 0) return;

            using (ManualResetEvent mre = new ManualResetEvent(false))
            {
                int remaining = tasks.Length;
                var exceptions = new List<Exception>();

                foreach (var task in tasks)
                {
                    ThreadPool.UnsafeQueueUserWorkItem(_ =>
                    {
                        try
                        {
                            action(task);
                        }
                        catch (Exception ex)
                        {
                            lock (exceptions) exceptions.Add(ex);
                        }
                        finally
                        {
                            if (Interlocked.Decrement(ref remaining) == 0) mre.Set();
                        }
                    }, null);
                }
                mre.WaitOne();
                if (exceptions.Count > 0) throw new AggregateException(exceptions);
            }
        }

        protected override bool TryDequeue(Task task)
        {
            lock (tasks) return tasks.Remove(task);
        }

        protected override bool TryExecuteTaskInline(Task task,
            bool taskWasPreviouslyQueued)
        {
            if (!currentThreadIsProcessingItems) return false;

            if (taskWasPreviouslyQueued) TryDequeue(task);

            return base.TryExecuteTask(task);
        }

        protected override IEnumerable<Task> GetScheduledTasks()
        {
            var lockTaken = false;
            try
            {
                Monitor.TryEnter(tasks, ref lockTaken);
                if (lockTaken) return tasks.ToArray();
                else throw new NotSupportedException();
            }
            finally { if (lockTaken) Monitor.Exit(tasks); }
        }

        protected override void QueueTask(Task task)
        {
            lock (tasks) tasks.AddLast(task);

            if (runningOrQueuedCount < maxDegreeOfParallelism)
            {
                runningOrQueuedCount++;
                RunTasks();
            }
        }

        private void RunTasks()
        {
            List<Task> taskList = new List<Task>();

            currentThreadIsProcessingItems = true;
            try
            {
                while (true)
                {
                    lock (tasks)
                    {
                        if (tasks.Count == 0)
                        {
                            runningOrQueuedCount--;
                            break;
                        }

                        var t = tasks.First.Value;
                        taskList.Add(t);
                        tasks.RemoveFirst();
                    }
                }

                if (taskList.Count > 0)
                {
                    var batches = taskList.GroupBy(
                        task => taskList.IndexOf(task) / maxDegreeOfParallelism);

                    foreach (var batch in batches)
                    {
                        ParallelInvoke(t => base.TryExecuteTask(t), batch.ToArray());
                    }
                }
            }
            finally { currentThreadIsProcessingItems = false; }
        }
    }
}
Advertisements

About Jerome

I am a senior C# developer in Johannesburg, South Africa. I am also a recovering addict, who spent nearly eight years using methamphetamine. I write on my recovery blog about my lessons learned and sometimes give advice to others who have made similar mistakes, often from my viewpoint as an atheist, and I also write some C# programming articles on my programming blog.
This entry was posted in Programming and tagged , , , , . Bookmark the permalink.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s