Spaces:
Build error
Build error
pengdaqian
commited on
Commit
·
ca13e69
1
Parent(s):
f0c5f90
add more
Browse files- audio_to_text.py +2 -1
audio_to_text.py
CHANGED
|
@@ -10,6 +10,7 @@ from transformers import pipeline
|
|
| 10 |
|
| 11 |
class AudioPipeline(object):
|
| 12 |
def __init__(self, audio_text_path, audio_text_embeddings_path):
|
|
|
|
| 13 |
self.model = laion_clap.CLAP_Module(enable_fusion=False)
|
| 14 |
self.model.load_ckpt() # download the default pretrained checkpoint.
|
| 15 |
self.audio_text_path = audio_text_path
|
|
@@ -39,7 +40,7 @@ class AudioPipeline(object):
|
|
| 39 |
texts = json.load(f)
|
| 40 |
|
| 41 |
tensors = {}
|
| 42 |
-
with safe_open(self.audio_text_embeddings_path, framework="pt", device=
|
| 43 |
for k in f.keys():
|
| 44 |
tensors[k] = f.get_tensor(k)
|
| 45 |
text_embed = tensors["text_embed"]
|
|
|
|
| 10 |
|
| 11 |
class AudioPipeline(object):
|
| 12 |
def __init__(self, audio_text_path, audio_text_embeddings_path):
|
| 13 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 14 |
self.model = laion_clap.CLAP_Module(enable_fusion=False)
|
| 15 |
self.model.load_ckpt() # download the default pretrained checkpoint.
|
| 16 |
self.audio_text_path = audio_text_path
|
|
|
|
| 40 |
texts = json.load(f)
|
| 41 |
|
| 42 |
tensors = {}
|
| 43 |
+
with safe_open(self.audio_text_embeddings_path, framework="pt", device=self.device) as f:
|
| 44 |
for k in f.keys():
|
| 45 |
tensors[k] = f.get_tensor(k)
|
| 46 |
text_embed = tensors["text_embed"]
|