Paul Bird
Upload RunAutomata.cs
72c0bcf verified
raw
history blame
9.76 kB
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.Sentis;
using System.IO;
using Lays = Unity.Sentis.Layers;
/*
* Neural Cellular Automata Inference Code
* =======================================
*
* Put this script on the Main Camera
* Create an image or quad in the scene.
* Assign an unlit transparent material to the image/quad.
* Draw the same material into the outputMaterial field
* Add the *.sentis files to the Assets/StreamingAssets folder
*
*/
public class RunAutomata : MonoBehaviour
{
//Change this to load a different model:
public AutomataNames automataName = AutomataNames.Poop;
//Reduce this to make it run slower
[Range(0f, 1f)]
public float stepSize = 1.0f;
const BackendType backend = BackendType.GPUCompute;
//Drag your unlit transparent material here for drawing the output
public Material outputMaterial;
//optional material for average alpha
public Material avgAlphaMaterial;
public enum AutomataNames { Lizard, Turtle ,Poop};
//Model parameters
const int trainedResolution = 40;
const int trainedPool = 16;
const int alphaBlocks = 4;
int m_paddedImageSize;
int m_trainedHiddenStates;
//Workers to run the networks
private IWorker m_WorkerStateUpdate;
private IWorker m_WorkerClip;
private TensorFloat m_currentStateTensor;
private RenderTexture m_currentStateTexture;
private RenderTexture m_currentBlockAlphaStateTexture;
Ops m_ops;
ITensorAllocator m_allocator;
void Start()
{
m_allocator = new TensorCachingAllocator();
m_ops = WorkerFactory.CreateOps(backend, m_allocator);
Application.targetFrameRate = 60;
LoadAutomataModel();
CreateProcessingModel();
SetupState();
SetupTextures();
DrawDotAt(m_paddedImageSize / 2, m_paddedImageSize / 2);
}
void LoadAutomataModel() {
Model m_ModelStateUpdate = null;
switch (automataName) {
case AutomataNames.Lizard:
m_ModelStateUpdate = ModelLoader.Load(Application.streamingAssetsPath + "/lizard.sentis");
break;
case AutomataNames.Turtle:
m_ModelStateUpdate = ModelLoader.Load(Application.streamingAssetsPath + "/turtle.sentis");
break;
case AutomataNames.Poop:
m_ModelStateUpdate = ModelLoader.Load(Application.streamingAssetsPath + "/poop.sentis");
break;
}
m_trainedHiddenStates = m_ModelStateUpdate.inputs[0].shape[3].value;
m_paddedImageSize = trainedResolution + trainedPool * 2;
m_WorkerStateUpdate = WorkerFactory.CreateWorker(backend, m_ModelStateUpdate, false);
}
void CreateProcessingModel() {
var m_Model = new Model();
var input0 = new Model.Input
{
name = "input0",
shape = (new SymbolicTensorShape(1, m_trainedHiddenStates, m_paddedImageSize, m_paddedImageSize)),
dataType=DataType.Float
};
var input1 = new Model.Input
{
name = "input1",
shape = (new SymbolicTensorShape(1, m_trainedHiddenStates, m_paddedImageSize, m_paddedImageSize)),
dataType = DataType.Float
};
var inputStepSize = new Model.Input
{
name = "inputStepSize",
shape = new SymbolicTensorShape(1, 1, 1, 1),
dataType = DataType.Float
};
m_Model.inputs.Add(input0);
m_Model.inputs.Add(input1);
m_Model.inputs.Add(inputStepSize);
m_Model.AddConstant(new Lays.Constant("aliveRate", new TensorFloat(new TensorShape(1, 1, 1, 1), new[] { 0.1f })));
m_Model.AddConstant(new Lays.Constant("sliceStarts", new int[] { 0, 3, 0, 0 }));
m_Model.AddConstant(new Lays.Constant("sliceEnds", new[] { 1, 4 ,m_paddedImageSize, m_paddedImageSize }));
m_Model.AddLayer(new Lays.Slice("sliceI0", "input0", "sliceStarts", "sliceEnds"));
m_Model.AddLayer(new Lays.MaxPool("maxpool0", "sliceI0", new[] { 3, 3 }, new[] { 1, 1 }, new[] { 1, 1, 1, 1 }));
m_Model.AddLayer(new Lays.Greater("pre_life_mask", "maxpool0", "aliveRate")); //INT
m_Model.AddLayer(new Lays.Mul("input1_stepsize", "input1", "inputStepSize" ));
m_Model.AddLayer(new Lays.RandomUniform("random", new int[] { 1, 1, m_paddedImageSize, m_paddedImageSize}, 0.0f, 1.0f, 0));
m_Model.AddConstant(new Lays.Constant("fireRate", new TensorFloat(new TensorShape(1, 1, 1, 1), new[] { 0.5f })));
m_Model.AddLayer(new Lays.LessOrEqual("lessEqualFireRateINT", "random", "fireRate"));
m_Model.AddLayer(new Lays.Cast("lessEqualFireRate", "lessEqualFireRateINT", DataType.Float));
m_Model.AddLayer(new Lays.Mul("mul", "input1_stepsize", "lessEqualFireRate" ));
m_Model.AddLayer(new Lays.Add("add", "input0", "mul" ));
m_Model.AddLayer(new Lays.Slice("sliceI1", "add", "sliceStarts", "sliceEnds"));
m_Model.AddLayer(new Lays.MaxPool("maxpool1", "sliceI1", new [] { 3 ,3 }, new[] { 1, 1 }, new[] {1, 1, 1, 1}));
m_Model.AddLayer(new Lays.Greater("post_life_mask", "maxpool1", "aliveRate"));
m_Model.AddLayer(new Lays.And("andINT", "pre_life_mask", "post_life_mask"));
m_Model.AddLayer(new Lays.Cast("and", "andINT", DataType.Float));
m_Model.AddLayer(new Lays.Mul("outputState", "add", "and" ));
m_Model.AddConstant(new Lays.Constant("sliceStarts2", new[] { 0, 0, trainedPool, trainedPool }));
m_Model.AddConstant(new Lays.Constant("sliceEnds2", new[] { 1, 4, m_paddedImageSize - trainedPool, m_paddedImageSize - trainedPool }));
m_Model.AddLayer(new Lays.Slice("outputImage", "outputState", "sliceStarts2", "sliceEnds2"));
m_Model.AddLayer(new Lays.Slice("outputIC", "outputImage", "sliceStarts", "sliceEnds"));
int blockSize = trainedResolution / alphaBlocks;
m_Model.AddLayer(new Lays.AveragePool("avgPoolBlocks", "outputIC", new[] { blockSize, blockSize }, new[] { blockSize, blockSize }, new[] { 1, 1, 1, 1 }));
m_Model.outputs.Add("outputState");
m_Model.outputs.Add("outputImage");
m_Model.outputs.Add("avgPoolBlocks");
m_WorkerClip = WorkerFactory.CreateWorker(backend, m_Model);
}
void SetupState()
{
float[] data = new float[1 * m_paddedImageSize * m_paddedImageSize * m_trainedHiddenStates];
m_currentStateTensor = new TensorFloat(new TensorShape(1, m_trainedHiddenStates, m_paddedImageSize, m_paddedImageSize), data);
}
void SetupTextures()
{
m_currentStateTexture = new RenderTexture(trainedResolution, trainedResolution, 0)
{
enableRandomWrite = true
};
outputMaterial.mainTexture = m_currentStateTexture;
if (avgAlphaMaterial)
{
m_currentBlockAlphaStateTexture = new RenderTexture(alphaBlocks, alphaBlocks, 0)
{
enableRandomWrite = true
};
outputMaterial.mainTexture = m_currentBlockAlphaStateTexture;
}
}
void DrawDotAt(int x,int y)
{
m_currentStateTensor.MakeReadable();
float[] data = m_currentStateTensor.ToReadOnlyArray();
for (int k = 3; k < 16; k++)
{
data[m_paddedImageSize * m_paddedImageSize * k + m_paddedImageSize * y + x] = 1f;
}
Replace(ref m_currentStateTensor, new TensorFloat(m_currentStateTensor.shape, data));
}
void Update()
{
DoInference();
if (Input.GetKeyDown(KeyCode.Escape))
{
Application.Quit();
}
if (Input.GetKeyDown(KeyCode.Space))
{
DrawDotAt(UnityEngine.Random.Range(0, m_paddedImageSize), UnityEngine.Random.Range(0, m_paddedImageSize));
}
}
void Replace(ref TensorFloat A, TensorFloat B)
{
A?.Dispose();
A = B;
}
void DoInference() {
using var stepSizeTensor = new TensorFloat(new TensorShape(1, 1, 1, 1), new float[] { stepSize });
using var currentStateTensorT = m_ops.Transpose(m_currentStateTensor, new int[] { 0, 2, 3, 1 });
m_WorkerStateUpdate.Execute(currentStateTensorT);
TensorFloat outputStateT = m_WorkerStateUpdate.PeekOutput() as TensorFloat;
using var outputState = m_ops.Transpose(outputStateT, new int[] { 0, 3, 1, 2 });
var inputs = new Dictionary<string, Tensor>() {
{ "input0", m_currentStateTensor }, //float
{ "input1", outputState }, //float
{ "inputStepSize", stepSizeTensor } //float
};
m_WorkerClip.Execute(inputs);
TensorFloat clippedState = m_WorkerClip.PeekOutput("outputState") as TensorFloat;
TensorFloat outputImage = m_WorkerClip.PeekOutput("outputImage") as TensorFloat;
TensorFloat blockAvgAlphaState = m_WorkerClip.PeekOutput("avgPoolBlocks") as TensorFloat;
if (m_currentStateTexture)
{
TextureConverter.RenderToTexture(outputImage, m_currentStateTexture);
}
if (m_currentBlockAlphaStateTexture)
{
TextureConverter.RenderToTexture(blockAvgAlphaState, m_currentBlockAlphaStateTexture);
}
Replace(ref m_currentStateTensor, clippedState);
m_currentStateTensor.TakeOwnership();
}
void OnDestroy()
{
m_currentStateTensor.Dispose();
m_WorkerStateUpdate.Dispose();
m_WorkerClip.Dispose();
m_ops?.Dispose();
m_allocator?.Dispose();
}
}