make conf_threshold adjustable through api

This commit is contained in:
Yi Ge 2019-04-29 03:05:02 +02:00
parent a5e6845a1b
commit 77362dce1a
3 changed files with 16 additions and 9 deletions

View File

@ -5,8 +5,8 @@ from . import constants
from .video import Video
def get_subtitles(video_path: str, lang='eng',
time_start='0:00', time_end='', use_fullframe=False) -> str:
def get_subtitles(video_path: str, lang='eng', time_start='0:00', time_end='',
conf_threshold=65, use_fullframe=False) -> str:
# download tesseract data file to ~/tessdata if necessary
fpath = constants.TESSDATA_DIR / '{}.traineddata'.format(lang)
if not fpath.is_file():
@ -19,13 +19,15 @@ def get_subtitles(video_path: str, lang='eng',
shutil.copyfileobj(res, f)
v = Video(video_path)
v.run_ocr(lang, time_start, time_end, use_fullframe)
v.run_ocr(lang, time_start, time_end, conf_threshold, use_fullframe)
return v.get_subtitles()
def save_subtitles_to_file(
video_path: str, file_path='subtitle.srt', lang='eng',
time_start='0:00', time_end='', use_fullframe=False) -> None:
time_start='0:00', time_end='', conf_threshold=65,
use_fullframe=False) -> None:
with open(file_path, 'w+') as f:
f.write(get_subtitles(
video_path, lang, time_start, time_end, use_fullframe))
video_path, lang, time_start, time_end, conf_threshold,
use_fullframe))

View File

@ -17,7 +17,7 @@ class PredictedFrame:
confidence: int # total confidence of all words
text: str
def __init__(self, index: int, pred_data: str, conf_threshold=70):
def __init__(self, index: int, pred_data: str, conf_threshold: int):
self.index = index
self.words = []

View File

@ -14,18 +14,22 @@ class Video:
use_fullframe: bool
num_frames: int
fps: float
height: int
pred_frames: List[PredictedFrame]
pred_subs: List[PredictedSubtitle]
def __init__(self, path: str):
self.path = path
v = cv2.VideoCapture(path)
if not v.isOpened():
raise IOError('can not open video format {}'.format(path))
self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT))
self.fps = v.get(cv2.CAP_PROP_FPS)
self.height = int(v.get(cv2.CAP_PROP_FRAME_HEIGHT))
v.release()
def run_ocr(self, lang: str, time_start: str, time_end: str,
use_fullframe: bool) -> None:
conf_threshold: int, use_fullframe: bool) -> None:
self.lang = lang
self.use_fullframe = use_fullframe
@ -44,7 +48,8 @@ class Video:
# perform ocr to frames in parallel
with futures.ProcessPoolExecutor() as pool:
ocr_map = pool.map(self._single_frame_ocr, frames, chunksize=10)
self.pred_frames = [PredictedFrame(i + ocr_start, data)
self.pred_frames = [
PredictedFrame(i + ocr_start, data, conf_threshold)
for i, data in enumerate(ocr_map)]
v.release()