Coco-18 commited on
Commit
e085921
Β·
verified Β·
1 Parent(s): 168acfa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -49
app.py CHANGED
@@ -4,100 +4,150 @@ import torchaudio
4
  import soundfile as sf
5
  from flask import Flask, request, jsonify, send_file
6
  from flask_cors import CORS
7
- from transformers import VitsModel, AutoTokenizer
8
 
9
- # Set ALL cache directories to /tmp (writable in Hugging Face Spaces)
10
  os.environ["HF_HOME"] = "/tmp/hf_home"
11
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
12
  os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface_cache"
13
  os.environ["TORCH_HOME"] = "/tmp/torch_home"
14
 
15
  app = Flask(__name__)
16
- CORS(app) # Allow external requests
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- # Model paths for different languages (Hugging Face Hub)
19
- MODELS = {
20
  "kapampangan": "facebook/mms-tts-pam",
21
  "tagalog": "facebook/mms-tts-tgl",
22
  "english": "facebook/mms-tts-eng"
23
  }
24
 
25
- loaded_models = {}
26
- loaded_processors = {}
27
-
28
- for lang, path in MODELS.items():
29
  try:
30
- print(f"πŸ”„ Loading {lang} model: {path}...")
31
- # Force models to save in /tmp
32
- loaded_models[lang] = VitsModel.from_pretrained(path, cache_dir="/tmp/huggingface_cache")
33
- loaded_processors[lang] = AutoTokenizer.from_pretrained(path, cache_dir="/tmp/huggingface_cache")
34
- print(f"βœ… {lang.capitalize()} model loaded successfully!")
35
  except Exception as e:
36
- print(f"❌ Error loading {lang} model: {str(e)}")
37
- loaded_models[lang] = None # Mark as unavailable
38
- loaded_processors[lang] = None
39
 
40
  # Constants
 
41
  OUTPUT_DIR = "/tmp/"
42
  os.makedirs(OUTPUT_DIR, exist_ok=True)
43
 
 
44
  @app.route("/", methods=["GET"])
45
  def home():
46
- """Root route to check if the API is running"""
47
- return jsonify({"message": "TTS API is running. Use /tts to generate speech."})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  @app.route("/tts", methods=["POST"])
50
  def generate_tts():
51
- """API endpoint to generate TTS audio"""
52
  try:
53
- # Get request data
54
  data = request.get_json()
55
  text_input = data.get("text", "").strip()
56
  language = data.get("language", "kapampangan").lower()
57
 
58
- # Validate inputs
59
- if language not in MODELS:
60
- return jsonify({"error": "Invalid language. Choose 'kapampangan', 'tagalog', or 'english'."}), 400
61
  if not text_input:
62
  return jsonify({"error": "No text provided"}), 400
63
- if loaded_models[language] is None:
64
- return jsonify({"error": f"Model for {language} failed to load"}), 500
65
-
66
- print(f"πŸ”„ Generating speech for '{text_input}' in {language}...")
67
 
68
- # Process text input
69
- processor = loaded_processors[language]
70
- model = loaded_models[language]
71
  inputs = processor(text_input, return_tensors="pt")
72
 
73
- # Generate speech - using model(**inputs) instead of model.generate()
74
  with torch.no_grad():
75
- output = model(**inputs).waveform
76
- waveform = output.squeeze().cpu().numpy()
77
-
78
- # Save to file
79
- output_filename = os.path.join(OUTPUT_DIR, f"{language}_output.wav")
80
- # Use the model's sampling rate
81
- sampling_rate = model.config.sampling_rate
82
- sf.write(output_filename, waveform, sampling_rate)
83
- print(f"βœ… Speech generated! File saved: {output_filename}")
84
-
85
- return jsonify({
86
- "message": "TTS audio generated",
87
- "file_url": f"/download/{language}_output.wav"
88
- })
89
  except Exception as e:
90
- print(f"❌ Error generating TTS: {e}")
91
- return jsonify({"error": f"Internal server error: {str(e)}"}), 500
92
 
93
  @app.route("/download/<filename>", methods=["GET"])
94
  def download_audio(filename):
95
- """Serve generated audio files"""
96
  file_path = os.path.join(OUTPUT_DIR, filename)
97
  if os.path.exists(file_path):
98
  return send_file(file_path, mimetype="audio/wav", as_attachment=True)
99
  return jsonify({"error": "File not found"}), 404
100
 
 
101
  if __name__ == "__main__":
102
  app.run(host="0.0.0.0", port=7860, debug=True)
103
 
 
4
  import soundfile as sf
5
  from flask import Flask, request, jsonify, send_file
6
  from flask_cors import CORS
7
+ from transformers import Wav2Vec2ForCTC, AutoProcessor, VitsModel, AutoTokenizer
8
 
9
+ # Set cache directories
10
  os.environ["HF_HOME"] = "/tmp/hf_home"
11
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
12
  os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface_cache"
13
  os.environ["TORCH_HOME"] = "/tmp/torch_home"
14
 
15
  app = Flask(__name__)
16
+ CORS(app)
17
+
18
+ # ASR Model (facebook/mms-1b-all)
19
+ ASR_MODEL_ID = "Coco-18/mms-asr-tgl-en-safetensor"
20
+ asr_processor = AutoProcessor.from_pretrained(ASR_MODEL_ID)
21
+ asr_model = Wav2Vec2ForCTC.from_pretrained(ASR_MODEL_ID)
22
+
23
+ # Language-specific configurations
24
+ LANGUAGE_CODES = {
25
+ "kapampangan": "pam",
26
+ "tagalog": "tgl",
27
+ "english": "eng"
28
+ }
29
 
