move video parameters to run_ocr() function
This commit is contained in:
parent
3f73cb9bca
commit
bc84ee39ff
@ -14,32 +14,45 @@ class Video:
|
|||||||
use_fullframe: bool
|
use_fullframe: bool
|
||||||
num_frames: int
|
num_frames: int
|
||||||
fps: float
|
fps: float
|
||||||
ocr_frame_start: int
|
|
||||||
num_ocr_frames: int
|
|
||||||
pred_frames: List[PredictedFrame]
|
pred_frames: List[PredictedFrame]
|
||||||
pred_subs: List[PredictedSubtitle]
|
pred_subs: List[PredictedSubtitle]
|
||||||
|
|
||||||
def __init__(self, path: str, lang: str, use_fullframe=False,
|
def __init__(self, path: str):
|
||||||
time_start='0:00', time_end=''):
|
|
||||||
self.path = path
|
self.path = path
|
||||||
self.lang = lang
|
|
||||||
self.use_fullframe = use_fullframe
|
|
||||||
v = cv2.VideoCapture(path)
|
v = cv2.VideoCapture(path)
|
||||||
self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT))
|
self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
self.fps = v.get(cv2.CAP_PROP_FPS)
|
self.fps = v.get(cv2.CAP_PROP_FPS)
|
||||||
v.release()
|
v.release()
|
||||||
|
|
||||||
self.ocr_frame_start = self._frame_index(time_start)
|
def run_ocr(self, lang: str, use_fullframe=False,
|
||||||
|
time_start='0:00', time_end='') -> None:
|
||||||
|
self.lang = lang
|
||||||
|
self.use_fullframe = use_fullframe
|
||||||
|
|
||||||
|
ocr_start = self._frame_index(time_start)
|
||||||
|
|
||||||
if time_end:
|
if time_end:
|
||||||
ocr_end = self._frame_index(time_end)
|
ocr_end = self._frame_index(time_end)
|
||||||
|
if ocr_end < ocr_start:
|
||||||
|
raise ValueError('time_start is later than time_end')
|
||||||
else:
|
else:
|
||||||
ocr_end = self.num_frames
|
ocr_end = self.num_frames
|
||||||
self.num_ocr_frames = ocr_end - self.ocr_frame_start
|
num_ocr_frames = ocr_end - ocr_start
|
||||||
|
|
||||||
if self.num_ocr_frames < 0:
|
# get frames from ocr_start to ocr_end
|
||||||
raise ValueError('time_start is later than time_end')
|
v = cv2.VideoCapture(self.path)
|
||||||
|
v.set(cv2.CAP_PROP_POS_FRAMES, ocr_start)
|
||||||
|
frames = (v.read()[1] for _ in range(num_ocr_frames))
|
||||||
|
|
||||||
|
# perform ocr to frames in parallel
|
||||||
|
with futures.ProcessPoolExecutor() as pool:
|
||||||
|
ocr_map = pool.map(self._single_frame_ocr, frames, chunksize=10)
|
||||||
|
self.pred_frames = [PredictedFrame(i + ocr_start, data)
|
||||||
|
for i, data in enumerate(ocr_map)]
|
||||||
|
|
||||||
|
v.release()
|
||||||
|
|
||||||
|
# convert time str to frame index
|
||||||
def _frame_index(self, time: str) -> int:
|
def _frame_index(self, time: str) -> int:
|
||||||
t = time.split(':')
|
t = time.split(':')
|
||||||
t = list(map(int, t))
|
t = list(map(int, t))
|
||||||
@ -58,19 +71,6 @@ class Video:
|
|||||||
|
|
||||||
return index
|
return index
|
||||||
|
|
||||||
def run_ocr(self) -> None:
|
|
||||||
v = cv2.VideoCapture(self.path)
|
|
||||||
v.set(cv2.CAP_PROP_POS_FRAMES, self.ocr_frame_start)
|
|
||||||
frames = (v.read()[1] for _ in range(self.num_ocr_frames))
|
|
||||||
|
|
||||||
# perform ocr to all frames in parallel
|
|
||||||
with futures.ProcessPoolExecutor() as pool:
|
|
||||||
ocr_map = pool.map(self._single_frame_ocr, frames, chunksize=10)
|
|
||||||
self.pred_frames = [PredictedFrame(i + self.ocr_frame_start, data)
|
|
||||||
for i, data in enumerate(ocr_map)]
|
|
||||||
|
|
||||||
v.release()
|
|
||||||
|
|
||||||
def _single_frame_ocr(self, img) -> str:
|
def _single_frame_ocr(self, img) -> str:
|
||||||
if not self.use_fullframe:
|
if not self.use_fullframe:
|
||||||
# only use bottom half of the frame by default
|
# only use bottom half of the frame by default
|
||||||
@ -99,7 +99,7 @@ class Video:
|
|||||||
bound = WIN_BOUND
|
bound = WIN_BOUND
|
||||||
i = 0
|
i = 0
|
||||||
j = 1
|
j = 1
|
||||||
while j < self.num_ocr_frames:
|
while j < len(self.pred_frames):
|
||||||
fi, fj = self.pred_frames[i], self.pred_frames[j]
|
fi, fj = self.pred_frames[i], self.pred_frames[j]
|
||||||
|
|
||||||
if fi.is_similar_to(fj):
|
if fi.is_similar_to(fj):
|
||||||
@ -118,7 +118,7 @@ class Video:
|
|||||||
j += 1
|
j += 1
|
||||||
|
|
||||||
# also handle the last remaining frames
|
# also handle the last remaining frames
|
||||||
if i < self.num_ocr_frames - 1:
|
if i < len(self.pred_frames) - 1:
|
||||||
self._append_sub(PredictedSubtitle(self.pred_frames[i:]))
|
self._append_sub(PredictedSubtitle(self.pred_frames[i:]))
|
||||||
|
|
||||||
def _append_sub(self, sub: PredictedSubtitle) -> None:
|
def _append_sub(self, sub: PredictedSubtitle) -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user