diff --git a/videocr/models.py b/videocr/models.py index bd828be..31259c4 100644 --- a/videocr/models.py +++ b/videocr/models.py @@ -62,12 +62,12 @@ class PredictedSubtitle: def __init__(self, frames: List[PredictedFrame]): self.frames = [f for f in frames if f.confidence > 0] - @property - def text(self) -> str: if self.frames: conf_max = max(f.confidence for f in self.frames) - return next(f.text for f in self.frames if f.confidence == conf_max) - return '' + self.text = next(f.text for f in self.frames + if f.confidence == conf_max) + else: + self.text = '' @property def index_start(self) -> int: diff --git a/videocr/video.py b/videocr/video.py index 703c1d4..e1a2767 100644 --- a/videocr/video.py +++ b/videocr/video.py @@ -46,12 +46,22 @@ class Video: v.release() def get_subtitles(self) -> str: + self._generate_subtitles() + return ''.join( + '{}\n{} --> {}\n{}\n'.format( + i, + self._srt_timestamp(sub.index_start), + self._srt_timestamp(sub.index_end), + sub.text) + for i, sub in enumerate(self.pred_subs)) + + def _generate_subtitles(self) -> None: + self.pred_subs = [] + if self.pred_frames is None: raise AttributeError( 'Please call self.run_ocr() first to generate ocr of frames') - self.pred_subs = [] - # divide ocr of frames into subtitle paragraphs using sliding window WIN_BOUND = int(self.fps / 2) # 1/2 sec sliding window boundary bound = WIN_BOUND @@ -75,18 +85,10 @@ class Video: j += 1 + # also handle the last remaining frames if i < self.num_frames - 1: self._append_sub(PredictedSubtitle(self.pred_frames[i:])) - for i, sub in enumerate(self.pred_subs): - print('{}\n{} --> {}\n{}\n'.format( - i, - self._srt_timestamp(sub.index_start), - self._srt_timestamp(sub.index_end), - sub.text)) - - return '' - def _append_sub(self, sub: PredictedSubtitle) -> None: if len(sub.text) == 0: return @@ -101,7 +103,7 @@ class Video: def _srt_timestamp(self, frame_index) -> str: time = str(datetime.timedelta(seconds=frame_index / self.fps)) - return time.replace('.', ',') # srt uses comma as fractional separator + return time.replace('.', ',') # srt uses comma, not dot time_start = timeit.default_timer()