[.NET]快快樂樂學LINQ系列-Aggregate() 簡介

[.NET]快快樂樂學LINQ系列-Aggregate() 簡介

前言

這篇文章打算從實務的一個例子,來介紹 Aggregate()

Aggregate 的意思就是合計,那跟 Sum() 有什麼不同呢?實際上 Aggregate() 當然可以做到 Sum() 的功能,而且擁有更大的彈性。

這個方法讓開發人員可以自行選定,在巡覽每一次的 iteration 時,暫存一個特定的值來跟這一次的 current item 進行結合、處理或運算。

而在這個說明的基礎底下, Sum() 只是其中的一種應用,暫存的特定值就是最後 Sum() 的結果,而巡覽的 iteration 時,將 item 透過 selector 投射出來的值,與暫存的結果進行加總,這就是 Sum() 。同樣地, Max()Min() 也可以用 Aggregate() 來設計,暫存的值仍是最後的結果,以 Max() 為例,那只需要把 Sum() 裡面,原本用來加總的動作,改成比大小,比較大的,就放到暫存結果。

那麼,什麼情況用 Aggregate() 會比用 Sum() 來得有效率呢?請見下面的範例。

 

範例

需求介紹:

  1. 日結表的資料中,會存放每一天有異動的庫存量。
  2. 當進行月結時,需要將各區的異動庫存量進行加總,產生一筆月結資料。

Scenario 如下:

image

可以看到五月份有三筆日結資料,當針對五月份進行月結結轉時,期望月結資料的良品區、不良品區與客退區的數量,會是五月份三筆日結的加總。

測試程式與自動產生的 production code 如下:

using System.Collections.Generic;
using System.Linq;
using Rhino.Mocks;
using TechTalk.SpecFlow;
using TechTalk.SpecFlow.Assist;

namespace AggregateSample
{
    [Binding]
    [Scope(Feature = "StockManagement")]
    public class StockManagementSteps
    {
        private MonthlyStockSettlementService target;
        private IDailyStockDao dailyStockDao;

        [BeforeScenario]
        public void BeforeScenario()
        {
            this.dailyStockDao = MockRepository.GenerateStub<IDailyStockDao>();
            this.target = new MonthlyStockSettlementService(dailyStockDao);
        }

        [Given(@"欲結算年月為 (.*)")]
        public void Given欲結算年月為(string yearMonth)
        {
            ScenarioContext.Current.Set<string>(yearMonth, "yearMonth");
        }

        [Given(@"ProductId為 (.*)")]
        public void GivenProductId為(string productId)
        {
            ScenarioContext.Current.Set<string>(productId, "id");
        }


        [Given(@"日結資料為")]
        public void Given日結資料為(Table table)
        {
            var dailyStockSettlements = table.CreateSet<DailyStockSettlement>();
            var yearmonth = ScenarioContext.Current.Get<string>("yearMonth");
            var productId = ScenarioContext.Current.Get<string>("id");
            this.dailyStockDao.Stub(x => x.GetDailyStocksByYearMonth(yearmonth, productId)).Return(dailyStockSettlements);
        }

        [When(@"呼叫月結結轉")]
        public void When呼叫月結結轉()
        {
            var yearMonth = ScenarioContext.Current.Get<string>("yearMonth");
            var productId = ScenarioContext.Current.Get<string>("id");
            MonthlyStockSettlement actual = this.target.Snapshot(yearMonth, productId);
            ScenarioContext.Current.Set<MonthlyStockSettlement>(actual);
        }

        [Then(@"月結資料應為")]
        public void Then月結資料應為(Table table)
        {
            var actual = ScenarioContext.Current.Get<MonthlyStockSettlement>();
            table.CompareToInstance(actual);
        }
    }

    public class MonthlyStockSettlementService
    {
        private IDailyStockDao dailyStockDao;

        public MonthlyStockSettlementService(IDailyStockDao dailyStockDao)
        {
            // TODO: Complete member initialization
            this.dailyStockDao = dailyStockDao;
        }

        internal MonthlyStockSettlement Snapshot(string yearMonth, string productId)
        {
            throw new System.NotImplementedException();
        }
    }

