add api definition
This commit is contained in:
parent
bc84ee39ff
commit
bd6f15978b
17
videocr/api.py
Normal file
17
videocr/api.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
|
||||||
|
from .video import Video
|
||||||
|
|
||||||
|
|
||||||
|
def get_subtitles(video_path: str, lang='eng',
|
||||||
|
time_start='0:00', time_end='', use_fullframe=False) -> str:
|
||||||
|
v = Video(video_path)
|
||||||
|
v.run_ocr(lang, time_start, time_end, use_fullframe)
|
||||||
|
return v.get_subtitles()
|
||||||
|
|
||||||
|
|
||||||
|
def save_subtitles_to_file(
|
||||||
|
video_path: str, file_path='subtitle.srt', lang='eng',
|
||||||
|
time_start='0:00', time_end='', use_fullframe=False) -> None:
|
||||||
|
with open(file_path, 'w+') as f:
|
||||||
|
f.write(get_subtitles(
|
||||||
|
video_path, lang, time_start, time_end, use_fullframe))
|
@ -48,9 +48,8 @@ class PredictedFrame:
|
|||||||
|
|
||||||
self.text = ' '.join(word.text for word in self.words)
|
self.text = ' '.join(word.text for word in self.words)
|
||||||
# remove chars that are obviously ocr errors
|
# remove chars that are obviously ocr errors
|
||||||
translate_table = {ord(c): None for c in '<>{}[];`@#$%^*_=~\\'}
|
table = str.maketrans('|', 'I', '<>{}[];`@#$%^*_=~\\')
|
||||||
translate_table[ord('|')] = 'I'
|
self.text = self.text.translate(table).replace(' \n ', '\n').strip()
|
||||||
self.text = self.text.translate(translate_table).strip()
|
|
||||||
|
|
||||||
def is_similar_to(self, other: PredictedFrame, threshold=70) -> bool:
|
def is_similar_to(self, other: PredictedFrame, threshold=70) -> bool:
|
||||||
return fuzz.ratio(self.text, other.text) >= threshold
|
return fuzz.ratio(self.text, other.text) >= threshold
|
||||||
@ -58,14 +57,13 @@ class PredictedFrame:
|
|||||||
|
|
||||||
class PredictedSubtitle:
|
class PredictedSubtitle:
|
||||||
frames: List[PredictedFrame]
|
frames: List[PredictedFrame]
|
||||||
|
text: str
|
||||||
|
|
||||||
def __init__(self, frames: List[PredictedFrame]):
|
def __init__(self, frames: List[PredictedFrame]):
|
||||||
self.frames = [f for f in frames if f.confidence > 0]
|
self.frames = [f for f in frames if f.confidence > 0]
|
||||||
|
|
||||||
if self.frames:
|
if self.frames:
|
||||||
conf_max = max(f.confidence for f in self.frames)
|
self.text = max(self.frames, key=lambda f: f.confidence).text
|
||||||
self.text = next(f.text for f in self.frames
|
|
||||||
if f.confidence == conf_max)
|
|
||||||
else:
|
else:
|
||||||
self.text = ''
|
self.text = ''
|
||||||
|
|
||||||
|
@ -3,7 +3,6 @@ from concurrent import futures
|
|||||||
import datetime
|
import datetime
|
||||||
import pytesseract
|
import pytesseract
|
||||||
import cv2
|
import cv2
|
||||||
import timeit
|
|
||||||
|
|
||||||
from .models import PredictedFrame, PredictedSubtitle
|
from .models import PredictedFrame, PredictedSubtitle
|
||||||
|
|
||||||
@ -24,19 +23,16 @@ class Video:
|
|||||||
self.fps = v.get(cv2.CAP_PROP_FPS)
|
self.fps = v.get(cv2.CAP_PROP_FPS)
|
||||||
v.release()
|
v.release()
|
||||||
|
|
||||||
def run_ocr(self, lang: str, use_fullframe=False,
|
def run_ocr(self, lang: str, time_start: str, time_end: str,
|
||||||
time_start='0:00', time_end='') -> None:
|
use_fullframe: bool) -> None:
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
self.use_fullframe = use_fullframe
|
self.use_fullframe = use_fullframe
|
||||||
|
|
||||||
ocr_start = self._frame_index(time_start)
|
ocr_start = self._frame_index(time_start) if time_start else 0
|
||||||
|
ocr_end = self._frame_index(time_end) if time_end else self.num_frames
|
||||||
|
|
||||||
if time_end:
|
if ocr_end < ocr_start:
|
||||||
ocr_end = self._frame_index(time_end)
|
raise ValueError('time_start is later than time_end')
|
||||||
if ocr_end < ocr_start:
|
|
||||||
raise ValueError('time_start is later than time_end')
|
|
||||||
else:
|
|
||||||
ocr_end = self.num_frames
|
|
||||||
num_ocr_frames = ocr_end - ocr_start
|
num_ocr_frames = ocr_end - ocr_start
|
||||||
|
|
||||||
# get frames from ocr_start to ocr_end
|
# get frames from ocr_start to ocr_end
|
||||||
@ -55,7 +51,7 @@ class Video:
|
|||||||
# convert time str to frame index
|
# convert time str to frame index
|
||||||
def _frame_index(self, time: str) -> int:
|
def _frame_index(self, time: str) -> int:
|
||||||
t = time.split(':')
|
t = time.split(':')
|
||||||
t = list(map(int, t))
|
t = list(map(float, t))
|
||||||
if len(t) == 3:
|
if len(t) == 3:
|
||||||
td = datetime.timedelta(hours=t[0], minutes=t[1], seconds=t[2])
|
td = datetime.timedelta(hours=t[0], minutes=t[1], seconds=t[2])
|
||||||
elif len(t) == 2:
|
elif len(t) == 2:
|
||||||
@ -139,19 +135,3 @@ class Video:
|
|||||||
m, s = divmod(td.seconds, 60)
|
m, s = divmod(td.seconds, 60)
|
||||||
h, m = divmod(m, 60)
|
h, m = divmod(m, 60)
|
||||||
return '{:02d}:{:02d}:{:02d},{:03d}'.format(h, m, s, ms)
|
return '{:02d}:{:02d}:{:02d},{:03d}'.format(h, m, s, ms)
|
||||||
|
|
||||||
def save_subtitles_to_file(self, path='subtitle.srt') -> None:
|
|
||||||
with open(path, 'w+') as f:
|
|
||||||
f.write(self.get_subtitles())
|
|
||||||
|
|
||||||
|
|
||||||
time_start = timeit.default_timer()
|
|
||||||
v = Video('1.mp4', 'HanS')
|
|
||||||
v.run_ocr()
|
|
||||||
time_stop = timeit.default_timer()
|
|
||||||
print('time for ocr: ', time_stop - time_start)
|
|
||||||
|
|
||||||
time_start = timeit.default_timer()
|
|
||||||
v.save_subtitles_to_file()
|
|
||||||
time_stop = timeit.default_timer()
|
|
||||||
print('time for save sub: ', time_stop - time_start)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user