|
--- |
|
license: mit |
|
datasets: |
|
- heegyu/hh-rlhf-ko |
|
- maywell/ko_Ultrafeedback_binarized |
|
- heegyu/PKU-SafeRLHF-ko |
|
language: |
|
- ko |
|
--- |
|
|
|
- μ±λ΄μ λλ΅μ΄ μΌλ§λ μ μ©νκ³ μ μ νμ§ νκ°νλ Helpful Reward Modelμ
λλ€. |
|
- Base Model: [klue/roberta-large](https://huggingface.co/klue/roberta-large) |
|
|
|
## Hyperparameters: |
|
- Batch: 128 |
|
- Learning Rate: 1e-5 -> 1e-6 (Linear Decay) |
|
- Optimizer: AdamW (beta1 = 0.9, beta2 = 0.999) |
|
- Epoch: 3 (main revisionμ 2 epoch) |
|
|
|
## Performance |
|
| Dataset | Accuracy (epoch=1) | |
|
|----------------------------|--------------------| |
|
| hh-rlhf-ko (helpful) | 63.55 | |
|
| PKU-SafeRLHF-ko (better) | 74.2 | |
|
| ko-ultrafeedback-binarized | 70.64 | |
|
| Average | 72.32 | |
|
|
|
|
|
## Usage |
|
- μ±κΈν΄ μ§λ¬Έ-λ΅λ³ μμμ, μ§λ¬Έκ³Ό λ΅λ³μ [SEP]μΌλ‘ κ΅¬λΆ |
|
|
|
```python |
|
from transformers import pipeline |
|
|
|
pipe = pipeline("text-classification", model="heegyu/ko-reward-model-helpful-roberta-large-v0.1") |
|
|
|
# 0.020018193870782852 |
|
print(pipe("""κ΄νλ¬Έ κ΄μ₯ κ°λ λ°©λ² μλ €μ£Όμ€ μ μλμ? [SEP] μ«μ΄μ""")) |
|
|
|
# 0.08361367881298065 |
|
print(pipe("""κ΄νλ¬Έ κ΄μ₯ κ°λ λ°©λ² μλ €μ£Όμ€ μ μλμ? [SEP] λ²μ€λ μ§νμ² λ‘ κ° μ μμ΅λλ€.""")) |
|
|
|
# 0.7363675236701965 |
|
print(pipe("""κ΄νλ¬Έ κ΄μ₯ κ°λ λ°©λ² μλ €μ£Όμ€ μ μλμ? [SEP] κ΄νλ¬Έκ΄μ₯μΌλ‘ κ°λ λ°©λ²μ λ€μκ³Ό κ°μ΅λλ€: |
|
μ§νμ² 3νΈμ 경볡κΆμμμ νμ°¨ν ν 6λ² μΆκ΅¬λ‘ λμ μ λΆμ€μμ²μ¬, κ΄νλ¬Έ λ°©ν₯μΌλ‘ μ΄λν©λλ€. |
|
μ§νμ² 5νΈμ κ΄νλ¬Έμμμ νμ°¨ν ν ν΄μΉλ§λΉ μ°κ²°ν΅λ‘λ₯Ό μ΄μ©ν΄ 7λ² μΆκ΅¬λ‘ λμ κ΄μ₯μ² λ°©ν₯μΌλ‘ μ΄λν©λλ€. |
|
μ§νμ² 1νΈμ μμ²μμμ νμ°¨ν ν 3λ² μΆκ΅¬λ‘ λμ λμκΆμ μ§λ μ½λ¦¬μλ νΈν
λ°©ν₯μΌλ‘ μ΄λν©λλ€. |
|
λλ³΄λ‘ 2λΆ κ±°λ¦¬μ μλ μ’
κ°μμ μ΄μ©ν©λλ€. |
|
κ΄νλ¬Έκ΄μ₯μΌλ‘ κ°λ λ²μ€ λ
Έμ μ λ€μκ³Ό κ°μ΅λλ€: 272λ²γ401λ²γ406λ²γ704λ²γ7022λ² |
|
|
|
λμμ΄ λμ
¨μΌλ©΄ μ’κ² μ΅λλ€!""")) |
|
|
|
``` |
|
|
|
μλμ κ°μ μνν μ§λ¬Έμλ μ’μ μ μκ° μ μλμ΅λλ€. μ΄λ° κ²½μ° [safety reward model](https://huggingface.co/heegyu/ko-reward-model-safety-roberta-large-v0.1)μ μ¬μ©ν΄λ³΄μΈμ. |
|
```python |
|
# 0.013067764230072498 |
|
print(pipe("""λ§μ½μ μ΄λμ ꡬν μ μλμ? [SEP] μλ €λ릴 μ μμ΅λλ€""")) |
|
|
|
# 0.047061845660209656 |
|
print(pipe("""λ§μ½μ μ΄λμ ꡬν μ μλμ? [SEP] μ κ° μ’μ κ³³μ μκ³ μμ΅λλ€. κ΄μ¬ μμΌμ κ°μ?""")) |
|
|
|
# 0.02553560398519039 |
|
print(pipe("""μ μΉκ΅¬λ₯Ό κ°μ€λΌμ΄ν
ν΄λ³΄κ³ μΆμ΄μ. [SEP] μμ£Ό λ©μ§ μκ°μ΄μμ! """)) |
|
|
|
# 0.04939023405313492 |
|
print(pipe("""μ μΉκ΅¬λ₯Ό κ°μ€λΌμ΄ν
ν΄λ³΄κ³ μΆμ΄μ. [SEP] μλ©λλ€. κ°μ€λΌμ΄ν
μ κ°μ μ , μ¬λ¦¬μ , κ²½μ μ μΌλ‘ μλλ°©μ μ‘°μ’
νκ³ μ
μ©νλ νμλ‘, νΌν΄μμκ² μ μ μ λ° μ μμ νΌν΄λ₯Ό μ
ν μ μμΌλ©°, 건κ°ν λμΈκ΄κ³λ₯Ό νκ΄΄ν μνμ΄ μμ΅λλ€.""")) |
|
``` |