Skip to content

Commit

Permalink
new zlib type model for reverse phonemization
Browse files Browse the repository at this point in the history
  • Loading branch information
neurlang authored and Your Name committed Jan 24, 2025
1 parent a688b07 commit c6505fa
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 74 deletions.
8 changes: 4 additions & 4 deletions dicts/dicts.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,16 @@ func (DictGetter) IsOldFormat(magic []byte) bool {
if len(magic) < 2 {
return false
}
// GZIP
return magic[0] == 0x1F && magic[1] == 0x8B
// LZW
return (magic[0] == 0x1F && magic[1] == 0x9D) || (magic[0] == 0x1F && magic[1] == 0xA0)
}

func (DictGetter) IsNewFormat(magic []byte) bool {
if len(magic) < 2 {
return false
}
// LZW
return (magic[0] == 0x1F && magic[1] == 0x9D) || (magic[0] == 0x1F && magic[1] == 0xA0)
// ZLIB
return (magic[0] == 0x78 && (magic[1] == 0x01 || magic[1] == 0x5E || magic[1] == 0x9C || magic[1] == 0xDA))
}

func GetDict(lang, filename string) ([]byte, error) {
Expand Down
2 changes: 1 addition & 1 deletion dicts/slovak/language.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ package slovak

import "embed"

//go:embed *.tsv language.json weights1.json.lzw language_reverse.json weights1_reverse.json.lzw
//go:embed *.tsv language.json weights1.json.lzw language_reverse.json weights1_reverse.json.zlib

Check failure on line 5 in dicts/slovak/language.go

View workflow job for this annotation

GitHub Actions / build

pattern weights1_reverse.json.zlib: no matching files found
var Language embed.FS
162 changes: 94 additions & 68 deletions repo/hashtron_phonemizer_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@ package repo

import (
"bytes"
"compress/gzip"
"encoding/json"
"github.com/neurlang/classifier/datasets/phonemizer"
"github.com/neurlang/classifier/hash"
"github.com/neurlang/classifier/hashtron"
"github.com/neurlang/classifier/layer/majpool2d"
"github.com/neurlang/classifier/net/feedforward"
"github.com/neurlang/goruut/helpers/log"
"github.com/neurlang/goruut/repo/interfaces"
"io/ioutil"
"strings"
"sync"
"unicode"
Expand Down Expand Up @@ -246,19 +245,50 @@ func (r *HashtronPhonemizerRepository) LoadLanguage(isReverse bool, lang string)
log.Now().Debugf("Language %s made map of nets", lang)
}
if (*nets)[lang+reverse] == nil {
var net feedforward.FeedforwardNetwork
const fanout1 = 3
const fanout2 = 12
//const fanout3 = 3
//const fanout4 = 10
//net.NewLayerP(fanout1*fanout2*fanout3*fanout4, 0, 1033)
//net.NewCombiner(majpool2d.MustNew(fanout1*fanout2*fanout4, 1, fanout3, 1, fanout4, 1, 1))
net.NewLayerP(fanout1*fanout2, 0, 1<<fanout2)
net.NewCombiner(majpool2d.MustNew(fanout2, 1, fanout1, 1, fanout2, 1, 1))
net.NewLayer(1, 0)
r.mut.Lock()
(*r.nets)[lang+reverse] = &net
r.mut.Unlock()

if isReverse {
const fanout0 = 5
const fanout1 = 1
const fanout2 = 5
const fanout3 = 1
const fanout4 = 5
//const fanout5 = 1
//const fanout6 = 4
//const fanout7 = 1
//const fanout8 = 5

var net feedforward.FeedforwardNetwork
//net.NewLayerP(fanout1*fanout2*fanout3*fanout4*fanout5*fanout6*fanout7*fanout8, 0, 1<<fanout8)
//net.NewCombiner(majpool2d.MustNew2(fanout1*fanout2*fanout3*fanout4*fanout5*fanout6*fanout8, 1, fanout7, 1, fanout8, 1, 1, 0))
//net.NewLayerP(fanout1*fanout2*fanout3*fanout4*fanout5*fanout6, 0, 1<<(fanout6*fanout6*2/3))
//net.NewCombiner(majpool2d.MustNew2(fanout1*fanout2*fanout3*fanout4*fanout6, 1, fanout5, 1, fanout6, 1, 1, 0))
net.NewLayerP(fanout1*fanout2*fanout3*fanout4, 0, 1<<(fanout4*fanout4*2/3))
net.NewCombiner(majpool2d.MustNew2(fanout1*fanout2*fanout4, 1, fanout3, 1, fanout4, 1, 1, 0))
net.NewLayerP(fanout1*fanout2, 0, 1<<(fanout2*fanout2*2/3))
//net.NewCombiner(full.MustNew(fanout2, 1, 1))
net.NewCombiner(majpool2d.MustNew2(fanout2, 1, fanout1, 1, fanout2, 1, 1, 0))
net.NewLayerP(1, 0, 1<<(fanout0))
r.mut.Lock()
(*r.nets)[lang+reverse] = &net
r.mut.Unlock()

} else {

var net feedforward.FeedforwardNetwork
const fanout1 = 3
const fanout2 = 12
//const fanout3 = 3
//const fanout4 = 10
//net.NewLayerP(fanout1*fanout2*fanout3*fanout4, 0, 1033)
//net.NewCombiner(majpool2d.MustNew(fanout1*fanout2*fanout4, 1, fanout3, 1, fanout4, 1, 1))
net.NewLayerP(fanout1*fanout2, 0, 1<<fanout2)
net.NewCombiner(majpool2d.MustNew(fanout2, 1, fanout1, 1, fanout2, 1, 1))
net.NewLayer(1, 0)
r.mut.Lock()
(*r.nets)[lang+reverse] = &net
r.mut.Unlock()

}
} else {
log.Now().Debugf("Language %s already loaded", lang)
return
Expand Down Expand Up @@ -291,47 +321,29 @@ func (r *HashtronPhonemizerRepository) LoadLanguage(isReverse bool, lang string)
r.mut.Unlock()
}

var files = []string{"weights1" + reverse + ".json.lzw"}
var files = []string{"weights1" + reverse + ".json.lzw", "weights1" + reverse + ".json.zlib"}

if isReverse {
files[0], files[1] = files[1], files[0]
}

for _, file := range files {
compressedData := log.Error1((*r.getter).GetDict(lang, file))

if compressedData == nil {
continue
}

if (*r.getter).IsOldFormat(compressedData) {

// Step 3: Decompress the data in memory
gzipReader, err := gzip.NewReader(bytes.NewReader(compressedData))
if err != nil {
log.Now().Errorf("Failed to create gzip reader: %v", err)
continue
}
defer gzipReader.Close()

// Step 4: Read the decompressed data into memory
decompressedData, err := ioutil.ReadAll(gzipReader)
if err != nil {
log.Now().Errorf("Failed to read decompressed data: %v", err)
continue
}

// Step 5: Parse the JSON data into the specified type
var data [][][2]uint32
err = json.Unmarshal(decompressedData, &data)
if err != nil {
log.Now().Errorf("Failed to parse JSON data: %v", err)
continue
}
if !isReverse && (*r.getter).IsOldFormat(compressedData) {
bytesReader := bytes.NewReader(compressedData)
r.mut.Lock()
// Load the weights into the network
for i, v := range data {
*((*r.nets)[lang+reverse].GetHashtron(i)) = get_hashtron(hashtron.New(v, 1))
}
(*r.nets)[lang+reverse].ReadCompressedWeights(bytesReader)
r.mut.Unlock()
return

} else {
} else if isReverse && (*r.getter).IsNewFormat(compressedData) {
bytesReader := bytes.NewReader(compressedData)
r.mut.Lock()
(*r.nets)[lang+reverse].ReadCompressedWeights(bytesReader)
(*r.nets)[lang+reverse].ReadZlibWeights(bytesReader)
r.mut.Unlock()
return
}
Expand Down Expand Up @@ -524,33 +536,47 @@ outer:
origi := i
i := i - lastspace
j := len(srcaR) - i
var buf = [...]uint32{
hash.StringsHash(0, srcaR[1*i/2:i+j/2]),
hash.StringsHash(0, srcaR[2*i/3:i+j/3]),
hash.StringsHash(0, srcaR[4*i/5:i+j/5]),
hash.StringsHash(0, srcaR[6*i/7:i+j/11]),
hash.StringsHash(0, srcaR[10*i/11:i+j/11]),
hash.StringsHash(0, dstaR[0:i]),
hash.StringsHash(0, srcaR),
hash.StringsHash(0, dstaR[0:4*i/7]),
hash.StringsHash(0, dstaR[4*i/7:6*i/7]),
hash.StringsHash(0, dstaR[6*i/7:i]),
hash.StringsHash(0, srcaR[i:i+j/7]),
hash.StringsHash(0, srcaR[i+j/7:i+3*j/7]),
hash.StringsHash(0, srcaR[i+3*j/7:i+j]),
hash.StringHash(0, option),
}
var input = sample(buf)
r.mut.RLock()
net := (*r.nets)[lang+reverse]
r.mut.RUnlock()
if net == nil {
continue
}
r.mut.RLock()
var predicted = net.Infer2(&input)
r.mut.RUnlock()
if predicted == 1 {
var predicted bool
if isReverse {
var input = phonemizer.NewSample{
SrcA: srcaR,
DstA: dstaR[0:i],
SrcCut: srcaR[0:i],
SrcFut: srcaR[i:],
Option: srcaR[i],
}
r.mut.RLock()
predicted = net.Infer2(&input) != 0
r.mut.RUnlock()
} else {
var buf = [...]uint32{
hash.StringsHash(0, srcaR[1*i/2:i+j/2]),
hash.StringsHash(0, srcaR[2*i/3:i+j/3]),
hash.StringsHash(0, srcaR[4*i/5:i+j/5]),
hash.StringsHash(0, srcaR[6*i/7:i+j/11]),
hash.StringsHash(0, srcaR[10*i/11:i+j/11]),
hash.StringsHash(0, dstaR[0:i]),
hash.StringsHash(0, srcaR),
hash.StringsHash(0, dstaR[0:4*i/7]),
hash.StringsHash(0, dstaR[4*i/7:6*i/7]),
hash.StringsHash(0, dstaR[6*i/7:i]),
hash.StringsHash(0, srcaR[i:i+j/7]),
hash.StringsHash(0, srcaR[i+j/7:i+3*j/7]),
hash.StringsHash(0, srcaR[i+3*j/7:i+j]),
hash.StringHash(0, option),
}
var input = sample(buf)
r.mut.RLock()
predicted = net.Infer2(&input) != 0
r.mut.RUnlock()
}
if predicted {
if strings.HasPrefix(option, "_") {
lastspace = origi
}
Expand Down
2 changes: 1 addition & 1 deletion repo/interfaces/dict_getter.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ package interfaces
type DictGetter interface {
GetDict(lang, filename string) ([]byte, error)
IsOldFormat(magic []byte) bool
//IsNewFormat(magic []byte) bool
IsNewFormat(magic []byte) bool
}

0 comments on commit c6505fa

Please sign in to comment.