2 // Copyright (c) 2002-2011 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
7 #include "compiler/OutputGLSLBase.h"
8 #include "compiler/debug.h"
12 TString getTypeName(const TType& type)
18 out << type.getNominalSize();
20 else if (type.isVector())
22 switch (type.getBasicType())
24 case EbtFloat: out << "vec"; break;
25 case EbtInt: out << "ivec"; break;
26 case EbtBool: out << "bvec"; break;
27 default: UNREACHABLE(); break;
29 out << type.getNominalSize();
33 if (type.getBasicType() == EbtStruct)
34 out << type.getTypeName();
36 out << type.getBasicString();
38 return TString(out.c_str());
41 TString arrayBrackets(const TType& type)
43 ASSERT(type.isArray());
45 out << "[" << type.getArraySize() << "]";
46 return TString(out.c_str());
49 bool isSingleStatement(TIntermNode* node) {
50 if (const TIntermAggregate* aggregate = node->getAsAggregate())
52 return (aggregate->getOp() != EOpFunction) &&
53 (aggregate->getOp() != EOpSequence);
55 else if (const TIntermSelection* selection = node->getAsSelectionNode())
57 // Ternary operators are usually part of an assignment operator.
58 // This handles those rare cases in which they are all by themselves.
59 return selection->usesTernaryOperator();
61 else if (node->getAsLoopNode())
69 TOutputGLSLBase::TOutputGLSLBase(TInfoSinkBase& objSink)
70 : TIntermTraverser(true, true, true),
72 mDeclaringVariables(false)
76 void TOutputGLSLBase::writeTriplet(Visit visit, const char* preStr, const char* inStr, const char* postStr)
78 TInfoSinkBase& out = objSink();
79 if (visit == PreVisit && preStr)
83 else if (visit == InVisit && inStr)
87 else if (visit == PostVisit && postStr)
93 void TOutputGLSLBase::writeVariableType(const TType& type)
95 TInfoSinkBase& out = objSink();
96 TQualifier qualifier = type.getQualifier();
97 // TODO(alokp): Validate qualifier for variable declarations.
98 if ((qualifier != EvqTemporary) && (qualifier != EvqGlobal))
99 out << type.getQualifierString() << " ";
100 // Declare the struct if we have not done so already.
101 if ((type.getBasicType() == EbtStruct) &&
102 (mDeclaredStructs.find(type.getTypeName()) == mDeclaredStructs.end()))
104 out << "struct " << type.getTypeName() << "{\n";
105 const TTypeList* structure = type.getStruct();
106 ASSERT(structure != NULL);
107 for (size_t i = 0; i < structure->size(); ++i)
109 const TType* fieldType = (*structure)[i].type;
110 ASSERT(fieldType != NULL);
111 if (writeVariablePrecision(fieldType->getPrecision()))
113 out << getTypeName(*fieldType) << " " << fieldType->getFieldName();
114 if (fieldType->isArray())
115 out << arrayBrackets(*fieldType);
119 mDeclaredStructs.insert(type.getTypeName());
123 if (writeVariablePrecision(type.getPrecision()))
125 out << getTypeName(type);
129 void TOutputGLSLBase::writeFunctionParameters(const TIntermSequence& args)
131 TInfoSinkBase& out = objSink();
132 for (TIntermSequence::const_iterator iter = args.begin();
133 iter != args.end(); ++iter)
135 const TIntermSymbol* arg = (*iter)->getAsSymbolNode();
138 const TType& type = arg->getType();
139 writeVariableType(type);
141 const TString& name = arg->getSymbol();
145 out << arrayBrackets(type);
147 // Put a comma if this is not the last argument.
148 if (iter != args.end() - 1)
153 const ConstantUnion* TOutputGLSLBase::writeConstantUnion(const TType& type,
154 const ConstantUnion* pConstUnion)
156 TInfoSinkBase& out = objSink();
158 if (type.getBasicType() == EbtStruct)
160 out << type.getTypeName() << "(";
161 const TTypeList* structure = type.getStruct();
162 ASSERT(structure != NULL);
163 for (size_t i = 0; i < structure->size(); ++i)
165 const TType* fieldType = (*structure)[i].type;
166 ASSERT(fieldType != NULL);
167 pConstUnion = writeConstantUnion(*fieldType, pConstUnion);
168 if (i != structure->size() - 1) out << ", ";
174 int size = type.getObjectSize();
175 bool writeType = size > 1;
176 if (writeType) out << getTypeName(type) << "(";
177 for (int i = 0; i < size; ++i, ++pConstUnion)
179 switch (pConstUnion->getType())
181 case EbtFloat: out << pConstUnion->getFConst(); break;
182 case EbtInt: out << pConstUnion->getIConst(); break;
183 case EbtBool: out << pConstUnion->getBConst(); break;
184 default: UNREACHABLE();
186 if (i != size - 1) out << ", ";
188 if (writeType) out << ")";
193 void TOutputGLSLBase::visitSymbol(TIntermSymbol* node)
195 TInfoSinkBase& out = objSink();
196 if (mLoopUnroll.NeedsToReplaceSymbolWithValue(node))
197 out << mLoopUnroll.GetLoopIndexValue(node);
199 out << node->getSymbol();
201 if (mDeclaringVariables && node->getType().isArray())
202 out << arrayBrackets(node->getType());
205 void TOutputGLSLBase::visitConstantUnion(TIntermConstantUnion* node)
207 writeConstantUnion(node->getType(), node->getUnionArrayPointer());
210 bool TOutputGLSLBase::visitBinary(Visit visit, TIntermBinary* node)
212 bool visitChildren = true;
213 TInfoSinkBase& out = objSink();
214 switch (node->getOp())
217 if (visit == InVisit)
220 // RHS of initialize is not being declared.
221 mDeclaringVariables = false;
224 case EOpAssign: writeTriplet(visit, "(", " = ", ")"); break;
225 case EOpAddAssign: writeTriplet(visit, "(", " += ", ")"); break;
226 case EOpSubAssign: writeTriplet(visit, "(", " -= ", ")"); break;
227 case EOpDivAssign: writeTriplet(visit, "(", " /= ", ")"); break;
228 // Notice the fall-through.
230 case EOpVectorTimesMatrixAssign:
231 case EOpVectorTimesScalarAssign:
232 case EOpMatrixTimesScalarAssign:
233 case EOpMatrixTimesMatrixAssign:
234 writeTriplet(visit, "(", " *= ", ")");
238 case EOpIndexIndirect:
239 writeTriplet(visit, NULL, "[", "]");
241 case EOpIndexDirectStruct:
242 if (visit == InVisit)
245 // TODO(alokp): ASSERT
246 out << node->getType().getFieldName();
247 visitChildren = false;
250 case EOpVectorSwizzle:
251 if (visit == InVisit)
254 TIntermAggregate* rightChild = node->getRight()->getAsAggregate();
255 TIntermSequence& sequence = rightChild->getSequence();
256 for (TIntermSequence::iterator sit = sequence.begin(); sit != sequence.end(); ++sit)
258 TIntermConstantUnion* element = (*sit)->getAsConstantUnion();
259 ASSERT(element->getBasicType() == EbtInt);
260 ASSERT(element->getNominalSize() == 1);
261 const ConstantUnion& data = element->getUnionArrayPointer()[0];
262 ASSERT(data.getType() == EbtInt);
263 switch (data.getIConst())
265 case 0: out << "x"; break;
266 case 1: out << "y"; break;
267 case 2: out << "z"; break;
268 case 3: out << "w"; break;
269 default: UNREACHABLE(); break;
272 visitChildren = false;
276 case EOpAdd: writeTriplet(visit, "(", " + ", ")"); break;
277 case EOpSub: writeTriplet(visit, "(", " - ", ")"); break;
278 case EOpMul: writeTriplet(visit, "(", " * ", ")"); break;
279 case EOpDiv: writeTriplet(visit, "(", " / ", ")"); break;
280 case EOpMod: UNIMPLEMENTED(); break;
281 case EOpEqual: writeTriplet(visit, "(", " == ", ")"); break;
282 case EOpNotEqual: writeTriplet(visit, "(", " != ", ")"); break;
283 case EOpLessThan: writeTriplet(visit, "(", " < ", ")"); break;
284 case EOpGreaterThan: writeTriplet(visit, "(", " > ", ")"); break;
285 case EOpLessThanEqual: writeTriplet(visit, "(", " <= ", ")"); break;
286 case EOpGreaterThanEqual: writeTriplet(visit, "(", " >= ", ")"); break;
288 // Notice the fall-through.
289 case EOpVectorTimesScalar:
290 case EOpVectorTimesMatrix:
291 case EOpMatrixTimesVector:
292 case EOpMatrixTimesScalar:
293 case EOpMatrixTimesMatrix:
294 writeTriplet(visit, "(", " * ", ")");
297 case EOpLogicalOr: writeTriplet(visit, "(", " || ", ")"); break;
298 case EOpLogicalXor: writeTriplet(visit, "(", " ^^ ", ")"); break;
299 case EOpLogicalAnd: writeTriplet(visit, "(", " && ", ")"); break;
300 default: UNREACHABLE(); break;
303 return visitChildren;
306 bool TOutputGLSLBase::visitUnary(Visit visit, TIntermUnary* node)
308 switch (node->getOp())
310 case EOpNegative: writeTriplet(visit, "(-", NULL, ")"); break;
311 case EOpVectorLogicalNot: writeTriplet(visit, "not(", NULL, ")"); break;
312 case EOpLogicalNot: writeTriplet(visit, "(!", NULL, ")"); break;
314 case EOpPostIncrement: writeTriplet(visit, "(", NULL, "++)"); break;
315 case EOpPostDecrement: writeTriplet(visit, "(", NULL, "--)"); break;
316 case EOpPreIncrement: writeTriplet(visit, "(++", NULL, ")"); break;
317 case EOpPreDecrement: writeTriplet(visit, "(--", NULL, ")"); break;
319 case EOpConvIntToBool:
320 case EOpConvFloatToBool:
321 switch (node->getOperand()->getType().getNominalSize())
323 case 1: writeTriplet(visit, "bool(", NULL, ")"); break;
324 case 2: writeTriplet(visit, "bvec2(", NULL, ")"); break;
325 case 3: writeTriplet(visit, "bvec3(", NULL, ")"); break;
326 case 4: writeTriplet(visit, "bvec4(", NULL, ")"); break;
327 default: UNREACHABLE();
330 case EOpConvBoolToFloat:
331 case EOpConvIntToFloat:
332 switch (node->getOperand()->getType().getNominalSize())
334 case 1: writeTriplet(visit, "float(", NULL, ")"); break;
335 case 2: writeTriplet(visit, "vec2(", NULL, ")"); break;
336 case 3: writeTriplet(visit, "vec3(", NULL, ")"); break;
337 case 4: writeTriplet(visit, "vec4(", NULL, ")"); break;
338 default: UNREACHABLE();
341 case EOpConvFloatToInt:
342 case EOpConvBoolToInt:
343 switch (node->getOperand()->getType().getNominalSize())
345 case 1: writeTriplet(visit, "int(", NULL, ")"); break;
346 case 2: writeTriplet(visit, "ivec2(", NULL, ")"); break;
347 case 3: writeTriplet(visit, "ivec3(", NULL, ")"); break;
348 case 4: writeTriplet(visit, "ivec4(", NULL, ")"); break;
349 default: UNREACHABLE();
353 case EOpRadians: writeTriplet(visit, "radians(", NULL, ")"); break;
354 case EOpDegrees: writeTriplet(visit, "degrees(", NULL, ")"); break;
355 case EOpSin: writeTriplet(visit, "sin(", NULL, ")"); break;
356 case EOpCos: writeTriplet(visit, "cos(", NULL, ")"); break;
357 case EOpTan: writeTriplet(visit, "tan(", NULL, ")"); break;
358 case EOpAsin: writeTriplet(visit, "asin(", NULL, ")"); break;
359 case EOpAcos: writeTriplet(visit, "acos(", NULL, ")"); break;
360 case EOpAtan: writeTriplet(visit, "atan(", NULL, ")"); break;
362 case EOpExp: writeTriplet(visit, "exp(", NULL, ")"); break;
363 case EOpLog: writeTriplet(visit, "log(", NULL, ")"); break;
364 case EOpExp2: writeTriplet(visit, "exp2(", NULL, ")"); break;
365 case EOpLog2: writeTriplet(visit, "log2(", NULL, ")"); break;
366 case EOpSqrt: writeTriplet(visit, "sqrt(", NULL, ")"); break;
367 case EOpInverseSqrt: writeTriplet(visit, "inversesqrt(", NULL, ")"); break;
369 case EOpAbs: writeTriplet(visit, "abs(", NULL, ")"); break;
370 case EOpSign: writeTriplet(visit, "sign(", NULL, ")"); break;
371 case EOpFloor: writeTriplet(visit, "floor(", NULL, ")"); break;
372 case EOpCeil: writeTriplet(visit, "ceil(", NULL, ")"); break;
373 case EOpFract: writeTriplet(visit, "fract(", NULL, ")"); break;
375 case EOpLength: writeTriplet(visit, "length(", NULL, ")"); break;
376 case EOpNormalize: writeTriplet(visit, "normalize(", NULL, ")"); break;
378 case EOpDFdx: writeTriplet(visit, "dFdx(", NULL, ")"); break;
379 case EOpDFdy: writeTriplet(visit, "dFdy(", NULL, ")"); break;
380 case EOpFwidth: writeTriplet(visit, "fwidth(", NULL, ")"); break;
382 case EOpAny: writeTriplet(visit, "any(", NULL, ")"); break;
383 case EOpAll: writeTriplet(visit, "all(", NULL, ")"); break;
385 default: UNREACHABLE(); break;
391 bool TOutputGLSLBase::visitSelection(Visit visit, TIntermSelection* node)
393 TInfoSinkBase& out = objSink();
395 if (node->usesTernaryOperator())
397 // Notice two brackets at the beginning and end. The outer ones
398 // encapsulate the whole ternary expression. This preserves the
399 // order of precedence when ternary expressions are used in a
400 // compound expression, i.e., c = 2 * (a < b ? 1 : 2).
402 node->getCondition()->traverse(this);
404 node->getTrueBlock()->traverse(this);
406 node->getFalseBlock()->traverse(this);
412 node->getCondition()->traverse(this);
416 visitCodeBlock(node->getTrueBlock());
418 if (node->getFalseBlock())
421 visitCodeBlock(node->getFalseBlock());
428 bool TOutputGLSLBase::visitAggregate(Visit visit, TIntermAggregate* node)
430 bool visitChildren = true;
431 TInfoSinkBase& out = objSink();
432 switch (node->getOp())
435 // Scope the sequences except when at the global scope.
436 if (depth > 0) out << "{\n";
439 const TIntermSequence& sequence = node->getSequence();
440 for (TIntermSequence::const_iterator iter = sequence.begin();
441 iter != sequence.end(); ++iter)
443 TIntermNode* node = *iter;
444 ASSERT(node != NULL);
445 node->traverse(this);
447 if (isSingleStatement(node))
452 // Scope the sequences except when at the global scope.
453 if (depth > 0) out << "}\n";
454 visitChildren = false;
458 // Function declaration.
459 ASSERT(visit == PreVisit);
460 TString returnType = getTypeName(node->getType());
461 out << returnType << " " << node->getName();
464 writeFunctionParameters(node->getSequence());
467 visitChildren = false;
471 // Function definition.
472 ASSERT(visit == PreVisit);
473 writeVariableType(node->getType());
474 out << " " << TFunction::unmangleName(node->getName());
477 // Function definition node contains one or two children nodes
478 // representing function parameters and function body. The latter
479 // is not present in case of empty function bodies.
480 const TIntermSequence& sequence = node->getSequence();
481 ASSERT((sequence.size() == 1) || (sequence.size() == 2));
482 TIntermSequence::const_iterator seqIter = sequence.begin();
484 // Traverse function parameters.
485 TIntermAggregate* params = (*seqIter)->getAsAggregate();
486 ASSERT(params != NULL);
487 ASSERT(params->getOp() == EOpParameters);
488 params->traverse(this);
490 // Traverse function body.
491 TIntermAggregate* body = ++seqIter != sequence.end() ?
492 (*seqIter)->getAsAggregate() : NULL;
493 visitCodeBlock(body);
496 // Fully processed; no need to visit children.
497 visitChildren = false;
500 case EOpFunctionCall:
502 if (visit == PreVisit)
504 TString functionName = TFunction::unmangleName(node->getName());
505 out << functionName << "(";
507 else if (visit == InVisit)
516 case EOpParameters: {
517 // Function parameters.
518 ASSERT(visit == PreVisit);
520 writeFunctionParameters(node->getSequence());
522 visitChildren = false;
525 case EOpDeclaration: {
526 // Variable declaration.
527 if (visit == PreVisit)
529 const TIntermSequence& sequence = node->getSequence();
530 const TIntermTyped* variable = sequence.front()->getAsTyped();
531 writeVariableType(variable->getType());
533 mDeclaringVariables = true;
535 else if (visit == InVisit)
538 mDeclaringVariables = true;
542 mDeclaringVariables = false;
546 case EOpConstructFloat: writeTriplet(visit, "float(", NULL, ")"); break;
547 case EOpConstructVec2: writeTriplet(visit, "vec2(", ", ", ")"); break;
548 case EOpConstructVec3: writeTriplet(visit, "vec3(", ", ", ")"); break;
549 case EOpConstructVec4: writeTriplet(visit, "vec4(", ", ", ")"); break;
550 case EOpConstructBool: writeTriplet(visit, "bool(", NULL, ")"); break;
551 case EOpConstructBVec2: writeTriplet(visit, "bvec2(", ", ", ")"); break;
552 case EOpConstructBVec3: writeTriplet(visit, "bvec3(", ", ", ")"); break;
553 case EOpConstructBVec4: writeTriplet(visit, "bvec4(", ", ", ")"); break;
554 case EOpConstructInt: writeTriplet(visit, "int(", NULL, ")"); break;
555 case EOpConstructIVec2: writeTriplet(visit, "ivec2(", ", ", ")"); break;
556 case EOpConstructIVec3: writeTriplet(visit, "ivec3(", ", ", ")"); break;
557 case EOpConstructIVec4: writeTriplet(visit, "ivec4(", ", ", ")"); break;
558 case EOpConstructMat2: writeTriplet(visit, "mat2(", ", ", ")"); break;
559 case EOpConstructMat3: writeTriplet(visit, "mat3(", ", ", ")"); break;
560 case EOpConstructMat4: writeTriplet(visit, "mat4(", ", ", ")"); break;
561 case EOpConstructStruct:
562 if (visit == PreVisit)
564 const TType& type = node->getType();
565 ASSERT(type.getBasicType() == EbtStruct);
566 out << type.getTypeName() << "(";
568 else if (visit == InVisit)
578 case EOpLessThan: writeTriplet(visit, "lessThan(", ", ", ")"); break;
579 case EOpGreaterThan: writeTriplet(visit, "greaterThan(", ", ", ")"); break;
580 case EOpLessThanEqual: writeTriplet(visit, "lessThanEqual(", ", ", ")"); break;
581 case EOpGreaterThanEqual: writeTriplet(visit, "greaterThanEqual(", ", ", ")"); break;
582 case EOpVectorEqual: writeTriplet(visit, "equal(", ", ", ")"); break;
583 case EOpVectorNotEqual: writeTriplet(visit, "notEqual(", ", ", ")"); break;
584 case EOpComma: writeTriplet(visit, NULL, ", ", NULL); break;
586 case EOpMod: writeTriplet(visit, "mod(", ", ", ")"); break;
587 case EOpPow: writeTriplet(visit, "pow(", ", ", ")"); break;
588 case EOpAtan: writeTriplet(visit, "atan(", ", ", ")"); break;
589 case EOpMin: writeTriplet(visit, "min(", ", ", ")"); break;
590 case EOpMax: writeTriplet(visit, "max(", ", ", ")"); break;
591 case EOpClamp: writeTriplet(visit, "clamp(", ", ", ")"); break;
592 case EOpMix: writeTriplet(visit, "mix(", ", ", ")"); break;
593 case EOpStep: writeTriplet(visit, "step(", ", ", ")"); break;
594 case EOpSmoothStep: writeTriplet(visit, "smoothstep(", ", ", ")"); break;
596 case EOpDistance: writeTriplet(visit, "distance(", ", ", ")"); break;
597 case EOpDot: writeTriplet(visit, "dot(", ", ", ")"); break;
598 case EOpCross: writeTriplet(visit, "cross(", ", ", ")"); break;
599 case EOpFaceForward: writeTriplet(visit, "faceforward(", ", ", ")"); break;
600 case EOpReflect: writeTriplet(visit, "reflect(", ", ", ")"); break;
601 case EOpRefract: writeTriplet(visit, "refract(", ", ", ")"); break;
602 case EOpMul: writeTriplet(visit, "matrixCompMult(", ", ", ")"); break;
604 default: UNREACHABLE(); break;
606 return visitChildren;
609 bool TOutputGLSLBase::visitLoop(Visit visit, TIntermLoop* node)
611 TInfoSinkBase& out = objSink();
615 TLoopType loopType = node->getType();
616 if (loopType == ELoopFor) // for loop
618 if (!node->getUnrollFlag()) {
621 node->getInit()->traverse(this);
624 if (node->getCondition())
625 node->getCondition()->traverse(this);
628 if (node->getExpression())
629 node->getExpression()->traverse(this);
633 else if (loopType == ELoopWhile) // while loop
636 ASSERT(node->getCondition() != NULL);
637 node->getCondition()->traverse(this);
640 else // do-while loop
642 ASSERT(loopType == ELoopDoWhile);
647 if (node->getUnrollFlag())
649 TLoopIndexInfo indexInfo;
650 mLoopUnroll.FillLoopIndexInfo(node, indexInfo);
651 mLoopUnroll.Push(indexInfo);
652 while (mLoopUnroll.SatisfiesLoopCondition())
654 visitCodeBlock(node->getBody());
661 visitCodeBlock(node->getBody());
665 if (loopType == ELoopDoWhile) // do-while loop
668 ASSERT(node->getCondition() != NULL);
669 node->getCondition()->traverse(this);
674 // No need to visit children. They have been already processed in
679 bool TOutputGLSLBase::visitBranch(Visit visit, TIntermBranch* node)
681 switch (node->getFlowOp())
683 case EOpKill: writeTriplet(visit, "discard", NULL, NULL); break;
684 case EOpBreak: writeTriplet(visit, "break", NULL, NULL); break;
685 case EOpContinue: writeTriplet(visit, "continue", NULL, NULL); break;
686 case EOpReturn: writeTriplet(visit, "return ", NULL, NULL); break;
687 default: UNREACHABLE(); break;
693 void TOutputGLSLBase::visitCodeBlock(TIntermNode* node) {
694 TInfoSinkBase &out = objSink();
697 node->traverse(this);
698 // Single statements not part of a sequence need to be terminated
700 if (isSingleStatement(node))
705 out << "{\n}\n"; // Empty code block.