move video parameters to run_ocr() function

This commit is contained in:
Yi Ge 2019-04-27 21:41:19 +02:00
parent 3f73cb9bca
commit bc84ee39ff

View File

@ -14,32 +14,45 @@ class Video:
use_fullframe: bool use_fullframe: bool
num_frames: int num_frames: int
fps: float fps: float
ocr_frame_start: int
num_ocr_frames: int
pred_frames: List[PredictedFrame] pred_frames: List[PredictedFrame]
pred_subs: List[PredictedSubtitle] pred_subs: List[PredictedSubtitle]
def __init__(self, path: str, lang: str, use_fullframe=False, def __init__(self, path: str):
time_start='0:00', time_end=''):
self.path = path self.path = path
self.lang = lang
self.use_fullframe = use_fullframe
v = cv2.VideoCapture(path) v = cv2.VideoCapture(path)
self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT)) self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT))
self.fps = v.get(cv2.CAP_PROP_FPS) self.fps = v.get(cv2.CAP_PROP_FPS)
v.release() v.release()
self.ocr_frame_start = self._frame_index(time_start) def run_ocr(self, lang: str, use_fullframe=False,
time_start='0:00', time_end='') -> None:
self.lang = lang
self.use_fullframe = use_fullframe
ocr_start = self._frame_index(time_start)
if time_end: if time_end:
ocr_end = self._frame_index(time_end) ocr_end = self._frame_index(time_end)
if ocr_end < ocr_start:
raise ValueError('time_start is later than time_end')
else: else:
ocr_end = self.num_frames ocr_end = self.num_frames
self.num_ocr_frames = ocr_end - self.ocr_frame_start num_ocr_frames = ocr_end - ocr_start
if self.num_ocr_frames < 0: # get frames from ocr_start to ocr_end
raise ValueError('time_start is later than time_end') v = cv2.VideoCapture(self.path)
v.set(cv2.CAP_PROP_POS_FRAMES, ocr_start)
frames = (v.read()[1] for _ in range(num_ocr_frames))
# perform ocr to frames in parallel
with futures.ProcessPoolExecutor() as pool:
ocr_map = pool.map(self._single_frame_ocr, frames, chunksize=10)
self.pred_frames = [PredictedFrame(i + ocr_start, data)
for i, data in enumerate(ocr_map)]
v.release()
# convert time str to frame index
def _frame_index(self, time: str) -> int: def _frame_index(self, time: str) -> int:
t = time.split(':') t = time.split(':')
t = list(map(int, t)) t = list(map(int, t))
@ -58,19 +71,6 @@ class Video:
return index return index
def run_ocr(self) -> None:
v = cv2.VideoCapture(self.path)
v.set(cv2.CAP_PROP_POS_FRAMES, self.ocr_frame_start)
frames = (v.read()[1] for _ in range(self.num_ocr_frames))
# perform ocr to all frames in parallel
with futures.ProcessPoolExecutor() as pool:
ocr_map = pool.map(self._single_frame_ocr, frames, chunksize=10)
self.pred_frames = [PredictedFrame(i + self.ocr_frame_start, data)
for i, data in enumerate(ocr_map)]
v.release()
def _single_frame_ocr(self, img) -> str: def _single_frame_ocr(self, img) -> str:
if not self.use_fullframe: if not self.use_fullframe:
# only use bottom half of the frame by default # only use bottom half of the frame by default
@ -99,7 +99,7 @@ class Video:
bound = WIN_BOUND bound = WIN_BOUND
i = 0 i = 0
j = 1 j = 1
while j < self.num_ocr_frames: while j < len(self.pred_frames):
fi, fj = self.pred_frames[i], self.pred_frames[j] fi, fj = self.pred_frames[i], self.pred_frames[j]
if fi.is_similar_to(fj): if fi.is_similar_to(fj):
@ -118,7 +118,7 @@ class Video:
j += 1 j += 1
# also handle the last remaining frames # also handle the last remaining frames
if i < self.num_ocr_frames - 1: if i < len(self.pred_frames) - 1:
self._append_sub(PredictedSubtitle(self.pred_frames[i:])) self._append_sub(PredictedSubtitle(self.pred_frames[i:]))
def _append_sub(self, sub: PredictedSubtitle) -> None: def _append_sub(self, sub: PredictedSubtitle) -> None: