File size: 3,706 Bytes
13d3de7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from smolagents import Tool
import os
import tempfile
import shutil
import torch
import subprocess
from transcription import run_whisper_transcription
from logging_config import logger, log_buffer
from ffmpeg_setup import ensure_ffmpeg_in_path


class TranscriptTool(Tool):
    name = "TranscriptTool"
    description = """
    A smolagent tool for transcribing audio and video files into text. This tool utilises Whisper for transcription 
    and ffmpeg for media conversion, enabling agents to process multimedia inputs into text. The tool supports robust 
    file handling, including format conversion to WAV and dynamic device selection for optimal performance.
    """
    inputs = {
        "file_path": {
            "type": "string",
            "description": "Path to the audio or video file for transcription."
        }
    }
    output_type = "string"

    def __init__(self, audio_directory=None):
        super().__init__()
        ensure_ffmpeg_in_path()
        self.audio_directory = audio_directory or os.getcwd()

    def locate_audio_file(self, file_name):
        for root, _, files in os.walk(self.audio_directory):
            if file_name in files:
                return os.path.join(root, file_name)
        return None

    def convert_audio_to_wav(self, input_file: str, output_file: str, ffmpeg_path: str) -> str:
        logger.info(f"Converting {input_file} to WAV format: {output_file}")
        cmd = [
            ffmpeg_path,
            "-y",  # Overwrite output files without asking
            "-i", input_file,
            "-ar", "16000",  # Set audio sampling rate to 16kHz
            "-ac", "1",      # Set number of audio channels to mono
            output_file
        ]
        try:
            subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            logger.info("Audio conversion to WAV completed successfully.")
            return output_file
        except subprocess.CalledProcessError as e:
            ffmpeg_error = e.stderr.decode()
            logger.error(f"ffmpeg error: {ffmpeg_error}")
            raise RuntimeError("Failed to convert audio to WAV.") from e

    def forward(self, file_path: str) -> str:
        log_buffer.seek(0)
        log_buffer.truncate()

        try:
            # Locate the file if it does not exist
            if not os.path.isfile(file_path):
                file_name = os.path.basename(file_path)
                file_path = self.locate_audio_file(file_name)
                if not file_path:
                    return f"Error: File '{file_name}' not found in '{self.audio_directory}'."

            with tempfile.TemporaryDirectory() as tmpdir:
                # Copy file to temp dir
                filename = os.path.basename(file_path)
                input_file_path = os.path.join(tmpdir, filename)
                shutil.copy(file_path, input_file_path)

                # Convert to wav
                wav_file_path = os.path.join(tmpdir, "converted_audio.wav")
                ffmpeg_path = shutil.which("ffmpeg")
                if not ffmpeg_path:
                    raise RuntimeError("ffmpeg is not accessible in PATH.")
                self.convert_audio_to_wav(input_file_path, wav_file_path, ffmpeg_path)

                device = "cuda" if torch.cuda.is_available() else "cpu"

                # Transcribe audio
                transcription_generator = run_whisper_transcription(wav_file_path, device)
                for transcription, _ in transcription_generator:
                    return transcription

        except Exception as e:
            logger.error(f"Error in transcription: {str(e)}")
            return f"An error occurred: {str(e)}"