2019-04-20 23:21:41 +02:00
|
|
|
from __future__ import annotations
|
|
|
|
from typing import List
|
|
|
|
from dataclasses import dataclass
|
2019-04-25 01:39:35 +02:00
|
|
|
from fuzzywuzzy import fuzz
|
2019-04-20 23:21:41 +02:00
|
|
|
|
|
|
|
|
2019-04-24 21:18:31 +02:00
|
|
|
CONF_THRESHOLD = 60
|
|
|
|
# word predictions with lower confidence will be filtered out
|
2019-04-20 23:21:41 +02:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class PredictedWord:
|
|
|
|
__slots__ = 'confidence', 'text'
|
|
|
|
confidence: int
|
|
|
|
text: str
|
|
|
|
|
|
|
|
|
|
|
|
class PredictedFrame:
|
2019-04-24 21:18:31 +02:00
|
|
|
index: int # 0-based index of the frame
|
2019-04-20 23:21:41 +02:00
|
|
|
words: List[PredictedWord]
|
2019-04-24 21:18:31 +02:00
|
|
|
confidence: int # total confidence of all words
|
|
|
|
text: str
|
2019-04-20 23:21:41 +02:00
|
|
|
|
2019-04-24 21:18:31 +02:00
|
|
|
def __init__(self, index, pred_data: str):
|
|
|
|
self.index = index
|
2019-04-20 23:21:41 +02:00
|
|
|
self.words = []
|
|
|
|
|
2019-04-24 21:18:31 +02:00
|
|
|
block = 0 # keep track of line breaks
|
|
|
|
|
|
|
|
for l in pred_data.splitlines()[1:]:
|
|
|
|
word_data = l.split()
|
|
|
|
if len(word_data) < 12:
|
2019-04-20 23:21:41 +02:00
|
|
|
# no word is predicted
|
|
|
|
continue
|
2019-04-24 21:18:31 +02:00
|
|
|
_, _, block_num, *_, conf, text = word_data
|
2019-04-20 23:21:41 +02:00
|
|
|
block_num, conf = int(block_num), int(conf)
|
|
|
|
|
|
|
|
# handle line breaks
|
2019-04-24 21:18:31 +02:00
|
|
|
if block < block_num:
|
|
|
|
block = block_num
|
2019-04-26 00:07:25 +02:00
|
|
|
if self.words and self.words[-1].text != '\n':
|
|
|
|
self.words.append(PredictedWord(0, '\n'))
|
2019-04-20 23:21:41 +02:00
|
|
|
|
2019-04-24 21:18:31 +02:00
|
|
|
if conf >= CONF_THRESHOLD:
|
2019-04-20 23:21:41 +02:00
|
|
|
self.words.append(PredictedWord(conf, text))
|
|
|
|
|
2019-04-24 21:18:31 +02:00
|
|
|
self.confidence = sum(word.confidence for word in self.words)
|
2019-04-20 23:21:41 +02:00
|
|
|
|
2019-04-26 00:07:25 +02:00
|
|
|
self.text = ' '.join(word.text for word in self.words)
|
|
|
|
# remove chars that are obviously ocr errors
|
|
|
|
translate_table = {ord(c): None for c in '<>{};`@#$%^*_=\\'}
|
|
|
|
translate_table[ord('|')] = 'I'
|
|
|
|
self.text = self.text.translate(translate_table).strip()
|
|
|
|
|
|
|
|
def is_similar_to(self, other: PredictedFrame, threshold=70) -> bool:
|
|
|
|
return fuzz.partial_ratio(self.text, other.text) >= threshold
|
2019-04-25 01:39:35 +02:00
|
|
|
|
|
|
|
|
|
|
|
class PredictedSubtitle:
|
|
|
|
frames: List[PredictedFrame]
|
|
|
|
|
|
|
|
def __init__(self, frames: List[PredictedFrame]):
|
|
|
|
self.frames = [f for f in frames if f.confidence > 0]
|
|
|
|
|
|
|
|
@property
|
|
|
|
def text(self) -> str:
|
|
|
|
if self.frames:
|
|
|
|
conf_max = max(f.confidence for f in self.frames)
|
|
|
|
return next(f.text for f in self.frames if f.confidence == conf_max)
|
|
|
|
return ''
|
|
|
|
|
|
|
|
@property
|
|
|
|
def index_start(self) -> int:
|
|
|
|
if self.frames:
|
|
|
|
return self.frames[0].index
|
|
|
|
return 0
|
|
|
|
|
|
|
|
@property
|
|
|
|
def index_end(self) -> int:
|
|
|
|
if self.frames:
|
|
|
|
return self.frames[-1].index
|
|
|
|
return 0
|
2019-04-26 00:07:25 +02:00
|
|
|
|
|
|
|
def is_similar_to(self, other: PredictedSubtitle, threshold=70) -> bool:
|
|
|
|
return fuzz.partial_ratio(self.text, other.text) >= threshold
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return '{} - {}. {}'.format(self.index_start, self.index_end, self.text)
|