การเรียนรู้ด้วยวิธี linear least squares

Jul 10 2010

ช่วงนี้ได้มีโอกาสเรียนรู้เรื่อง machine learning แบบลึกกว่าที่ผ่านมา ที่ผ่านมาต้องเรียกว่าไม่รู้อะไรเลยมากกว่า ได้แต่ใช้ tool ของชาวบ้าน ไหนๆ ก็เรียนมาแล้วก่อนที่จะลืมก็อยากเขียนเอาไว้ซะหน่อย เผื่อคนที่เพิ่งเริ่มเหมือนกันด้วย

เป้าหมายหนึ่งของงานในด้าน supervised learning หรือการเรียนรู้แบบมีผู้สอน ก็คือการประมาณค่าของ function (real-valued function) จริงจากข้อมูลที่สุ่มมาได้ (sample) จากภาพ (จาก http://en.wikipedia.org/wiki/Least_squares) จุดแดงคือข้อมูลที่สุ่มมา เส้นน้ำเงินคือ function ที่ประมาณขึ้นมาได้จากข้อมูล วิธีหนึ่งในการประมาณค่า function คือ Linear least-squares ไอเดียคือ พยายามหา function ที่ทำให้ผลรวมของระยะห่างจากจุดแต่ละจุดมายัง function นั้นน้อยที่สุด (จริงๆแล้วคือระยะห่างกำลังสอง) ระยะห่างแต่ละจุดมาถึง function พูดง่ายๆก็คือ error นั่นเอง ถ้าระยะห่างทุกจุดเป็น 0 ก็แปลว่า function เราผ่านทุกจุดเป๊ะๆ

สมมติว่าเรามี n จุด (ตัวอย่าง) กำหนดให้  (x_i, y_i) เป็นตัวอย่างที่ i เราพยายามจะหา function ในรูปของ

 \hat{f}(x) = \sum_{j=1}^b \alpha_j \varphi_j(x)

เมื่อ  \varphi_j(x) คือสิ่งที่เรียกว่า basis function ที่รับ x แล้วคืนค่าเป็นจำนวนจริง ซึ่งเราสามารถเลือกเป็นอะไรก็ได้ มีทั้งหมด b basis function จากไอเดียของ lineat least-squares เราอยากหา  \hat{f}(x) ที่ทำให้ผลรวม error น้อยสุด ในเมื่อ basis function ของเรากำหนดตายตัวไปแล้ว ปัญหาจึงลดลงไปเป็นการหา  \alpha_j ทั้งหมดที่ทำให้ผลรวม error น้อยสุด ดังนี้ (เรียก  \alpha_j ดังกล่าวว่า  \hat{\alpha}_j)

 \hat{\boldsymbol \alpha} = \arg \min_{\boldsymbol \alpha} \sum_{i=1}^n (\hat{f}(x_i) - y_i)^2

ในสมการข้างบน \arg \min_{\boldsymbol \alpha}  แปลว่า หา \boldsymbol \alpha ที่ทำให้สิ่งที่ตามมามีค่าน้อยที่สุด ในที่นี้ \boldsymbol \alpha = (\alpha_1, \alpha_2, \ldots, \alpha_b)^T เป็น vector ของ \alpha_j ทั้งหมด เขียนเป็น vector จะได้สะดวก ต่อมาจากข้างบนแทนนิยามของ function ลงไปเราจะได้

  \hat{\boldsymbol \alpha} = \arg \min_{\boldsymbol \alpha} \sum_{i=1}^n (\sum_{j=1}^b \alpha_j \varphi_j(x_i) - y_i)^2

ถ้าเรากำหนดให้  \boldsymbol X_{i,j} = \varphi_j(x_i) โดยที่  \boldsymbol X_{n \times b} เป็น matrix แล้วสมการข้างบนจะสามารถเขียนได้เป็น

  \hat{\boldsymbol \alpha} = \arg \min_{\boldsymbol \alpha} \| (\boldsymbol X \boldsymbol \alpha - \boldsymbol y) \|^2

โดยที่ \boldsymbol y = (y_1, y_2, \ldots, y_n)^T เป็น vector ของค่า y ทั้งหมด ต่อไปเราก็ diff ค่าข้างบนนั่นเทียบกับ \boldsymbol \alpha แล้วจับเท่ากับ 0 (การ diff เทียบกับ vector สามารถหาดูได้ที่ matrixcookbook.com) เราจะได้

  2\boldsymbol X^T (\boldsymbol X \boldsymbol \alpha - \boldsymbol y) = 0

จัดรูปไปมา เราจะได้

  \hat{\boldsymbol \alpha} = (\boldsymbol X^T \boldsymbol X)^{-1} \boldsymbol X^T \boldsymbol y

จะเห็นว่าเราสามารถหา parameter ที่ดีที่สุดในแง่ของ Least-squares ได้โดยการคูณ matrix เมื่อเราได้ \hat{\boldsymbol \alpha} แล้วที่เหลือก็แค่แทนลงไปในนิยามของ \hat{f}(x) ข้างบน เราก็ได้ function แล้ว สังเกตว่าถึงเราจะเรียกว่า linear least squares แต่ \varphi_j(x) จะเป็น function ที่ไม่ linear กับ x ก็ได้ เช่น log คำว่า linear least squares ในที่นี้คือ function ที่ได้นั้น linear เมื่อเทียบกับ parameter นั่นคือ \boldsymbol \alpha

ถ้าเลือก basis function ดีๆหลากหลายหน่อย เราสามารถประมาณค่า function ได้หลากหลายมาก วิธีนี้ดูเผินๆแล้วน่าจะดี แต่จริงๆแล้ว function ที่ได้มักจะมีปัญหาที่เรียกว่า overfit หรือประมาณค่าใกล้กับข้อมูลมากจนเกินไป (เกิดขึ้นเมื่อมี basis function เยอะ จนสามารถประมาณค่า function อะไรก็ได้) นั่นหมายความว่าถ้าข้อมูลที่เราได้มามี noise มาก เช่น ได้จากเครื่องมือวัดค่าที่เก่าแล้ว คลาดเคลื่อนบ่อย function ที่เราได้มันก็จะซิกแซกไปตามจุดพวกนั้นด้วย ซึ่งในความเป็นจริงเราไม่อยากได้

ตัวอย่างคือรูปนี้ http://upload.wikimedia.org/wikipedia/commons/5/5d/Overfit.png จะเห็นว่ามี 2 เส้น เส้นโค้งกับเส้นตรง ซึ่งได้จากข้อมูลชุดเดียวกันแต่ basis function ต่างกัน ในความเป็นจริงเราอยากได้เส้นตรง ถึงแม้เส้นโค้งจะผ่านทุกจุดเป๊ะ แต่เห็นได้ชัดว่าข้อมูลมีโครงสร้างเชิงเส้น ที่จุดพวกนั้นไม่เรียงเป็นแนวเส้นตรงน่าจะเกิดจาก noise มากกว่า ด้วยเหตุนี้เราจึงกล่าวว่าเส้นโค้งฟิตกับข้อมูลมากไปหรือ overfit นั่นเอง วันหลังจะมีพูดถึงวิธีแก้ปัญหา overfit

จบแล้ว หากผู้รู้จริงผ่านมาแล้วมีคำแนะนำก็เชิญ comment ครับ

\sum_{j=1}^b \alpha_i \varphi_j(x)[

5 responses so far

  • Ake says:

    you are very fit. I think linear least square is a good introductory model if one tries to understand the concept of machine learning. Many more advanced ML models have been proposed based on the concept of lls. Ridge regression is also an interesting model, which remedies the problem found in lls and has interpretation in Bayesian method.

  • admin says:

    Although I know ridge regression, I have no idea that it can be interpreted in a Bayesian way. Actually, I have no idea on Bayesian approaches. Maybe I need to read that some days. :)

  • Chotika says:

    555 U r so brilliant Nuke. I remember when I studied all those things, it took me for ages to figure it out. Seriously, I have no idea how come any person in this world like to learn those stuff. It’s killing me every time I start touching the ML book. U should be my tutor :)

  • Jirach says:

    what are you guy talking about 5555+

  • admin says:

    @Chotika
    Ake is one who likes these stuffs. Anyway, I think you also like it from the fact that you have a name “Miss Bayesian”.

Leave a Reply