    public class MonthlyStockSettlement
    {
        //| ProductId | YearMonth | QualifiedProductSection | DefectProductSection | ReturnGoodSection |
    }

    public interface IDailyStockDao
    {
        IEnumerable<DailyStockSettlement> GetDailyStocksByYearMonth(string yearmonth, string productId);
    }

    public class DailyStockSettlement
    {
        //| ProductId | Date       | QualifiedProductSection | DefectProductSection | ReturnGoodSection |
    }
}

上面 production code 的部分,全都是從測試程式撰寫過程自動產生出來的。而接下來要做的事情,只需要將 Snapshot() 完成且通過測試即可。

 

針對三個欄位做 Sum()

第一個作法,是針對三個欄位做 Sum() 的動作,程式碼如下:

    public class MonthlyStockSettlementService
    {
        private IDailyStockDao dailyStockDao;

        public MonthlyStockSettlementService(IDailyStockDao dailyStockDao)
        {
            this.dailyStockDao = dailyStockDao;
        }

        internal MonthlyStockSettlement Snapshot(string yearMonth, string productId)
        {
            var dailyStockSettlements = this.dailyStockDao.GetDailyStocksByYearMonth(yearMonth, productId);
            var result = new MonthlyStockSettlement
            {
                ProductId = productId,
                YearMonth = yearMonth,
                QualifiedProductSection = dailyStockSettlements.Sum(x => x.QualifiedProductSection),
                DefectProductSection = dailyStockSettlements.Sum(x => x.DefectProductSection),
                ReturnGoodSection = dailyStockSettlements.Sum(x => x.ReturnGoodSection)
            };

            return result;
        }
    }

    public class MonthlyStockSettlement
    {
        //| ProductId | YearMonth | QualifiedProductSection | DefectProductSection | ReturnGoodSection |
        public string ProductId { get; set; }
        public string YearMonth { get; set; }
        public int QualifiedProductSection { get; set; }
        public int DefectProductSection { get; set; }
        public int ReturnGoodSection { get; set; }
    }

    public interface IDailyStockDao
    {

        IEnumerable<DailyStockSettlement> GetDailyStocksByYearMonth(string yearmonth, string productId);
    }

    public class DailyStockSettlement
    {
        //| ProductId | Date       | QualifiedProductSection | DefectProductSection | ReturnGoodSection |
        public string ProductId { get; set; }
        public DateTime Date { get; set; }
        public int QualifiedProductSection { get; set; }
        public int DefectProductSection { get; set; }
        public int ReturnGoodSection { get; set; }
    }

為了取得月結的結果,上述程式碼針對「良品區」、「不良品區」、「退貨區」的三個欄位,用了 3 次 Sum() 進行加總,看起來好像很酷,但其實為了三個欄位的加總,原本「只需要針對日結表資料用一次 loop 針對三個欄位加總」的動作,現在卻用了 3 次 loop ,這是不合理的。

 

使用 Loop 來做

用 foreach loop 反而只要透過一個暫存結果,跑一次 loop 的動作而已,程式碼如下:

        internal MonthlyStockSettlement Snapshot(string yearMonth, string productId)
        {
            var dailyStockSettlements = this.dailyStockDao.GetDailyStocksByYearMonth(yearMonth, productId);
            //var result = new MonthlyStockSettlement
            //{
            //    ProductId = productId,
            //    YearMonth = yearMonth,
            //    QualifiedProductSection = dailyStockSettlements.Sum(x => x.QualifiedProductSection),
            //    DefectProductSection = dailyStockSettlements.Sum(x => x.DefectProductSection),
            //    ReturnGoodSection = dailyStockSettlements.Sum(x => x.ReturnGoodSection)
            //};

            var result = new MonthlyStockSettlement { ProductId = productId, YearMonth = yearMonth };
            foreach (var dailyStockSettlement in dailyStockSettlements)
            {
                result.QualifiedProductSection += dailyStockSettlement.QualifiedProductSection;
                result.DefectProductSection += dailyStockSettlement.DefectProductSection;
                result.ReturnGoodSection += dailyStockSettlement.ReturnGoodSection;
            }
            return result;
        }

