Skip to content

Commit e861255

Browse files
committed
fix query
1 parent 6494bc3 commit e861255

4 files changed

Lines changed: 18 additions & 132 deletions

File tree

cpp/src/security/DecOverflowWhenComparing/DecOverflowWhenComparing.c

Lines changed: 0 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -3,100 +3,17 @@
33
#include <stdlib.h>
44
#include <string.h>
55

6-
size_t min(size_t a, size_t b) {
7-
return a < b ? a : b;
8-
}
9-
10-
void MlasReorderInputNhwc(
11-
const float* S,
12-
float* D,
13-
size_t InputChannels,
14-
size_t RowCount,
15-
size_t FullRowCount
16-
)
17-
{
18-
const size_t BlockSize = FullRowCount % 123;
19-
20-
//
21-
// Iterate over batches of the input size to improve locality.
22-
//
23-
24-
for (size_t OuterRowCountRemaining = RowCount; OuterRowCountRemaining > 0; ) {
25-
26-
size_t OuterRowCountBatch = 32;
27-
28-
const size_t OuterRowCountThisIteration = min(OuterRowCountRemaining, OuterRowCountBatch);
29-
OuterRowCountRemaining -= OuterRowCountThisIteration;
30-
31-
//
32-
// Iterate over BlockSize batches of the input channels.
33-
//
34-
35-
const float* s = S;
36-
float* d = D;
37-
38-
for (size_t i = InputChannels; i > 0;) {
39-
40-
const size_t InputChannelsThisIteration = min(i, BlockSize);
41-
i -= InputChannelsThisIteration;
42-
43-
const float* ss = s;
44-
float* dd = d;
45-
size_t InnerRowCountRemaining = OuterRowCountThisIteration;
46-
47-
if (InputChannelsThisIteration == BlockSize) {
48-
49-
if (BlockSize == 8) {
50-
51-
while (InnerRowCountRemaining-- > 0) {
52-
ss += InputChannels;
53-
dd += 8;
54-
}
55-
56-
} else {
57-
58-
while (InnerRowCountRemaining-- > 0) {
59-
ss += InputChannels;
60-
dd += 16;
61-
}
62-
}
63-
64-
} else {
65-
66-
size_t BlockPadding = BlockSize - InputChannelsThisIteration;
67-
68-
while (InnerRowCountRemaining-- > 0) {
69-
ss += InputChannels;
70-
dd += BlockSize;
71-
}
72-
}
73-
74-
s += InputChannelsThisIteration;
75-
d += BlockSize * FullRowCount;
76-
}
77-
78-
S += InputChannels * OuterRowCountThisIteration;
79-
D += BlockSize * OuterRowCountThisIteration;
80-
}
81-
}
82-
836
// from https://github.com/apple-oss-distributions/Libinfo/blob/9fce29e5c5edc15d3ecea55116ca17d3f6350603/lookup.subproj/mdns_module.c#L1033C1-L1079C2
847
char* _mdns_parse_domain_name(const uint8_t *data, uint32_t datalen)
858
{
869
int i = 0, j = 0;
87-
// uint32_t len;
8810
uint32_t domainlen = 0;
8911
char *domain = NULL;
9012

9113
if ((data == NULL) || (datalen == 0)) return NULL;
9214

93-
/*
94-
* i: index into input data
95-
* j: index into output string
96-
*/
9715
while (datalen-- > 0)
9816
{
99-
// printf("%d\n", len);
10017
uint32_t len = data[i++];
10118
domainlen += (len + 1);
10219
domain = reallocf(domain, domainlen);
@@ -136,4 +53,3 @@ int main() {
13653
memcpy(data, "\x04quildu\x03xyz\x00", 11);
13754
_mdns_parse_domain_name(data, datalen);
13855
}
139-

cpp/src/security/DecOverflowWhenComparing/DecOverflowWhenComparing.ql

Lines changed: 13 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
* @name Decrementation overflow when comparing
33
* @id tob/cpp/dec-overflow-when-comparing
44
* @description This query finds unsigned integer overflows resulting from unchecked decrementation during comparison.
5-
* @kind graph
5+
* @kind problem
66
* @tags security
77
* @problem.severity error
88
* @precision high
@@ -14,49 +14,24 @@ import cpp
1414
import semmle.code.cpp.ir.IR
1515
import semmle.code.cpp.rangeanalysis.SimpleRangeAnalysis
1616

17-
query predicate nodes(ControlFlowNode node, string key, string value) {
18-
exists(Variable var, PostfixDecrExpr dec |
19-
dec.getOperand() = var.getAnAccess().getExplicitlyConverted() and
20-
var.getUnderlyingType().(IntegralType).isUnsigned() and
21-
successorGuarded(node, _, var) and
22-
key = node.toString() and
23-
value = node.toString() + "-val"
24-
)
25-
}
26-
27-
query predicate edges(ControlFlowNode source, ControlFlowNode target, string key, string value) {
28-
exists(Variable var, PostfixDecrExpr dec, VariableAccess acc |
29-
var.getAnAccess() = acc and
30-
dec.getOperand() = acc.getExplicitlyConverted() and
31-
var.getUnderlyingType().(IntegralType).isUnsigned() and
32-
33-
source.getASuccessor() = target and
34-
35-
key = source.toString() + "-key" and
36-
value = target.toString() + "-val"
37-
)
38-
}
39-
40-
query predicate graphProperties(string key, string value) {
41-
key = "semmle.graphKind" and value = "graph"
17+
/**
18+
* Holds if `node` overwrites `var` (assignment or declaration with initializer).
19+
*/
20+
predicate isDefOf(ControlFlowNode node, Variable var) {
21+
node = var.getAnAccess() and node.(VariableAccess).isLValue()
22+
or
23+
node.(DeclStmt).getADeclaration() = var and exists(var.getInitializer())
4224
}
4325

4426
/**
45-
* Find CFG paths from start to end that do not cross over node that is var's lvalue access
46-
* TODO: there must be an API for that...
27+
* Find CFG paths from start to end that do not cross over a definition of var.
4728
*/
4829
predicate successorGuarded(ControlFlowNode start, ControlFlowNode end, Variable var) {
4930
start = end
5031
or
5132
exists(ControlFlowNode interm |
5233
start.getASuccessor() = interm and
53-
54-
// break the path if variable is overwritten
55-
not (
56-
interm = var.getAnAccess() and
57-
interm.(VariableAccess).isLValue()
58-
) and
59-
34+
not isDefOf(interm, var) and
6035
(
6136
interm.getASuccessor() = end
6237
or
@@ -65,7 +40,7 @@ predicate successorGuarded(ControlFlowNode start, ControlFlowNode end, Variable
6540
)
6641
}
6742

68-
/*
43+
6944
from Variable var, VariableAccess varAcc, PostfixDecrExpr dec,
7045
VariableAccess varAccAfterOverflow, ComparisonOperation cmp
7146
where
@@ -103,9 +78,7 @@ where
10378
// only if var may possibly be zero during comparison
10479
lowerBound(varAcc) = 0
10580

106-
// skip tests etc
107-
and not dec.getFile().getAbsolutePath().toLowerCase().matches(["%test%", "%vendor%", "%third_party%"])
81+
// skip vendor code
82+
and not dec.getFile().getAbsolutePath().toLowerCase().matches(["%vendor%", "%third_party%"])
10883

10984
select dec, "Unsigned decrementation in comparison ($@) - $@", cmp, cmp.toString(), varAccAfterOverflow, varAccAfterOverflow.toString()
110-
111-
*/

cpp/test/query-tests/security/DecOverflowWhenComparing/DecOverflowWhenComparing.c

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,14 @@
77
char* _mdns_parse_domain_name(const uint8_t *data, uint32_t datalen)
88
{
99
int i = 0, j = 0;
10-
uint32_t len;
1110
uint32_t domainlen = 0;
1211
char *domain = NULL;
1312

1413
if ((data == NULL) || (datalen == 0)) return NULL;
1514

16-
/*
17-
* i: index into input data
18-
* j: index into output string
19-
*/
2015
while (datalen-- > 0)
2116
{
22-
len = data[i++];
17+
uint32_t len = data[i++];
2318
domainlen += (len + 1);
2419
domain = reallocf(domain, domainlen);
2520

@@ -53,7 +48,8 @@ char* _mdns_parse_domain_name(const uint8_t *data, uint32_t datalen)
5348
}
5449

5550
int main() {
56-
uint8_t data[128] = {0};
51+
const uint16_t datalen = 128;
52+
uint8_t data[datalen] = {};
5753
memcpy(data, "\x04quildu\x03xyz\x00", 11);
58-
_mdns_parse_domain_name(data, 128);
54+
_mdns_parse_domain_name(data, datalen);
5955
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
| DecOverflowWhenComparing.c:30:31:30:39 | ... -- | Unsigned decrementation in comparison ($@) - $@ | DecOverflowWhenComparing.c:30:26:30:39 | ... != ... | ... != ... | DecOverflowWhenComparing.c:15:9:15:15 | datalen | datalen |

0 commit comments

Comments
 (0)