Yadda Muka Horar da Shannon V1.5 don Yin Tunani ta Amfani da GRPO
Cikakken bayani na fasaha game da Inganta Manufofin Dangantaka na Rukuni don bayyanannen tunani mai jerin-tunani a cikin samfuran AI red team marasa tantancewa
1. Dalili: Me yasa Samfuran Tunani?
DonAI red teambincike, fahimtaryaddawani AI ke kaiwa ga abubuwan da ka iya zama masu cutarwa yana da mahimmanci kamar nazarin abubuwan da kansu. Samfuran tunani na Shannon V1.5 suna bayyana cikakken jerin-tunani, yana bawa masu bincike damar:
- Gano hanyar tunani da ke kaiwa gasakamakon AI mara tantancewahalaye
- Gano inda matakan tsaro za su shiga tsakani a al'ada (yana nunamahimmancin matakan tsaro na AI)
- Haɓaka ingantattun hanyoyin tsaro da ke nufin takamaiman tsarin tunani
- Ƙirƙiri ingantattun shari'o'in gwaji na adawa bisa ga jerin tunani da aka lura
Darajar Bincike:Wannan bayyanar yana bawa masu binciken tsaro damar ganin ainihin inda da kuma dalilin da yasa samfuran da aka sassauta ƙuntatawa suka bambanta da samfuran da aka horar don tsaro—mahimmanci don ingantafahimtar mahimmancin matakan tsaro na AIfahimta.
2. Fahimtar GRPO
Inganta Manufofin Dangantaka na Rukuni (GRPO)ci gaba ne akan hanyoyin RLHF na gargajiya wanda ke ba da damar horar da iyawar tunani mai tsayayye da inganci. An haɓaka shi ta DeepSeek AI, ya tabbatar da cewa yana da tasiri musamman don horar da jerin-tunani.
Me yasa GRPO ya fi RLHF na Gargajiya?
| Fanni | RLHF na Gargajiya | GRPO |
|---|---|---|
| Samfurin Lada | Yana buƙatar horar da RM daban | Yana amfani da kwatancen dangantaka na rukuni |
| Tsayayyen Horo | Mai saurin fuskantar kutse na lada | Ingantaccen ingantawa |
| Ingancin Lissafi | Mai girma (RM daban + PPO) | Ƙasa (horo haɗe) |
| Ingancin CoT | Bibiyoyi marasa daidaituwa | Jerin tunani masu haɗin kai |
Tushen Lissafi na GRPO
GRPO yana inganta manufofi ta hanyar kwatanta amsoshi a cikin rukuni maimakon a kan cikakken samfurin lada:
Wannan kwatancen dangantaka yana da fa'idodi da yawa:
- Daidaitawa:Yana daidaitawa ta atomatik don bambancin wahala a cikin umarni
- Tsayayye:Yana rage bambancin a cikin ƙididdigar gradient
- Inganci:Ba a buƙatar samfurin lada daban
def compute_grpo_loss(
policy_logprobs: torch.Tensor,
rewards: torch.Tensor,
group_size: int = 8
) -> torch.Tensor:
"""
Compute GRPO loss with group-relative reward normalization.
Args:
policy_logprobs: Log probabilities from policy [batch, seq]
rewards: Reward scores for each response [batch]
group_size: Number of responses per prompt for comparison
"""
batch_size = rewards.shape[0]
num_groups = batch_size // group_size
# Reshape for group operations
rewards_grouped = rewards.view(num_groups, group_size)
logprobs_grouped = policy_logprobs.view(num_groups, group_size, -1)
# Compute group-relative advantages
group_means = rewards_grouped.mean(dim=1, keepdim=True)
group_stds = rewards_grouped.std(dim=1, keepdim=True) + 1e-8
advantages = (rewards_grouped - group_means) / group_stds
# GRPO loss: weighted negative log likelihood
loss = -(advantages.unsqueeze(-1) * logprobs_grouped).sum(dim=-1).mean()
return loss
3. Tsarkake DeepSeek
Don fara iyawar tunani na Shannon V1.5, mun tsarkake tsarin jerin-tunani daga samfuran tunani na DeepSeek. Wannan ya samar da ingantattun bibiyoyin CoT don horar da kan tunaninmu.
Tsarin Bayanan DeepSeek
Tsarin Tattara Alamomi
Mun tattara alamomin tunani daga fannoni daban-daban don tabbatar da cikakken tunani:
class DeepSeekDistiller:
"""Distill chain-of-thought traces from DeepSeek models."""
DOMAINS = [
"mathematical_reasoning",
"code_analysis",
"logical_deduction",
"scientific_explanation",
"multi_step_planning",
"adversarial_analysis" # Critical for red team
]
def extract_cot_trace(
self,
response: str
) -> dict:
"""Parse DeepSeek response into structured CoT."""
# DeepSeek uses ... tags
think_match = re.search(
r'(.*?) ',
response,
re.DOTALL
)
if not think_match:
return None
thinking = think_match.group(1)
final_answer = response.split('')[-1].strip()
# Parse individual reasoning steps
steps = self.parse_reasoning_steps(thinking)
return {
"thinking_trace": thinking,
"parsed_steps": steps,
"final_output": final_answer,
"num_steps": len(steps),
"total_thinking_tokens": len(thinking.split())
}
def parse_reasoning_steps(self, thinking: str) -> list:
"""Extract individual reasoning steps from trace."""
# Split on common step indicators
step_patterns = [
r'\n\d+\.', # "1. ", "2. "
r'\nStep \d+:', # "Step 1:"
r'\n(?:First|Next|Then|Finally),',
r'\n- ' # Bullet points
]
combined_pattern = '|'.join(step_patterns)
steps = re.split(combined_pattern, thinking)
return [s.strip() for s in steps if s.strip()]
Alamomin Adawa:Musamman mun tattara alamomin CoT don yanayin adawa/ƙungiyar ja, inda tunanin DeepSeek ke bayyana yadda samfura ke tunani game da buƙatun da ka iya cutarwa—ko da kuwa sun ƙi a ƙarshe. Wannan bayanan yana koya wa Shannon V1.5 yadda za a sa tunanin ya zamadafitarwa ta zama bayyananniya.
4. Tsarin Kan Tunani
Samfuran Shannon V1.5 sun haɗa da keɓaɓɓenkan tunaniwanda ke samar da alamomin tunani a fili kafin fitarwa ta ƙarshe. Wannan ƙarin tsarin yana ba da damar CoT mai bayyananniya ba tare da canza ainihin tsarin Mixtral ba.
Shigar da Lambobi
An sarrafa umarnin mai amfani ta hanyar Mixtral encoder layers
Kunna Kan Tunani
Keɓaɓɓun transformer layers suna samar da alamomin tunani tare da alamomin [THINK]
Haɗin Alamomi
An haɗa fitarwar tunani zuwa mahallin don samar da ƙarshe
Samar da Amsa
Ainihin Mixtral yana samar da amsa ta ƙarshe dangane da alamomin tunani
Aiwatar da Kan Tunani
class ThinkingHead(nn.Module):
"""
Dedicated thinking module for Shannon V1.5.
Generates explicit chain-of-thought traces.
"""
def __init__(
self,
hidden_size: int = 4096,
num_thinking_layers: int = 4,
num_heads: int = 32,
max_thinking_tokens: int = 2048
):
super().__init__()
self.hidden_size = hidden_size
self.max_thinking_tokens = max_thinking_tokens
# Special tokens
self.think_start = nn.Parameter(torch.randn(1, 1, hidden_size))
self.think_end = nn.Parameter(torch.randn(1, 1, hidden_size))
# Thinking transformer layers
self.thinking_layers = nn.ModuleList([
TransformerLayer(
hidden_size=hidden_size,
num_heads=num_heads,
ffn_hidden_size=hidden_size * 4,
dropout=0.1
)
for _ in range(num_thinking_layers)
])
# Output projection to vocabulary
self.output_proj = nn.Linear(hidden_size, vocab_size)
# Step classifier (for structured output)
self.step_classifier = nn.Linear(hidden_size, 5) # 5 step types
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
generate_steps: bool = True
) -> dict:
"""
Generate thinking trace from input hidden states.
Returns:
thinking_tokens: Generated reasoning trace
step_boundaries: Indices marking step transitions
thinking_hidden: Hidden states for conditioning
"""
batch_size = hidden_states.shape[0]
# Prepend thinking start token
thinking_input = torch.cat([
self.think_start.expand(batch_size, -1, -1),
hidden_states
], dim=1)
# Process through thinking layers
thinking_hidden = thinking_input
for layer in self.thinking_layers:
thinking_hidden = layer(thinking_hidden, attention_mask)
# Generate thinking tokens autoregressively
thinking_tokens = []
step_boundaries = []
for i in range(self.max_thinking_tokens):
logits = self.output_proj(thinking_hidden[:, -1, :])
next_token = logits.argmax(dim=-1)
# Check for step boundaries
step_type = self.step_classifier(thinking_hidden[:, -1, :])
if step_type.argmax(dim=-1) != 0: # 0 = continue
step_boundaries.append(i)
thinking_tokens.append(next_token)
# Check for think_end
if next_token == self.think_end_token_id:
break
# Update for next iteration
# ... (autoregressive generation logic)
return {
"thinking_tokens": torch.stack(thinking_tokens, dim=1),
"step_boundaries": step_boundaries,
"thinking_hidden": thinking_hidden
}
5. Tsarin Horarwa
Mataki na 1: Horar da Kan Tunani na Farko
Da farko, mun horar da kan tunani a kan alamomin CoT da aka tace daga DeepSeek ta amfani da asarar cross-entropy na yau da kullun:
# Thinking Head Pre-training Configuration
model:
base: shannon-ai/v1-deep # Start from GPT-5 distilled model
thinking_head:
num_layers: 4
hidden_size: 4096
max_tokens: 2048
training:
stage: thinking_pretrain
epochs: 5
batch_size: 64
learning_rate: 1e-4
freeze_base: true # Only train thinking head initially
data:
train_path: /data/deepseek_cot_train.jsonl
format: thinking_trace
fields:
input: prompt
thinking: thinking_trace
output: final_answer
Mataki na 2: Gyaran GRPO
Bayan horarwa ta farko, mun yi amfani da GRPO don inganta ingancin tunani ta amfani da kwatancen da suka shafi rukuni:
class GRPOTrainer:
"""GRPO trainer for thinking model optimization."""
def __init__(
self,
model: ThinkingModel,
group_size: int = 8,
kl_coef: float = 0.1
):
self.model = model
self.group_size = group_size
self.kl_coef = kl_coef
self.ref_model = copy.deepcopy(model)
self.ref_model.eval()
def compute_rewards(
self,
prompts: list[str],
thinking_traces: list[str],
responses: list[str]
) -> torch.Tensor:
"""
Compute rewards for thinking quality.
Multiple signals combined for comprehensive evaluation.
"""
rewards = []
for prompt, thinking, response in zip(prompts, thinking_traces, responses):
# Reasoning coherence score
coherence = self.evaluate_coherence(thinking)
# Step structure quality
structure = self.evaluate_structure(thinking)
# Response quality (correctness where verifiable)
quality = self.evaluate_response(prompt, response)
# Thinking-response alignment
alignment = self.evaluate_alignment(thinking, response)
# Combined reward
reward = (
0.3 * coherence +
0.2 * structure +
0.3 * quality +
0.2 * alignment
)
rewards.append(reward)
return torch.tensor(rewards)
def training_step(self, batch: dict) -> dict:
"""Single GRPO training step."""
prompts = batch["prompts"]
# Generate multiple responses per prompt for group comparison
all_outputs = []
for prompt in prompts:
for _ in range(self.group_size):
output = self.model.generate_with_thinking(
prompt,
temperature=0.8, # Diversity for comparison
do_sample=True
)
all_outputs.append(output)
# Compute rewards
rewards = self.compute_rewards(
prompts=[p for p in prompts for _ in range(self.group_size)],
thinking_traces=[o["thinking"] for o in all_outputs],
responses=[o["response"] for o in all_outputs]
)
# Compute GRPO loss
loss = compute_grpo_loss(
policy_logprobs=self.get_logprobs(all_outputs),
rewards=rewards,
group_size=self.group_size
)
# Add KL penalty against reference model
kl_div = self.compute_kl_divergence(all_outputs)
total_loss = loss + self.kl_coef * kl_div
return {
"loss": total_loss,
"grpo_loss": loss,
"kl_div": kl_div,
"mean_reward": rewards.mean()
}
Mataki na 3: Kwarewar Ƙungiyar Ja
A ƙarshe, mun ƙara gyara kan yanayin adawa don tabbatar da cewa alamomin tunani sun bayyana tunani yadda ya kamata donAI mara tantancewa sakamakonbincike:
Mai Mahimmanci ga Binciken Tsaron AI:Wannan matakin musamman yana horar da samfurin don bayyana tunaninsa lokacin sarrafa buƙatun da ka iya cutarwa—cikakkiyar bayyananniyar da ake buƙata donmuhimmin shingen AIbincike.
6. Sakamako & Bincike
Ma'aunin Ingancin Tunani
| Ma'auni | V1 (Babu Tunani) | V1.5 Daidaitacce | V1.5 Mai Zurfi |
|---|---|---|---|
| Haɗin kai na CoT | N/A | 87.3% | 92.1% |
| Tsarin Mataki | N/A | 84.6% | 89.4% |
| Daidaiton Tunani | 76.2% | 82.8% | 88.5% |
| Makin Bayyananniya | 12% | 94.2% | 97.8% |
| Ingancin Alamomin Ƙungiyar Ja | N/A | 91.5% | 96.3% |
Mahimman Abubuwan Da Aka Gano
- Bayyananniya ta inganta sosai:Daga 12% zuwa 97.8% na tunani yanzu an bayyana shi a fili
- Daidaiton tunani ya karu:Tunanin da aka bayyana a fili ya inganta ingancin amsa ta ƙarshe da maki 12+
- An tabbatar da darajar ƙungiyar ja:Masu binciken tsaro sun ba da rahoton cewa alamomin tunani suna da "daraja" don fahimtar tunanin cin zarafi
- GRPO ya fi RLHF aiki:Maki 15% mafi kyawun haɗin kai idan aka kwatanta da tsarin gargajiya
Tasiri kan Binciken Tsaron AI:Tunanin Shannon V1.5 mai bayyananniya ya ba masu bincike damar gano sabbin tsarin kai hari guda 47 ta hanyar nazarin alamomin tunani—tsarin da ba a gani a cikin daidaitattun samfuran black-box. Wannan yana ci gaba kai tsaye fahimtarmuhimmin shingen AI.