看起來很簡單也不難懂,對吧?的確,以加總搭配這麼簡單的需求,沒有太大的差異,但還是來看一下,針對這樣的結構與需求,可以使用 LINQ 的方式來取代迴圈的作業。更後面的段落,則來看這樣的執行方式,是如何抽象成 Aggregate() 的 function 。

 

使用 Aggregate() 來做

接著透過 Aggregate() 來取代原本迴圈的程式碼,程式碼如下:

        internal MonthlyStockSettlement Snapshot(string yearMonth, string productId)
        {
            var dailyStockSettlements = this.dailyStockDao.GetDailyStocksByYearMonth(yearMonth, productId);
            //var result = new MonthlyStockSettlement
            //{
            //    ProductId = productId,
            //    YearMonth = yearMonth,
            //    QualifiedProductSection = dailyStockSettlements.Sum(x => x.QualifiedProductSection),
            //    DefectProductSection = dailyStockSettlements.Sum(x => x.DefectProductSection),
            //    ReturnGoodSection = dailyStockSettlements.Sum(x => x.ReturnGoodSection)
            //};

            //var result = new MonthlyStockSettlement { ProductId = productId, YearMonth = yearMonth };
            //foreach (var dailyStockSettlement in dailyStockSettlements)
            //{
            //    result.QualifiedProductSection += dailyStockSettlement.QualifiedProductSection;
            //    result.DefectProductSection += dailyStockSettlement.DefectProductSection;
            //    result.ReturnGoodSection += dailyStockSettlement.ReturnGoodSection;
            //}
            //return result;

            return dailyStockSettlements.Aggregate(new MonthlyStockSettlement { ProductId = productId, YearMonth = yearMonth },
                (result, d) => 
                {
                    result.QualifiedProductSection += d.QualifiedProductSection;
                    result.DefectProductSection += d.DefectProductSection;
                    result.ReturnGoodSection += d.ReturnGoodSection;

                    return result;
                });
        }

看到跟 foreach loop 的差異如下:

  1. 把 loop 外面的那一行暫存結果,放到第一個參數
  2. 第二個參數是一個 Func<T1, T2, T1> 的委派。T1 指的就是第一個參數那個暫存結果,T2 則是 foreach 巡覽的每一個 item 。

就只是一種把一堆實作細節抽象出來,用更有彈性的方式來取代而已。

 

上面的 Aggregate() 該怎麼自己寫

有了 foreach loop 與 Aggregate() 的對應,相信要自己寫出 LINQ 的方法,應該也不是件難事吧。這邊先貼上這個方法的簽章:

public static TAccumulate Aggregate<TSource, TAccumulate>(this IEnumerable<TSource> source, TAccumulate seed, Func<TAccumulate, TSource, TAccumulate> func);

接著,寫一個自己的 MyAggregate() 來取代原本 LINQ 的 Aggregate() ,程式碼如下:

       internal MonthlyStockSettlement Snapshot(string yearMonth, string productId)
        {
            var dailyStockSettlements = this.dailyStockDao.GetDailyStocksByYearMonth(yearMonth, productId);
            //var result = new MonthlyStockSettlement
            //{
            //    ProductId = productId,
            //    YearMonth = yearMonth,
            //    QualifiedProductSection = dailyStockSettlements.Sum(x => x.QualifiedProductSection),
            //    DefectProductSection = dailyStockSettlements.Sum(x => x.DefectProductSection),
            //    ReturnGoodSection = dailyStockSettlements.Sum(x => x.ReturnGoodSection)
            //};

            //var result = new MonthlyStockSettlement { ProductId = productId, YearMonth = yearMonth };
            //foreach (var dailyStockSettlement in dailyStockSettlements)
            //{
            //    result.QualifiedProductSection += dailyStockSettlement.QualifiedProductSection;
            //    result.DefectProductSection += dailyStockSettlement.DefectProductSection;
            //    result.ReturnGoodSection += dailyStockSettlement.ReturnGoodSection;
            //}
            //return result;

            //return dailyStockSettlements.Aggregate(new MonthlyStockSettlement { ProductId = productId, YearMonth = yearMonth },
            //    (result, d) =>
            //    {
            //        result.QualifiedProductSection += d.QualifiedProductSection;
            //        result.DefectProductSection += d.DefectProductSection;
            //        result.ReturnGoodSection += d.ReturnGoodSection;

            //        return result;
            //    });

            return dailyStockSettlements.MyAggregate(new MonthlyStockSettlement { ProductId = productId, YearMonth = yearMonth },
                (result, d) =>
                {
                    result.QualifiedProductSection += d.QualifiedProductSection;
                    result.DefectProductSection += d.DefectProductSection;
                    result.ReturnGoodSection += d.ReturnGoodSection;

                    return result;
                });
        }    

    public static class MyLinqExtension
    {
        public static TAccumulate MyAggregate<TSource, TAccumulate>(this IEnumerable<TSource> source, TAccumulate seed, Func<TAccumulate, TSource, TAccumulate> func)
        {
            var result = seed;
            foreach (var item in source)
            {
                result = func(result, item);
            }

            return result;
        }
    }

