make conf_threshold adjustable through api
This commit is contained in:
parent
a5e6845a1b
commit
77362dce1a
@ -5,8 +5,8 @@ from . import constants
|
||||
from .video import Video
|
||||
|
||||
|
||||
def get_subtitles(video_path: str, lang='eng',
|
||||
time_start='0:00', time_end='', use_fullframe=False) -> str:
|
||||
def get_subtitles(video_path: str, lang='eng', time_start='0:00', time_end='',
|
||||
conf_threshold=65, use_fullframe=False) -> str:
|
||||
# download tesseract data file to ~/tessdata if necessary
|
||||
fpath = constants.TESSDATA_DIR / '{}.traineddata'.format(lang)
|
||||
if not fpath.is_file():
|
||||
@ -19,13 +19,15 @@ def get_subtitles(video_path: str, lang='eng',
|
||||
shutil.copyfileobj(res, f)
|
||||
|
||||
v = Video(video_path)
|
||||
v.run_ocr(lang, time_start, time_end, use_fullframe)
|
||||
v.run_ocr(lang, time_start, time_end, conf_threshold, use_fullframe)
|
||||
return v.get_subtitles()
|
||||
|
||||
|
||||
def save_subtitles_to_file(
|
||||
video_path: str, file_path='subtitle.srt', lang='eng',
|
||||
time_start='0:00', time_end='', use_fullframe=False) -> None:
|
||||
time_start='0:00', time_end='', conf_threshold=65,
|
||||
use_fullframe=False) -> None:
|
||||
with open(file_path, 'w+') as f:
|
||||
f.write(get_subtitles(
|
||||
video_path, lang, time_start, time_end, use_fullframe))
|
||||
video_path, lang, time_start, time_end, conf_threshold,
|
||||
use_fullframe))
|
||||
|
@ -17,7 +17,7 @@ class PredictedFrame:
|
||||
confidence: int # total confidence of all words
|
||||
text: str
|
||||
|
||||
def __init__(self, index: int, pred_data: str, conf_threshold=70):
|
||||
def __init__(self, index: int, pred_data: str, conf_threshold: int):
|
||||
self.index = index
|
||||
self.words = []
|
||||
|
||||
|
@ -14,18 +14,22 @@ class Video:
|
||||
use_fullframe: bool
|
||||
num_frames: int
|
||||
fps: float
|
||||
height: int
|
||||
pred_frames: List[PredictedFrame]
|
||||
pred_subs: List[PredictedSubtitle]
|
||||
|
||||
def __init__(self, path: str):
|
||||
self.path = path
|
||||
v = cv2.VideoCapture(path)
|
||||
if not v.isOpened():
|
||||
raise IOError('can not open video format {}'.format(path))
|
||||
self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
self.fps = v.get(cv2.CAP_PROP_FPS)
|
||||
self.height = int(v.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
v.release()
|
||||
|
||||
def run_ocr(self, lang: str, time_start: str, time_end: str,
|
||||
use_fullframe: bool) -> None:
|
||||
conf_threshold: int, use_fullframe: bool) -> None:
|
||||
self.lang = lang
|
||||
self.use_fullframe = use_fullframe
|
||||
|
||||
@ -44,7 +48,8 @@ class Video:
|
||||
# perform ocr to frames in parallel
|
||||
with futures.ProcessPoolExecutor() as pool:
|
||||
ocr_map = pool.map(self._single_frame_ocr, frames, chunksize=10)
|
||||
self.pred_frames = [PredictedFrame(i + ocr_start, data)
|
||||
self.pred_frames = [
|
||||
PredictedFrame(i + ocr_start, data, conf_threshold)
|
||||
for i, data in enumerate(ocr_map)]
|
||||
|
||||
v.release()
|
||||
|
Loading…
Reference in New Issue
Block a user