diff --git a/main_test.go b/main_test.go index b4dff95..8ae7b12 100644 --- a/main_test.go +++ b/main_test.go @@ -1,6 +1,8 @@ package main import ( + "runtime" + "sync" "bytes" "crypto/md5" "encoding/hex" @@ -12,13 +14,14 @@ import ( "testing" ) -const tmpDir = "tmp/" +const tmpDir = "/tmp/" const plainDir = tmpDir + "plain/" const cipherDir = tmpDir + "cipher/" func mount(extraArgs ...string) { var args []string args = append(args, extraArgs...) + //args = append(args, "--fusedebug") args = append(args, cipherDir) args = append(args, plainDir) c := exec.Command("./gocryptfs", args...) @@ -44,6 +47,10 @@ func md5fn(filename string) string { fmt.Printf("ReadFile: %v\n", err) return "" } + return md5hex(buf) +} + +func md5hex(buf []byte) string { rawHash := md5.Sum(buf) hash := hex.EncodeToString(rawHash[:]) return hash @@ -209,6 +216,84 @@ func TestFileHoles(t *testing.T) { } } +func sContains(haystack []string, needle string) bool { + for _, element := range haystack { + if element == needle { + return true + } + } + return false +} + +func TestRmwRace(t *testing.T) { + + runtime.GOMAXPROCS(10) + + fn := plainDir + "rmwrace" + f1, err := os.Create(fn) + if err != nil { + t.Errorf("file create failed") + } + f2, err := os.Create(fn) + if err != nil { + t.Errorf("file create failed") + } + + oldBlock := bytes.Repeat([]byte("o"), 4096) + + newBlock := bytes.Repeat([]byte("n"), 4096) + + shortBlock := bytes.Repeat([]byte("s"), 16) + + mergedBlock := make([]byte, 4096) + copy(mergedBlock, newBlock) + copy(mergedBlock[4080:], shortBlock) + + goodMd5 := make(map[string]int) + + for i := 0; i < 1000; i++ { + // Reset to [ooooooooo] + _, err = f1.WriteAt(oldBlock, 0) + if err != nil { + t.Errorf("Write failed") + } + + var wg sync.WaitGroup + wg.Add(2) + + // Write to the end of the file, [....ssss] + go func() { + f1.WriteAt(shortBlock, 4080) + wg.Done() + }() + + // Overwrite to [nnnnnnn] + go func() { + f2.WriteAt(newBlock, 0) + wg.Done() + }() + + wg.Wait() + + // The file should be either: + // [nnnnnnnnnn] (md5: 6c1660fdabccd448d1359f27b3db3c99) or + // [nnnnnnssss] (md5: da885006a6a284530a427c73ce1e5c32) + // but it must not be + // [oooooossss] + + buf, _ := ioutil.ReadFile(fn) + m := md5hex(buf) + goodMd5[m] = goodMd5[m] + 1 + + /* + if m == "6c1660fdabccd448d1359f27b3db3c99" { + fmt.Println(hex.Dump(buf)) + t.FailNow() + } + */ + } + fmt.Println(goodMd5) +} func BenchmarkStreamWrite(t *testing.B) { buf := make([]byte, 1024*1024) t.SetBytes(int64(len(buf)))