30
+ # TTS Models (Kapampangan, Tagalog, English)
31
+ TTS_MODELS = {
32
  "kapampangan": "facebook/mms-tts-pam",
33
  "tagalog": "facebook/mms-tts-tgl",
34
  "english": "facebook/mms-tts-eng"
35
  }
36
 
37
+ tts_models = {}
38
+ tts_processors = {}
39
+ for lang, model_id in TTS_MODELS.items():
 
40
  try:
41
+ tts_models[lang] = VitsModel.from_pretrained(model_id, cache_dir="/tmp/huggingface_cache")
42
+ tts_processors[lang] = AutoTokenizer.from_pretrained(model_id, cache_dir="/tmp/huggingface_cache")
43
+ print(f"βœ… TTS Model loaded: {lang}")
 
 
44
  except Exception as e:
45
+ print(f"❌ Error loading {lang} TTS model: {e}")
46
+ tts_models[lang] = None
 
47
 
48
  # Constants
49
+ SAMPLE_RATE = 16000
50
  OUTPUT_DIR = "/tmp/"
51
  os.makedirs(OUTPUT_DIR, exist_ok=True)
52
 
53
+
54
  @app.route("/", methods=["GET"])
55
  def home():
56
+ return jsonify({"message": "Speech API is running."})
57
+
58
+
59
+ @app.route("/asr", methods=["POST"])
60
+ def transcribe_audio():
61
+ try:
62
+ if "audio" not in request.files:
63
+ return jsonify({"error": "No audio file uploaded"}), 400
64
+
65
+ audio_file = request.files["audio"]
66
+ language = request.form.get("language", "english").lower()
67
+
68
+ # Validate language
69
+ if language not in LANGUAGE_CODES:
70
+ return jsonify({"error": f"Unsupported language: {language}"}), 400
71
+
72
+ # Get the language code for the ASR model
73
+ lang_code = LANGUAGE_CODES[language]
74
+
75
+ # Save audio file temporarily
76
+ audio_path = os.path.join(OUTPUT_DIR, "input_audio.wav")
77
+ audio_file.save(audio_path)
78
+
79
+ # Load and process audio
80
+ try:
81
+ waveform, sr = torchaudio.load(audio_path)
82
+ if sr != SAMPLE_RATE:
83
+ waveform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(waveform)
84
+
85
+ # Normalize audio (recommended for Wav2Vec2)
86
+ waveform = waveform / torch.max(torch.abs(waveform))
87
+
88
+ # Process audio for ASR
89
+ inputs = asr_processor(
90
+ waveform.squeeze().numpy(),
91
+ sampling_rate=SAMPLE_RATE,
92
+ return_tensors="pt",
93
+ language=lang_code # Set the language code
94
+ )
95
+ except Exception as e:
96
+ return jsonify({"error": f"Error processing audio: {str(e)}"}), 400
97
+
98
+ # Transcribe
99
+ with torch.no_grad():
100
+ logits = asr_model(**inputs).logits
101
+ ids = torch.argmax(logits, dim=-1)[0]
102
+ transcription = asr_processor.decode(ids)
103
+
104
+ # Log the transcription
105
+ print(f"Transcription ({language}): {transcription}")
106
+
107
+ return jsonify({"transcription": transcription})
108
+ except Exception as e:
109
+ print(f"ASR error: {str(e)}")
110
+ return jsonify({"error": f"ASR failed: {str(e)}"}), 500
111
+
112
 
113
  @app.route("/tts", methods=["POST"])
114
  def generate_tts():
 
115
  try:
 
116
  data = request.get_json()
117
  text_input = data.get("text", "").strip()
118
  language = data.get("language", "kapampangan").lower()
119
 
120
+ if language not in TTS_MODELS:
121
+ return jsonify({"error": "Invalid language"}), 400
 
122
  if not text_input:
123
  return jsonify({"error": "No text provided"}), 400
124
+ if tts_models[language] is None:
125
+ return jsonify({"error": "TTS model not available"}), 500
 
 
126
 
127
+ processor = tts_processors[language]
128
+ model = tts_models[language]
 
129
  inputs = processor(text_input, return_tensors="pt")
130
 
 
131
  with torch.no_grad():
132
+ output = model.generate(**inputs)
133
+
134
+ waveform = output.cpu().numpy().flatten()
135
+ output_filename = os.path.join(OUTPUT_DIR, f"{language}_tts.wav")
136
+ sf.write(output_filename, waveform, SAMPLE_RATE)
137
+
138
+ return jsonify({"file_url": f"/download/{language}_tts.wav"})
 
 
 
 
 
 
 
139
  except Exception as e:
140
+ return jsonify({"error": f"TTS failed: {e}"}), 500
141
+
142
 
143
  @app.route("/download/<filename>", methods=["GET"])
144
  def download_audio(filename):
 
145
  file_path = os.path.join(OUTPUT_DIR, filename)
146
  if os.path.exists(file_path):
147
  return send_file(file_path, mimetype="audio/wav", as_attachment=True)
148
  return jsonify({"error": "File not found"}), 404
149
 
150
+
151
  if __name__ == "__main__":
152
  app.run(host="0.0.0.0", port=7860, debug=True)
153