Skip to content

Commit 3c9cb85

Browse files
committed
more work DecOverflowWhenComparing
1 parent e4f8d7c commit 3c9cb85

4 files changed

Lines changed: 140 additions & 23 deletions

File tree

cpp/src/security/DecOverflowWhenComparing/DecOverflowWhenComparing.c

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,88 @@
33
#include <stdlib.h>
44
#include <string.h>
55

6+
size_t min(size_t a, size_t b) {
7+
return a < b ? a : b;
8+
}
9+
10+
void MlasReorderInputNhwc(
11+
const float* S,
12+
float* D,
13+
size_t InputChannels,
14+
size_t RowCount,
15+
size_t FullRowCount
16+
)
17+
{
18+
const size_t BlockSize = FullRowCount % 123;
19+
20+
//
21+
// Iterate over batches of the input size to improve locality.
22+
//
23+
24+
for (size_t OuterRowCountRemaining = RowCount; OuterRowCountRemaining > 0; ) {
25+
26+
size_t OuterRowCountBatch = 32;
27+
28+
const size_t OuterRowCountThisIteration = min(OuterRowCountRemaining, OuterRowCountBatch);
29+
OuterRowCountRemaining -= OuterRowCountThisIteration;
30+
31+
//
32+
// Iterate over BlockSize batches of the input channels.
33+
//
34+
35+
const float* s = S;
36+
float* d = D;
37+
38+
for (size_t i = InputChannels; i > 0;) {
39+
40+
const size_t InputChannelsThisIteration = min(i, BlockSize);
41+
i -= InputChannelsThisIteration;
42+
43+
const float* ss = s;
44+
float* dd = d;
45+
size_t InnerRowCountRemaining = OuterRowCountThisIteration;
46+
47+
if (InputChannelsThisIteration == BlockSize) {
48+
49+
if (BlockSize == 8) {
50+
51+
while (InnerRowCountRemaining-- > 0) {
52+
ss += InputChannels;
53+
dd += 8;
54+
}
55+
56+
} else {
57+
58+
while (InnerRowCountRemaining-- > 0) {
59+
ss += InputChannels;
60+
dd += 16;
61+
}
62+
}
63+
64+
} else {
65+
66+
size_t BlockPadding = BlockSize - InputChannelsThisIteration;
67+
68+
while (InnerRowCountRemaining-- > 0) {
69+
ss += InputChannels;
70+
dd += BlockSize;
71+
}
72+
}
73+
74+
s += InputChannelsThisIteration;
75+
d += BlockSize * FullRowCount;
76+
}
77+
78+
S += InputChannels * OuterRowCountThisIteration;
79+
D += BlockSize * OuterRowCountThisIteration;
80+
}
81+
}
82+
683
// from https://github.com/apple-oss-distributions/Libinfo/blob/9fce29e5c5edc15d3ecea55116ca17d3f6350603/lookup.subproj/mdns_module.c#L1033C1-L1079C2
784
char* _mdns_parse_domain_name(const uint8_t *data, uint32_t datalen)
885
{
986
int i = 0, j = 0;
10-
uint32_t len;
87+
// uint32_t len;
1188
uint32_t domainlen = 0;
1289
char *domain = NULL;
1390

@@ -19,8 +96,8 @@ char* _mdns_parse_domain_name(const uint8_t *data, uint32_t datalen)
1996
*/
2097
while (datalen-- > 0)
2198
{
22-
printf("%d\n", len);
23-
len = data[i++];
99+
// printf("%d\n", len);
100+
uint32_t len = data[i++];
24101
domainlen += (len + 1);
25102
domain = reallocf(domain, domainlen);
26103

cpp/src/security/DecOverflowWhenComparing/DecOverflowWhenComparing.ql

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,33 @@
1212

1313
import cpp
1414
import semmle.code.cpp.ir.IR
15+
import semmle.code.cpp.rangeanalysis.SimpleRangeAnalysis
1516

16-
from Variable var, VariableAccess varAcc, DecrementOperation dec,
17+
/**
18+
* Find CFG paths from start to end that do not cross over node that is var's lvalue access
19+
* TODO: there must be an API for that...
20+
*/
21+
predicate successorGuarded(ControlFlowNode start, ControlFlowNode end, Variable var) {
22+
start = end
23+
or
24+
exists(ControlFlowNode interm |
25+
start.getASuccessor() = interm and
26+
27+
// break the path if variable is overwritten
28+
not (
29+
interm = var.getAnAccess() and
30+
interm.(VariableAccess).isLValue()
31+
) and
32+
33+
(
34+
interm.getASuccessor() = end
35+
or
36+
successorGuarded(interm, end, var)
37+
)
38+
)
39+
}
40+
41+
from Variable var, VariableAccess varAcc, PostfixDecrExpr dec,
1742
VariableAccess varAccAfterOverflow, ComparisonOperation cmp
1843
where
1944
// get unsigned variable that is decremented
@@ -28,19 +53,9 @@ where
2853
cmp.getAnOperand() instanceof Zero and
2954

3055
// only if the variable is used after the comparison
31-
cmp.getASuccessor+() = varAccAfterOverflow and
56+
successorGuarded(cmp, varAccAfterOverflow, var) and
3257
cmp.getAnOperand().getAChild*() != varAccAfterOverflow and
3358

34-
// skip if the variable is overwritten
35-
// TODO: handle loops correctly
36-
// not exists(VariableAccess varAccLV | varAccLV.isUsedAsLValue() |
37-
// varAccLV = var.getAnAccess() and
38-
// varAccLV != varAcc and
39-
// varAccLV != varAccAfterOverflow and
40-
// cmp.getASuccessor+() = varAccLV and
41-
// varAccAfterOverflow.getAPredecessor+() = varAccLV
42-
// ) and
43-
4459
// var-- > 0 (0 < var--) then accesses only in false branch
4560
// var-- >= 0 then accesses in all branches
4661
// var-- == 0 then accesses in all branches
@@ -54,4 +69,13 @@ where
5469
cmp.getAFalseSuccessor().getASuccessor*() = varAccAfterOverflow
5570
else
5671
any()
57-
select cmp, varAccAfterOverflow
72+
73+
and
74+
75+
// only if var may possibly be zero during comparison
76+
lowerBound(varAcc) = 0
77+
78+
// skip tests etc
79+
and not dec.getFile().getAbsolutePath().toLowerCase().matches(["%test%", "%vendor%", "%third_party%"])
80+
81+
select dec, "Unsigned decrementation in comparison ($@) - $@", cmp, cmp.toString(), varAccAfterOverflow, varAccAfterOverflow.toString()

cpp/test/include/libc/stdint.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#ifndef USE_HEADERS
2+
3+
#ifndef HEADER_STDINT_STUB_H
4+
#define HEADER_STDINT_STUB_H
5+
6+
typedef unsigned char uint8_t;
7+
typedef unsigned short uint16_t;
8+
typedef unsigned int uint32_t;
9+
10+
#endif
11+
#else // --- else USE_HEADERS
12+
13+
#include <stdint.h>
14+
15+
#endif // --- end USE_HEADERS

cpp/test/query-tests/security/DecOverflowWhenComparing/DecOverflowWhenComparing.c

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
#ifndef USE_HEADERS
21
#include "../../../include/libc/string_stubs.h"
2+
#include "../../../include/libc/stdlib.h"
3+
#include "../../../include/libc/unistd.h"
4+
#include "../../../include/libc/stdint.h"
5+
6+
// #include <stdio.h>
7+
// #include <stdint.h>
8+
// #include <stdlib.h>
9+
// #include <string.h>
310

4-
#else
5-
#include <stdio.h>
6-
#include <stdint.h>
7-
#include <stdlib.h>
8-
#include <string.h>
9-
#endif
1011

1112
// from https://github.com/apple-oss-distributions/Libinfo/blob/9fce29e5c5edc15d3ecea55116ca17d3f6350603/lookup.subproj/mdns_module.c#L1033C1-L1079C2
1213
char* _mdns_parse_domain_name(const uint8_t *data, uint32_t datalen)

0 commit comments

Comments
 (0)