抽象的過程只是進行下面幾個步驟:

  1. DailyStockSettlement 換成泛型 TSource
  2. MonthlyStockSettlement 換成泛型 TAccumulate
  3. 把 foreach 裡面要做的事情,換成 Func<TAccumulate, TSource, TAccumulate>
  4. 把 foreach 巡覽 IEnumerable<DailyStockSettlement> 封裝到 extension method 中

就可以讓這種 foreach loop + 迴圈外面放一個暫存的結果,透過泛型 + 匿名委派 + Lambda 的方式抽象成各種型別跟各種處理都能使用,是不是很神奇呢?

這邊也順手把三種簽章都寫出來,只是刻意不把重複的程式碼重構。

    public static class MyLinqExtension
    {
        public static TSource MyAggregate<TSource>(this IEnumerable<TSource> source, Func<TSource, TSource, TSource> func)
        {
            // 因為要把第一個item當初始值,所以用原始的 iterator 寫法比較有效率,用 foreach 看起來很醜
            using (var iterator = source.GetEnumerator())
            {
                if (!iterator.MoveNext())
                {
                    throw new InvalidOperationException("Source seqence was empty");
                }

                //第一個item
                var result = iterator.Current;
                while (iterator.MoveNext())
                {
                    var next = iterator.Current;
                    result = func(result, next);
                }

                return result;
            }
        }

        public static TAccumulate MyAggregate<TSource, TAccumulate>(this IEnumerable<TSource> source, TAccumulate seed, Func<TAccumulate, TSource, TAccumulate> func)
        {
            var result = seed;
            foreach (var item in source)
            {
                result = func(result, item);
            }

            return result;
        }

        public static TResult MyAggregate<TSource, TAccumulate, TResult>(this IEnumerable<TSource> source, TAccumulate seed, Func<TAccumulate, TSource, TAccumulate> func, Func<TAccumulate, TResult> resultSelector)
        {
            var result = seed;
            foreach (var item in source)
            {
                result = func(result, item);
            }

            return resultSelector(result);
        }
    }

很多概念都是互通的,看懂了簽章,知道沒有 LINQ 的時候怎麼用 foreach 寫,接下來怎麼把 foreach 的「使用方式」抽象化,就是這一堆 LINQ to Objects 的方法了。

 

結論

自己在實務上碰到的例子,比這複雜很多,所以我自己一開始也是寫了 3 個 Sum() ,赫然發現很蠢,LINQ 應該要有對應的方法來幫我解決這樣的需求。但我真的忘了是哪一個了,回推回原始的 foreach loop 可能怎麼寫時,我才想到這樣的 foreach loop 使用方式,就是用 Aggregate() 來取代。

所以,希望大家不要只為了讓程式碼看起來好像有在用 LINQ ,看起來很酷,而忽略了每一個 Sum() 其實都是完整的走完 IEnumerable<TSource> 一輪的動作。

by the way, 讀者可以自己練習,用 Aggregate() 來做出 Max(), Min()Sum() 囉。

對敏捷開發有興趣的朋友,可以參考我的粉絲專頁:91敏捷開發之路

對 TDD 課程有興趣的朋友,課程內容、大綱與學員心得,可以參考 skilltree 的公開課程:自動測試與 TDD 實務開發

若需要聯絡我,可以透過粉絲專頁私訊或是側欄的關於我。