🐯 Liger GRPO meets TRL
Thank you for your great work.
Anyway, I tested the liger loss with deepspeed zero3 using Qwen/Qwen2.5-0.5B-Instruct in a bf16.
I met an shape mismatch as stated below:
[rank0]: Traceback (most recent call last):
[rank0]: File "/workspace/temp.py", line 22, in
[rank0]: trainer.train()
[rank0]: File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 2238, in train
[rank0]: return inner_training_loop(
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 2553, in _inner_training_loop
[rank0]: tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 3730, in training_step
[rank0]: loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.11/dist-packages/trl/extras/profiling.py", line 87, in wrapper
[rank0]: return func(self, *args, **kwargs)
[rank0]: