Implementation of the multinomial generalized linear model (softmax link)#3319
Open
jachymb wants to merge 9 commits intostan-dev:developfrom
Open
Implementation of the multinomial generalized linear model (softmax link)#3319jachymb wants to merge 9 commits intostan-dev:developfrom
jachymb wants to merge 9 commits intostan-dev:developfrom
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This is the implementation of the
multinomial_logit_glm_lpmf, including the OpenCL variant. A solution to #3149It computes, in an efficient vectorized way, exactly what one would expect:
I was trying to closely follow patterns in related functions like the
binomial_logit_glm_lmpfandcategorical_logit_glm_lpmfwhich this is a generalization of both.The typing:
yare a 2Darray [,] intnon-negative (rows=instances, columns=outcome classes)xarematrix(rows=instances, columns=features)betaare amatrix(rows=features, columns=outcome classes)alphamay be either amatrix(same shape asy) or arow_vector(length=outcome classes). Now this may seem a little irregular at first, but it makes sense when you really think about it. Other GLMs offer either avectororrealfor the intercept. In the multivariate case, the (univariate)vectorcorresponds to thematrixand the scalar to be copied for each instance is more naturally represented here as arow_vectorrather than a column. (It can be seen as an extra row ofbetacorresponding to a constant feature, also matching the shape ofyand in consequence, the implementation algebra is just more natural this way.) However this may be argued to be a little inconsistent withcategorical_logit_glm_lpmfso it's maybe up to a debate, but then the categorical is not implemented as multivariate even though it also uses beta matrix.I chose to use "logit" in the name instead of "softmax" (even though it really is softmax) to be consistent with existing functions like
multinomial_logit.I am new to this, so if there are any improvements to be made, I will gladly listen. I tried to use efficient vectorized approach and utilize existing functions where applicable. Note however this uses its own softmax implementation, which is necessary for efficient vectorization and in my understanding can't be done by the existing library implementation the way it's needed here. Maybe if my other PR #3313 is approved, this could be made a few LOC shorter with a call to that.
As a personal note, this is a function I want available in Stan for my actual work. :)
Tests
Standard testing of the new functions introduced, both for the prim and cl variants.
Side Effects
None.
Release notes
Add
multinomial_logit_glm_lpmf, including OpenCL support.AI use disclosure
I used claude code w/ Sonnet to help with the work, but I critically reviewed/edited every single line, striving to match the general code quality and patterns within the library.
Checklist
Copyright holder: me, jachymb@gmail.com
- Code: BSD 3-clause (https://opensource.org/licenses/BSD-3-Clause)
- Documentation: CC-BY 4.0 (https://creativecommons.org/licenses/by/4.0/)
the basic tests are passing
./runTests.py test/unit)make test-headers)make test-math-dependencies)make doxygen)make cpplint)the code is written in idiomatic C++ and changes are documented in the doxygen
the new changes are tested