Files
Cielonos/Assets/Scripts/MainGame/Characters/Automata/AI/Composites/WeightedSelector.cs
SoulliesOfficial f26f9fd374 爆更
2026-03-20 12:07:44 -04:00

334 lines
17 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
using Opsive.BehaviorDesigner.Runtime.Components;
using Opsive.BehaviorDesigner.Runtime.Utility;
using Opsive.BehaviorDesigner.Runtime.Tasks;
using Opsive.GraphDesigner.Runtime;
using Opsive.Shared.Utility;
using Unity.Collections;
using Unity.Entities;
using Unity.Burst;
using UnityEngine;
using System;
using System.Collections.Generic;
using Opsive.GraphDesigner.Runtime.Variables;
namespace Cielonos.MainGame.Characters.AI
{
/// <summary>
/// 用于持有浮点权重的不可变 Blob 数据结构
/// </summary>
public struct FloatBlob {
public BlobArray<float> Weights;
}
[NodeIcon("Assets/Sprites/Icon/Dice.png")]
[Opsive.Shared.Utility.Description("类似于 RandomSelector(随机选择器),但允许为其每个子节点配置不同的出现权重。权重越大的节点越容易被优先执行。\n如果有节点返回 Failure会依据排好的权重顺序继续尝试下一个节点。\n如果所有子节点的权重加起来全都是 0将退化为均匀随机分布。")]
[Category("Cielonos")]
public class WeightedSelector : ECSCompositeTask<WeightedSelectorTaskSystem, WeightedSelectorComponent>, IParentNode, IConditionalAbortParent, IInterruptResponder, ISavableTask, ICloneable
{
[Tooltip("子节点条件预检与中断重估策略。")]
[SerializeField] ConditionalAbortType m_AbortType;
[Tooltip("随机生成种子0 表示使用实体的索引作为默认种子。")]
[SerializeField] uint m_Seed;
[Tooltip("配置各子节点的出现权重 (从左到右匹配子节点,如果列表长度不够,超出的子节点权重默认为 1.0)。")]
[SerializeField] List<SharedVariable<float>> m_Weights = new List<SharedVariable<float>>();
private ushort m_ComponentIndex;
public ConditionalAbortType AbortType { get => m_AbortType; set => m_AbortType = value; }
public uint Seed { get => m_Seed; set => m_Seed = value; }
public List<SharedVariable<float>> Weights { get => m_Weights; set => m_Weights = value; }
public override ComponentType Flag { get => typeof(WeightedSelectorFlag); }
public Type InterruptSystemType { get => typeof(WeightedSelectorInterruptSystem); }
public override WeightedSelectorComponent GetBufferElement()
{
return new WeightedSelectorComponent()
{
Index = RuntimeIndex,
Seed = m_Seed,
};
}
public override int AddBufferElement(World world, Entity entity, GameObject gameObject)
{
m_ComponentIndex = (ushort)base.AddBufferElement(world, entity, gameObject);
var components = world.EntityManager.GetBuffer<WeightedSelectorComponent>(entity);
var component = components[m_ComponentIndex];
// 将 OOP 端配置的权重转换为 DOTS 高性能的 Blob 传给 System
if (m_Weights != null && m_Weights.Count > 0) {
var builder = new BlobBuilder(Allocator.Temp);
ref var root = ref builder.ConstructRoot<FloatBlob>();
var array = builder.Allocate(ref root.Weights, m_Weights.Count);
for (int i = 0; i < m_Weights.Count; i++)
{
array[i] = m_Weights[i].Value;
}
component.WeightsBlob = builder.CreateBlobAssetReference<FloatBlob>(Allocator.Persistent);
builder.Dispose();
}
components[m_ComponentIndex] = component;
return m_ComponentIndex;
}
public MemberVisibility GetSaveReflectionType(int index) { return MemberVisibility.None; }
public object Save(World world, Entity entity)
{
var components = world.EntityManager.GetBuffer<WeightedSelectorComponent>(entity);
var component = components[m_ComponentIndex];
var saveData = new object[2];
saveData[0] = component.ActiveRelativeChildIndex;
if (component.TaskOrder.IsCreated) {
var taskOrder = component.TaskOrder.Value.Indicies.ToArray();
saveData[1] = taskOrder;
}
return saveData;
}
public void Load(object saveData, World world, Entity entity)
{
var components = world.EntityManager.GetBuffer<WeightedSelectorComponent>(entity);
var component = components[m_ComponentIndex];
var taskSaveData = (object[])saveData;
component.ActiveRelativeChildIndex = (ushort)taskSaveData[0];
if (taskSaveData[1] != null) {
var taskOrder = (ushort[])taskSaveData[1];
var builder = new BlobBuilder(Allocator.Temp);
ref var root = ref builder.ConstructRoot<IndiciesBlob>();
var orderArray = builder.Allocate(ref root.Indicies, taskOrder.Length);
for (int i = 0; i < taskOrder.Length; i++) {
orderArray[i] = taskOrder[i];
}
component.TaskOrder = builder.CreateBlobAssetReference<IndiciesBlob>(Allocator.Persistent);
builder.Dispose();
}
components[m_ComponentIndex] = component;
}
public object Clone()
{
var clone = Activator.CreateInstance<WeightedSelector>();
clone.Index = Index;
clone.ParentIndex = ParentIndex;
clone.SiblingIndex = SiblingIndex;
clone.AbortType = AbortType;
if (m_Weights != null) clone.m_Weights = new List<SharedVariable<float>>(m_Weights);
return clone;
}
}
/// <summary>
/// WeightedSelector 的 ECS 数据持有者。
/// </summary>
public struct WeightedSelectorComponent : IBufferElementData
{
public ushort Index;
public ushort ActiveRelativeChildIndex;
public uint Seed;
public Unity.Mathematics.Random RandomNumberGenerator;
public BlobAssetReference<IndiciesBlob> TaskOrder;
public BlobAssetReference<FloatBlob> WeightsBlob;
}
public struct WeightedSelectorFlag : IComponentData, IEnableableComponent { }
/// <summary>
/// WeightedSelector 核心逻辑调度系统,在 Burst 引擎内超高速执行无放回加权随机构建。
/// </summary>
[DisableAutoCreation]
public partial struct WeightedSelectorTaskSystem : ISystem
{
[BurstCompile]
public void OnUpdate(ref SystemState state)
{
foreach (var (branchComponents, taskComponents, selectorComponents, entity) in
SystemAPI.Query<DynamicBuffer<BranchComponent>, DynamicBuffer<TaskComponent>, DynamicBuffer<WeightedSelectorComponent>>().WithAll<WeightedSelectorFlag, EvaluateFlag>().WithEntityAccess()) {
for (int i = 0; i < selectorComponents.Length; ++i) {
var component = selectorComponents[i];
var taskComponent = taskComponents[component.Index];
var branchComponent = branchComponents[taskComponent.BranchIndex];
// 分支被中断,或当前不可执行的情况下跳过
if (branchComponent.InterruptType != InterruptType.None || !branchComponent.CanExecute) {
continue;
}
var branchComponentsBuffer = branchComponents;
var taskComponentsBuffer = taskComponents;
var selectorComponentsBuffer = selectorComponents;
// 当该节点刚刚轮入执行队列时...
if (taskComponent.Status == TaskStatus.Queued) {
taskComponent.Status = TaskStatus.Running;
taskComponentsBuffer[taskComponent.Index] = taskComponent;
if (!component.TaskOrder.IsCreated) {
var childCount = TraversalUtility.GetImmediateChildCount(ref taskComponent, ref taskComponentsBuffer);
var builder = new BlobBuilder(Allocator.Temp);
ref var root = ref builder.ConstructRoot<IndiciesBlob>();
builder.Allocate(ref root.Indicies, childCount);
component.TaskOrder = builder.CreateBlobAssetReference<IndiciesBlob>(Allocator.Persistent);
builder.Dispose();
}
if (component.RandomNumberGenerator.state == 0) {
component.RandomNumberGenerator = Unity.Mathematics.Random.CreateFromIndex(component.Seed != 0 ? component.Seed : (uint)entity.Index);
}
var childCountOriginal = component.TaskOrder.Value.Indicies.Length;
// 由于 DOTS/Burst 中禁止开辟 GC 堆内存,这里使用 Temp 级别的 NativeArray 处理随机抽取序列
var tempIndices = new NativeArray<ushort>(childCountOriginal, Allocator.Temp);
var tempWeights = new NativeArray<float>(childCountOriginal, Allocator.Temp);
var tmpChildIndex = taskComponent.Index + 1;
for (int j = 0; j < childCountOriginal; ++j) {
tempIndices[j] = (ushort)tmpChildIndex;
if (component.WeightsBlob.IsCreated && j < component.WeightsBlob.Value.Weights.Length) {
tempWeights[j] = component.WeightsBlob.Value.Weights[j];
} else {
// 配置中缺失权重位时,兜底给该节点 1.0f 的权重
tempWeights[j] = 1f;
}
// 借助 SiblingIndex 在内存中穿梭跃迁到下一个直系子节点
tmpChildIndex = taskComponentsBuffer[tmpChildIndex].SiblingIndex;
}
ref var initialTaskOrder = ref component.TaskOrder.Value.Indicies;
// 真正的无放回加权打乱算法 (Weighted Shuffle without replacement)
for (int j = 0; j < childCountOriginal; j++) {
int remaining = childCountOriginal - j;
float totWeight = 0;
for (int k = 0; k < remaining; k++) totWeight += tempWeights[k];
int picked = 0;
if (totWeight > 0f) {
float r = component.RandomNumberGenerator.NextFloat(0, totWeight);
float cum = 0;
for (int k = 0; k < remaining; k++) {
cum += tempWeights[k];
if (r <= cum || k == remaining - 1) {
picked = k;
break;
}
}
} else {
// 退化情况:所有子节点加起来权重等于 0那就按纯平均随机去决定
float r = component.RandomNumberGenerator.NextFloat();
picked = (int)Unity.Mathematics.math.floor(r * remaining);
if (picked >= remaining) picked = remaining - 1;
}
// 将抽中的原始子节点 ID 写入最终被执行的序列当中
initialTaskOrder[j] = tempIndices[picked];
// 高效将挑剩下的最后一个元素盖到被抽走的位置上,缩小编历范围
tempIndices[picked] = tempIndices[remaining - 1];
tempWeights[picked] = tempWeights[remaining - 1];
}
// Burst Compile 中必须手动解除 NativeArray 分配的内存
tempIndices.Dispose();
tempWeights.Dispose();
component.ActiveRelativeChildIndex = 0;
selectorComponentsBuffer[i] = component;
branchComponent.NextIndex = initialTaskOrder[component.ActiveRelativeChildIndex];
branchComponentsBuffer[taskComponent.BranchIndex] = branchComponent;
var nextChildTaskComponent = taskComponentsBuffer[branchComponent.NextIndex];
nextChildTaskComponent.Status = TaskStatus.Queued;
taskComponentsBuffer[branchComponent.NextIndex] = nextChildTaskComponent;
} else if (taskComponent.Status != TaskStatus.Running) {
continue;
}
// System 轮询期间:监控当前派发的子节点状况
ref var taskOrder = ref component.TaskOrder.Value.Indicies;
var childTaskComponent = taskComponentsBuffer[taskOrder[component.ActiveRelativeChildIndex]];
if (childTaskComponent.Status == TaskStatus.Queued || childTaskComponent.Status == TaskStatus.Running) {
continue; // 子节点还在挣扎,继续等
}
if (component.ActiveRelativeChildIndex == taskOrder.Length - 1 || childTaskComponent.Status == TaskStatus.Success) {
// 子节点全试完但都崩了,或者是只要有一个成功了,我们就宣告成功/收敛失败
taskComponent.Status = childTaskComponent.Status != TaskStatus.Inactive ? childTaskComponent.Status : TaskStatus.Failure;
component.ActiveRelativeChildIndex = 0;
taskComponentsBuffer[component.Index] = taskComponent;
branchComponent.NextIndex = taskComponent.ParentIndex;
branchComponentsBuffer[taskComponent.BranchIndex] = branchComponent;
} else {
// 刚刚那个子节点失败了,按排好的权重序列推进,尝试下一个
component.ActiveRelativeChildIndex++;
var nextIndex = taskOrder[component.ActiveRelativeChildIndex];
var nextTaskComponent = taskComponentsBuffer[nextIndex];
nextTaskComponent.Status = TaskStatus.Queued;
taskComponentsBuffer[nextIndex] = nextTaskComponent;
branchComponent.NextIndex = nextIndex;
branchComponentsBuffer[taskComponent.BranchIndex] = branchComponent;
}
selectorComponentsBuffer[i] = component;
}
}
}
public void OnDestroy(ref SystemState state)
{
foreach (var selectorComponents in SystemAPI.Query<DynamicBuffer<WeightedSelectorComponent>>()) {
for (int i = 0; i < selectorComponents.Length; ++i) {
var component = selectorComponents[i];
if (component.TaskOrder.IsCreated) {
component.TaskOrder.Dispose();
}
if (component.WeightsBlob.IsCreated) {
component.WeightsBlob.Dispose();
}
}
}
}
}
[DisableAutoCreation]
public partial struct WeightedSelectorInterruptSystem : ISystem
{
[BurstCompile]
public void OnUpdate(ref SystemState state)
{
foreach (var (taskComponents, selectorComponents) in
SystemAPI.Query<DynamicBuffer<TaskComponent>, DynamicBuffer<WeightedSelectorComponent>>().WithAll<InterruptFlag>()) {
for (int i = 0; i < selectorComponents.Length; ++i) {
var component = selectorComponents[i];
var taskComponent = taskComponents[component.Index];
if (taskComponent.Status == TaskStatus.Running && taskComponents[component.TaskOrder.Value.Indicies[component.ActiveRelativeChildIndex]].Status != TaskStatus.Running) {
ushort relativeChildIndex = 0;
int maxChildren = component.TaskOrder.Value.Indicies.Length;
while (relativeChildIndex < maxChildren && taskComponents[component.TaskOrder.Value.Indicies[relativeChildIndex]].Status != TaskStatus.Running) {
relativeChildIndex++;
}
if (relativeChildIndex < maxChildren) {
component.ActiveRelativeChildIndex = relativeChildIndex;
var selectorComponentsBuffer = selectorComponents;
selectorComponentsBuffer[i] = component;
}
}
}
}
}
}
}