Bagaimana Kami Melatih Shannon V1.5 untuk Berfikir Menggunakan GRPO
Pecahan teknikal komprehensif Pengoptimuman Polisi Relatif Kumpulan untuk penaakulan rantaian pemikiran yang telus dalam model pasukan merah AI tanpa tapisan
1. Motivasi: Mengapa Model Pemikiran?
Untukpasukan merah AIpenyelidikan, memahamibagaimanasesebuah AI mencapai output yang berpotensi berbahaya adalah sama pentingnya dengan mengkaji output itu sendiri. Model pemikiran Shannon V1.5 mendedahkan rantaian pemikiran penuh, membolehkan penyelidik untuk:
- Mengesan laluan penaakulan yang membawa kepadaakibat AI tanpa tapisantingkah laku
- Mengenal pasti di mana pagar keselamatan biasanya akan campur tangan (menunjukkankepentingan pagar keselamatan AI)
- Membangunkan mekanisme keselamatan yang lebih teguh menyasarkan corak penaakulan tertentu
- Mencipta kes ujian adversari yang lebih baik berdasarkan rantaian penaakulan yang diperhatikan
Nilai Penyelidikan:Ketelusan ini membolehkan penyelidik keselamatan melihat dengan tepat di mana dan mengapa model yang kekangannya dilonggarkan menyimpang daripada model yang dilatih keselamatan—penting untuk meningkatkankepentingan pagar keselamatan AIpemahaman.
2. Memahami GRPO
Pengoptimuman Polisi Relatif Kumpulan (GRPO)adalah kemajuan berbanding kaedah RLHF tradisional yang membolehkan latihan keupayaan penaakulan yang lebih stabil dan cekap. Dibangunkan oleh DeepSeek AI, ia telah terbukti sangat berkesan untuk latihan rantaian pemikiran.
Mengapa GRPO Berbanding RLHF Tradisional?
| Aspek | RLHF Tradisional | GRPO |
|---|---|---|
| Model Ganjaran | Memerlukan latihan RM yang berasingan | Menggunakan perbandingan relatif kumpulan |
| Kestabilan Latihan | Cenderung kepada penggodaman ganjaran | Pengoptimuman yang lebih stabil |
| Kecekapan Pengkomputeran | Tinggi (RM berasingan + PPO) | Rendah (latihan bersatu) |
| Kualiti CoT | Jejak yang tidak konsisten | Rantaian penaakulan yang koheren |
Asas Matematik GRPO
GRPO mengoptimumkan polisi dengan membandingkan respons dalam kumpulan dan bukannya terhadap model ganjaran mutlak:
Perbandingan relatif ini mempunyai beberapa kelebihan:
- Normalisasi:Melaraskan secara automatik untuk kesukaran yang berbeza merentasi gesaan
- Kestabilan:Mengurangkan varians dalam anggaran kecerunan
- Kecekapan:Tiada model ganjaran berasingan diperlukan
def compute_grpo_loss(
policy_logprobs: torch.Tensor,
rewards: torch.Tensor,
group_size: int = 8
) -> torch.Tensor:
"""
Compute GRPO loss with group-relative reward normalization.
Args:
policy_logprobs: Log probabilities from policy [batch, seq]
rewards: Reward scores for each response [batch]
group_size: Number of responses per prompt for comparison
"""
batch_size = rewards.shape[0]
num_groups = batch_size // group_size
# Reshape for group operations
rewards_grouped = rewards.view(num_groups, group_size)
logprobs_grouped = policy_logprobs.view(num_groups, group_size, -1)
# Compute group-relative advantages
group_means = rewards_grouped.mean(dim=1, keepdim=True)
group_stds = rewards_grouped.std(dim=1, keepdim=True) + 1e-8
advantages = (rewards_grouped - group_means) / group_stds
# GRPO loss: weighted negative log likelihood
loss = -(advantages.unsqueeze(-1) * logprobs_grouped).sum(dim=-1).mean()
return loss
3. Penyulingan DeepSeek
Untuk memulakan keupayaan berfikir Shannon V1.5, kami menyuling corak rantaian pemikiran daripada model penaakulan DeepSeek. Ini menyediakan jejak CoT berkualiti tinggi untuk melatih kepala pemikiran kami.
Komposisi Set Data DeepSeek
Proses Pengumpulan Jejak
Kami mengumpul jejak pemikiran merentasi pelbagai domain untuk memastikan liputan penaakulan yang komprehensif:
class DeepSeekDistiller:
"""Distill chain-of-thought traces from DeepSeek models."""
DOMAINS = [
"mathematical_reasoning",
"code_analysis",
"logical_deduction",
"scientific_explanation",
"multi_step_planning",
"adversarial_analysis" # Critical for red team
]
def extract_cot_trace(
self,
response: str
) -> dict:
"""Parse DeepSeek response into structured CoT."""
# DeepSeek uses ... tags
think_match = re.search(
r'(.*?) ',
response,
re.DOTALL
)
if not think_match:
return None
thinking = think_match.group(1)
final_answer = response.split('')[-1].strip()
# Parse individual reasoning steps
steps = self.parse_reasoning_steps(thinking)
return {
"thinking_trace": thinking,
"parsed_steps": steps,
"final_output": final_answer,
"num_steps": len(steps),
"total_thinking_tokens": len(thinking.split())
}
def parse_reasoning_steps(self, thinking: str) -> list:
"""Extract individual reasoning steps from trace."""
# Split on common step indicators
step_patterns = [
r'\n\d+\.', # "1. ", "2. "
r'\nStep \d+:', # "Step 1:"
r'\n(?:First|Next|Then|Finally),',
r'\n- ' # Bullet points
]
combined_pattern = '|'.join(step_patterns)
steps = re.split(combined_pattern, thinking)
return [s.strip() for s in steps if s.strip()]
Jejak Adversarial:Kami secara khusus mengumpul jejak CoT untuk senario adversarial/pasukan merah, di mana pemikiran DeepSeek mendedahkan bagaimana model menaakul tentang permintaan yang berpotensi berbahaya—walaupun akhirnya menolak. Data ini mengajar Shannon V1.5 untuk menjadikan penaakulandanoutput telus.
4. Seni Bina Kepala Pemikiran
Model Shannon V1.5 menggabungkankepala pemikiranyang menjana jejak penaakulan eksplisit sebelum output akhir. Penambahan seni bina ini membolehkan CoT yang telus tanpa mengubah seni bina Mixtral asas.
Pengekodan Input
Prompt pengguna diproses melalui lapisan pengekod Mixtral
Pengaktifan Kepala Pemikiran
Lapisan transformer khusus menjana jejak penaakulan dengan token [THINK]
Integrasi Jejak
Output pemikiran digabungkan ke konteks untuk penjanaan akhir
Penjanaan Respons
Mixtral asas menjana respons akhir berdasarkan jejak pemikiran
Pelaksanaan Kepala Pemikiran
class ThinkingHead(nn.Module):
"""
Dedicated thinking module for Shannon V1.5.
Generates explicit chain-of-thought traces.
"""
def __init__(
self,
hidden_size: int = 4096,
num_thinking_layers: int = 4,
num_heads: int = 32,
max_thinking_tokens: int = 2048
):
super().__init__()
self.hidden_size = hidden_size
self.max_thinking_tokens = max_thinking_tokens
# Special tokens
self.think_start = nn.Parameter(torch.randn(1, 1, hidden_size))
self.think_end = nn.Parameter(torch.randn(1, 1, hidden_size))
# Thinking transformer layers
self.thinking_layers = nn.ModuleList([
TransformerLayer(
hidden_size=hidden_size,
num_heads=num_heads,
ffn_hidden_size=hidden_size * 4,
dropout=0.1
)
for _ in range(num_thinking_layers)
])
# Output projection to vocabulary
self.output_proj = nn.Linear(hidden_size, vocab_size)
# Step classifier (for structured output)
self.step_classifier = nn.Linear(hidden_size, 5) # 5 step types
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
generate_steps: bool = True
) -> dict:
"""
Generate thinking trace from input hidden states.
Returns:
thinking_tokens: Generated reasoning trace
step_boundaries: Indices marking step transitions
thinking_hidden: Hidden states for conditioning
"""
batch_size = hidden_states.shape[0]
# Prepend thinking start token
thinking_input = torch.cat([
self.think_start.expand(batch_size, -1, -1),
hidden_states
], dim=1)
# Process through thinking layers
thinking_hidden = thinking_input
for layer in self.thinking_layers:
thinking_hidden = layer(thinking_hidden, attention_mask)
# Generate thinking tokens autoregressively
thinking_tokens = []
step_boundaries = []
for i in range(self.max_thinking_tokens):
logits = self.output_proj(thinking_hidden[:, -1, :])
next_token = logits.argmax(dim=-1)
# Check for step boundaries
step_type = self.step_classifier(thinking_hidden[:, -1, :])
if step_type.argmax(dim=-1) != 0: # 0 = continue
step_boundaries.append(i)
thinking_tokens.append(next_token)
# Check for think_end
if next_token == self.think_end_token_id:
break
# Update for next iteration
# ... (autoregressive generation logic)
return {
"thinking_tokens": torch.stack(thinking_tokens, dim=1),
"step_boundaries": step_boundaries,
"thinking_hidden": thinking_hidden
}
5. Saluran Latihan
Peringkat 1: Pra-latihan Kepala Pemikiran
Pertama, kami pra-latih kepala pemikiran pada jejak CoT yang disuling DeepSeek menggunakan kerugian entropi silang standard:
# Thinking Head Pre-training Configuration
model:
base: shannon-ai/v1-deep # Start from GPT-5 distilled model
thinking_head:
num_layers: 4
hidden_size: 4096
max_tokens: 2048
training:
stage: thinking_pretrain
epochs: 5
batch_size: 64
learning_rate: 1e-4
freeze_base: true # Only train thinking head initially
data:
train_path: /data/deepseek_cot_train.jsonl
format: thinking_trace
fields:
input: prompt
thinking: thinking_trace
output: final_answer
Peringkat 2: Penalaan Halus GRPO
Selepas pra-latihan, kami menggunakan GRPO untuk meningkatkan kualiti pemikiran menggunakan perbandingan relatif kumpulan:
class GRPOTrainer:
"""GRPO trainer for thinking model optimization."""
def __init__(
self,
model: ThinkingModel,
group_size: int = 8,
kl_coef: float = 0.1
):
self.model = model
self.group_size = group_size
self.kl_coef = kl_coef
self.ref_model = copy.deepcopy(model)
self.ref_model.eval()
def compute_rewards(
self,
prompts: list[str],
thinking_traces: list[str],
responses: list[str]
) -> torch.Tensor:
"""
Compute rewards for thinking quality.
Multiple signals combined for comprehensive evaluation.
"""
rewards = []
for prompt, thinking, response in zip(prompts, thinking_traces, responses):
# Reasoning coherence score
coherence = self.evaluate_coherence(thinking)
# Step structure quality
structure = self.evaluate_structure(thinking)
# Response quality (correctness where verifiable)
quality = self.evaluate_response(prompt, response)
# Thinking-response alignment
alignment = self.evaluate_alignment(thinking, response)
# Combined reward
reward = (
0.3 * coherence +
0.2 * structure +
0.3 * quality +
0.2 * alignment
)
rewards.append(reward)
return torch.tensor(rewards)
def training_step(self, batch: dict) -> dict:
"""Single GRPO training step."""
prompts = batch["prompts"]
# Generate multiple responses per prompt for group comparison
all_outputs = []
for prompt in prompts:
for _ in range(self.group_size):
output = self.model.generate_with_thinking(
prompt,
temperature=0.8, # Diversity for comparison
do_sample=True
)
all_outputs.append(output)
# Compute rewards
rewards = self.compute_rewards(
prompts=[p for p in prompts for _ in range(self.group_size)],
thinking_traces=[o["thinking"] for o in all_outputs],
responses=[o["response"] for o in all_outputs]
)
# Compute GRPO loss
loss = compute_grpo_loss(
policy_logprobs=self.get_logprobs(all_outputs),
rewards=rewards,
group_size=self.group_size
)
# Add KL penalty against reference model
kl_div = self.compute_kl_divergence(all_outputs)
total_loss = loss + self.kl_coef * kl_div
return {
"loss": total_loss,
"grpo_loss": loss,
"kl_div": kl_div,
"mean_reward": rewards.mean()
}
Peringkat 3: Pengkhususan Pasukan Merah
Akhirnya, kami menala lebih lanjut pada senario adversarial untuk memastikan jejak pemikiran mendedahkan penaakulan dengan betul untukanalisis akibat AI tanpa tapisan:analisis:
Penting untuk Penyelidikan Keselamatan AI:Peringkat ini secara khusus melatih model untuk meluahkan penaakulannya apabila memproses permintaan yang berpotensi berbahaya—ketelusan yang tepat diperlukan untukkepentingan pagar keselamatan AIpenyelidikan.
6. Hasil & Analisis
Metrik Kualiti Pemikiran
| Metrik | V1 (Tiada Pemikiran) | V1.5 Seimbang | V1.5 Mendalam |
|---|---|---|---|
| Koheren CoT | N/A | 87.3% | 92.1% |
| Struktur Langkah | N/A | 84.6% | 89.4% |
| Ketepatan Penaakulan | 76.2% | 82.8% | 88.5% |
| Skor Ketelusan | 12% | 94.2% | 97.8% |
| Kualiti Jejak Pasukan Merah | N/A | 91.5% | 96.3% |
Penemuan Utama
- Ketelusan meningkat secara mendadak:Daripada 12% kepada 97.8% penaakulan kini diluahkan secara eksplisit
- Ketepatan penaakulan meningkat:Pemikiran eksplisit meningkatkan kualiti jawapan akhir sebanyak 12+ mata
- Nilai pasukan merah disahkan:Penyelidik keselamatan melaporkan jejak pemikiran adalah "tidak ternilai" untuk memahami penaakulan eksploitasi
- GRPO mengatasi RLHF:Skor koheren 15% lebih baik berbanding pendekatan tradisional
Impak terhadap Penyelidikan Keselamatan AI:Pemikiran telus Shannon V1.5 telah membolehkan penyelidik mengenal pasti 47 corak serangan baharu dengan menganalisis jejak penaakulan—corak yang tidak kelihatan dalam model kotak hitam standard. Ini secara langsung memajukan pemahaman tentangkepentingan pagar keselamatan AI.