calculate PredictedSubtitle.text early

This commit is contained in:
Yi Ge 2019-04-26 00:32:47 +02:00
parent 3a73f1f508
commit f5d27a7a46
2 changed files with 18 additions and 16 deletions

View File

@ -62,12 +62,12 @@ class PredictedSubtitle:
def __init__(self, frames: List[PredictedFrame]): def __init__(self, frames: List[PredictedFrame]):
self.frames = [f for f in frames if f.confidence > 0] self.frames = [f for f in frames if f.confidence > 0]
@property
def text(self) -> str:
if self.frames: if self.frames:
conf_max = max(f.confidence for f in self.frames) conf_max = max(f.confidence for f in self.frames)
return next(f.text for f in self.frames if f.confidence == conf_max) self.text = next(f.text for f in self.frames
return '' if f.confidence == conf_max)
else:
self.text = ''
@property @property
def index_start(self) -> int: def index_start(self) -> int:

View File

@ -46,12 +46,22 @@ class Video:
v.release() v.release()
def get_subtitles(self) -> str: def get_subtitles(self) -> str:
self._generate_subtitles()
return ''.join(
'{}\n{} --> {}\n{}\n'.format(
i,
self._srt_timestamp(sub.index_start),
self._srt_timestamp(sub.index_end),
sub.text)
for i, sub in enumerate(self.pred_subs))
def _generate_subtitles(self) -> None:
self.pred_subs = []
if self.pred_frames is None: if self.pred_frames is None:
raise AttributeError( raise AttributeError(
'Please call self.run_ocr() first to generate ocr of frames') 'Please call self.run_ocr() first to generate ocr of frames')
self.pred_subs = []
# divide ocr of frames into subtitle paragraphs using sliding window # divide ocr of frames into subtitle paragraphs using sliding window
WIN_BOUND = int(self.fps / 2) # 1/2 sec sliding window boundary WIN_BOUND = int(self.fps / 2) # 1/2 sec sliding window boundary
bound = WIN_BOUND bound = WIN_BOUND
@ -75,18 +85,10 @@ class Video:
j += 1 j += 1
# also handle the last remaining frames
if i < self.num_frames - 1: if i < self.num_frames - 1:
self._append_sub(PredictedSubtitle(self.pred_frames[i:])) self._append_sub(PredictedSubtitle(self.pred_frames[i:]))
for i, sub in enumerate(self.pred_subs):
print('{}\n{} --> {}\n{}\n'.format(
i,
self._srt_timestamp(sub.index_start),
self._srt_timestamp(sub.index_end),
sub.text))
return ''
def _append_sub(self, sub: PredictedSubtitle) -> None: def _append_sub(self, sub: PredictedSubtitle) -> None:
if len(sub.text) == 0: if len(sub.text) == 0:
return return
@ -101,7 +103,7 @@ class Video:
def _srt_timestamp(self, frame_index) -> str: def _srt_timestamp(self, frame_index) -> str:
time = str(datetime.timedelta(seconds=frame_index / self.fps)) time = str(datetime.timedelta(seconds=frame_index / self.fps))
return time.replace('.', ',') # srt uses comma as fractional separator return time.replace('.', ',') # srt uses comma, not dot
time_start = timeit.default_timer() time_start = timeit.default_timer()