-
Notifications
You must be signed in to change notification settings - Fork 276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP [DeepSeek R1] Add DeepSeekV3 Base + Weight Conversion #2171
base: master
Are you sure you want to change the base?
Conversation
|
||
|
||
@dataclass | ||
class ModelArgs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comes from the original impl - currently here for sanity checking. Will be removed in lieu of json configs.
return logits | ||
|
||
|
||
if __name__ == "__main__": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sanity check main call - will be removed.
rank = 0 | ||
|
||
|
||
class Embedding(layers.Layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: Remove custom class and just use layers.Embedding.
return linear(x, self.weight, self.bias) | ||
|
||
|
||
class ColumnParallelLinear(Linear): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need the custom XParallel
classes if we don't use torch.dist which boils most of them back to the standard implementations?
Adds DeepSeekV3 base and weight conversion script.
The architecture itself builds and runs, but requires massive RAM. Example of a one-block model running on some tokens below (5s/token):
Needs more refactoring and simplification.
WIP/TODOs