forked from pradana.aumars/videocr
support ocr on part of the video
This commit is contained in:
parent
e55c17c325
commit
a3986b3279
@ -14,10 +14,13 @@ class Video:
|
||||
use_fullframe: bool
|
||||
num_frames: int
|
||||
fps: float
|
||||
ocr_frame_start: int
|
||||
num_ocr_frames: int
|
||||
pred_frames: List[PredictedFrame]
|
||||
pred_subs: List[PredictedSubtitle]
|
||||
|
||||
def __init__(self, path, lang, use_fullframe=False):
|
||||
def __init__(self, path: str, lang: str, use_fullframe=False,
|
||||
time_start='0:00', time_end=''):
|
||||
self.path = path
|
||||
self.lang = lang
|
||||
self.use_fullframe = use_fullframe
|
||||
@ -26,24 +29,53 @@ class Video:
|
||||
self.fps = v.get(cv2.CAP_PROP_FPS)
|
||||
v.release()
|
||||
|
||||
self.ocr_frame_start = self._frame_index(time_start)
|
||||
|
||||
if time_end:
|
||||
ocr_end = self._frame_index(time_end)
|
||||
else:
|
||||
ocr_end = self.num_frames
|
||||
self.num_ocr_frames = ocr_end - self.ocr_frame_start
|
||||
|
||||
if self.num_ocr_frames < 0:
|
||||
raise ValueError('time_start is later than time_end')
|
||||
|
||||
def _frame_index(self, time: str) -> int:
|
||||
t = time.split(':')
|
||||
t = list(map(int, t))
|
||||
if len(t) == 3:
|
||||
td = datetime.timedelta(hours=t[0], minutes=t[1], seconds=t[2])
|
||||
elif len(t) == 2:
|
||||
td = datetime.timedelta(minutes=t[0], seconds=t[1])
|
||||
else:
|
||||
raise ValueError(
|
||||
'time data "{}" does not match format "%H:%M:%S"'.format(time))
|
||||
|
||||
index = int(td.total_seconds() * self.fps)
|
||||
if index > self.num_frames or index < 0:
|
||||
raise ValueError(
|
||||
'time data "{}" exceeds video duration'.format(time))
|
||||
|
||||
return index
|
||||
|
||||
def run_ocr(self) -> None:
|
||||
v = cv2.VideoCapture(self.path)
|
||||
v.set(cv2.CAP_PROP_POS_FRAMES, self.ocr_frame_start)
|
||||
frames = (v.read()[1] for _ in range(self.num_ocr_frames))
|
||||
|
||||
# perform ocr to all frames in parallel
|
||||
with futures.ProcessPoolExecutor() as pool:
|
||||
ocr_map = pool.map(self._single_frame_ocr, frames, chunksize=10)
|
||||
self.pred_frames = [PredictedFrame(i + self.ocr_frame_start, data)
|
||||
for i, data in enumerate(ocr_map)]
|
||||
|
||||
v.release()
|
||||
|
||||
def _single_frame_ocr(self, img) -> str:
|
||||
if not self.use_fullframe:
|
||||
# only use bottom half of the frame by default
|
||||
img = img[img.shape[0] // 2:, :]
|
||||
data = pytesseract.image_to_data(img, lang=self.lang)
|
||||
return data
|
||||
|
||||
def run_ocr(self) -> None:
|
||||
v = cv2.VideoCapture(self.path)
|
||||
frames = (v.read()[1] for _ in range(self.num_frames))
|
||||
|
||||
# perform ocr to all frames in parallel
|
||||
with futures.ProcessPoolExecutor() as pool:
|
||||
frames_ocr = pool.map(self._single_frame_ocr, frames, chunksize=10)
|
||||
self.pred_frames = [PredictedFrame(i, data)
|
||||
for i, data in enumerate(frames_ocr)]
|
||||
|
||||
v.release()
|
||||
return pytesseract.image_to_data(img, lang=self.lang)
|
||||
|
||||
def get_subtitles(self) -> str:
|
||||
self._generate_subtitles()
|
||||
@ -67,7 +99,7 @@ class Video:
|
||||
bound = WIN_BOUND
|
||||
i = 0
|
||||
j = 1
|
||||
while j < self.num_frames:
|
||||
while j < self.num_ocr_frames:
|
||||
fi, fj = self.pred_frames[i], self.pred_frames[j]
|
||||
|
||||
if fi.is_similar_to(fj):
|
||||
@ -86,7 +118,7 @@ class Video:
|
||||
j += 1
|
||||
|
||||
# also handle the last remaining frames
|
||||
if i < self.num_frames - 1:
|
||||
if i < self.num_ocr_frames - 1:
|
||||
self._append_sub(PredictedSubtitle(self.pred_frames[i:]))
|
||||
|
||||
def _append_sub(self, sub: PredictedSubtitle) -> None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user