make conf_threshold adjustable through api
This commit is contained in:
parent
a5e6845a1b
commit
77362dce1a
@ -5,8 +5,8 @@ from . import constants
|
|||||||
from .video import Video
|
from .video import Video
|
||||||
|
|
||||||
|
|
||||||
def get_subtitles(video_path: str, lang='eng',
|
def get_subtitles(video_path: str, lang='eng', time_start='0:00', time_end='',
|
||||||
time_start='0:00', time_end='', use_fullframe=False) -> str:
|
conf_threshold=65, use_fullframe=False) -> str:
|
||||||
# download tesseract data file to ~/tessdata if necessary
|
# download tesseract data file to ~/tessdata if necessary
|
||||||
fpath = constants.TESSDATA_DIR / '{}.traineddata'.format(lang)
|
fpath = constants.TESSDATA_DIR / '{}.traineddata'.format(lang)
|
||||||
if not fpath.is_file():
|
if not fpath.is_file():
|
||||||
@ -19,13 +19,15 @@ def get_subtitles(video_path: str, lang='eng',
|
|||||||
shutil.copyfileobj(res, f)
|
shutil.copyfileobj(res, f)
|
||||||
|
|
||||||
v = Video(video_path)
|
v = Video(video_path)
|
||||||
v.run_ocr(lang, time_start, time_end, use_fullframe)
|
v.run_ocr(lang, time_start, time_end, conf_threshold, use_fullframe)
|
||||||
return v.get_subtitles()
|
return v.get_subtitles()
|
||||||
|
|
||||||
|
|
||||||
def save_subtitles_to_file(
|
def save_subtitles_to_file(
|
||||||
video_path: str, file_path='subtitle.srt', lang='eng',
|
video_path: str, file_path='subtitle.srt', lang='eng',
|
||||||
time_start='0:00', time_end='', use_fullframe=False) -> None:
|
time_start='0:00', time_end='', conf_threshold=65,
|
||||||
|
use_fullframe=False) -> None:
|
||||||
with open(file_path, 'w+') as f:
|
with open(file_path, 'w+') as f:
|
||||||
f.write(get_subtitles(
|
f.write(get_subtitles(
|
||||||
video_path, lang, time_start, time_end, use_fullframe))
|
video_path, lang, time_start, time_end, conf_threshold,
|
||||||
|
use_fullframe))
|
||||||
|
@ -17,7 +17,7 @@ class PredictedFrame:
|
|||||||
confidence: int # total confidence of all words
|
confidence: int # total confidence of all words
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
def __init__(self, index: int, pred_data: str, conf_threshold=70):
|
def __init__(self, index: int, pred_data: str, conf_threshold: int):
|
||||||
self.index = index
|
self.index = index
|
||||||
self.words = []
|
self.words = []
|
||||||
|
|
||||||
|
@ -14,18 +14,22 @@ class Video:
|
|||||||
use_fullframe: bool
|
use_fullframe: bool
|
||||||
num_frames: int
|
num_frames: int
|
||||||
fps: float
|
fps: float
|
||||||
|
height: int
|
||||||
pred_frames: List[PredictedFrame]
|
pred_frames: List[PredictedFrame]
|
||||||
pred_subs: List[PredictedSubtitle]
|
pred_subs: List[PredictedSubtitle]
|
||||||
|
|
||||||
def __init__(self, path: str):
|
def __init__(self, path: str):
|
||||||
self.path = path
|
self.path = path
|
||||||
v = cv2.VideoCapture(path)
|
v = cv2.VideoCapture(path)
|
||||||
|
if not v.isOpened():
|
||||||
|
raise IOError('can not open video format {}'.format(path))
|
||||||
self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT))
|
self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
self.fps = v.get(cv2.CAP_PROP_FPS)
|
self.fps = v.get(cv2.CAP_PROP_FPS)
|
||||||
|
self.height = int(v.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
v.release()
|
v.release()
|
||||||
|
|
||||||
def run_ocr(self, lang: str, time_start: str, time_end: str,
|
def run_ocr(self, lang: str, time_start: str, time_end: str,
|
||||||
use_fullframe: bool) -> None:
|
conf_threshold: int, use_fullframe: bool) -> None:
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
self.use_fullframe = use_fullframe
|
self.use_fullframe = use_fullframe
|
||||||
|
|
||||||
@ -44,7 +48,8 @@ class Video:
|
|||||||
# perform ocr to frames in parallel
|
# perform ocr to frames in parallel
|
||||||
with futures.ProcessPoolExecutor() as pool:
|
with futures.ProcessPoolExecutor() as pool:
|
||||||
ocr_map = pool.map(self._single_frame_ocr, frames, chunksize=10)
|
ocr_map = pool.map(self._single_frame_ocr, frames, chunksize=10)
|
||||||
self.pred_frames = [PredictedFrame(i + ocr_start, data)
|
self.pred_frames = [
|
||||||
|
PredictedFrame(i + ocr_start, data, conf_threshold)
|
||||||
for i, data in enumerate(ocr_map)]
|
for i, data in enumerate(ocr_map)]
|
||||||
|
|
||||||
v.release()
|
v.release()
|
||||||
|
Loading…
Reference in New Issue
Block a user