Spaces:
Runtime error
Runtime error
<!--Copyright 2023 The HuggingFace Team. All rights reserved. | |
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | |
the License. You may obtain a copy of the License at | |
http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | |
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
specific language governing permissions and limitations under the License. | |
โ ๏ธ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | |
rendered properly in your Markdown viewer. | |
--> | |
# ์๋ ์์ฑ ์ธ์[[automatic-speech-recognition]] | |
[[open-in-colab]] | |
<Youtube id="TksaY_FDgnk"/> | |
์๋ ์์ฑ ์ธ์(Automatic Speech Recognition, ASR)์ ์์ฑ ์ ํธ๋ฅผ ํ ์คํธ๋ก ๋ณํํ์ฌ ์์ฑ ์ ๋ ฅ ์ํ์ค๋ฅผ ํ ์คํธ ์ถ๋ ฅ์ ๋งคํํฉ๋๋ค. | |
Siri์ Alexa์ ๊ฐ์ ๊ฐ์ ์ด์์คํดํธ๋ ASR ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์ผ์์ ์ผ๋ก ์ฌ์ฉ์๋ฅผ ๋๊ณ ์์ผ๋ฉฐ, ํ์ ์ค ๋ผ์ด๋ธ ์บก์ ๋ฐ ๋ฉ๋ชจ ์์ฑ๊ณผ ๊ฐ์ ์ ์ฉํ ์ฌ์ฉ์ ์นํ์ ์์ฉ ํ๋ก๊ทธ๋จ๋ ๋ง์ด ์์ต๋๋ค. | |
์ด ๊ฐ์ด๋์์ ์๊ฐํ ๋ด์ฉ์ ์๋์ ๊ฐ์ต๋๋ค: | |
1. [MInDS-14](https://huggingface.co/datasets/PolyAI/minds14) ๋ฐ์ดํฐ ์ธํธ์์ [Wav2Vec2](https://huggingface.co/facebook/wav2vec2-base)๋ฅผ ๋ฏธ์ธ ์กฐ์ ํ์ฌ ์ค๋์ค๋ฅผ ํ ์คํธ๋ก ๋ณํํฉ๋๋ค. | |
2. ๋ฏธ์ธ ์กฐ์ ํ ๋ชจ๋ธ์ ์ถ๋ก ์ ์ฌ์ฉํฉ๋๋ค. | |
<Tip> | |
์ด ํํ ๋ฆฌ์ผ์์ ์ค๋ช ํ๋ ์์ ์ ๋ค์ ๋ชจ๋ธ ์ํคํ ์ฒ์ ์ํด ์ง์๋ฉ๋๋ค: | |
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!--> | |
[Data2VecAudio](../model_doc/data2vec-audio), [Hubert](../model_doc/hubert), [M-CTC-T](../model_doc/mctct), [SEW](../model_doc/sew), [SEW-D](../model_doc/sew-d), [UniSpeech](../model_doc/unispeech), [UniSpeechSat](../model_doc/unispeech-sat), [Wav2Vec2](../model_doc/wav2vec2), [Wav2Vec2-Conformer](../model_doc/wav2vec2-conformer), [WavLM](../model_doc/wavlm) | |
<!--End of the generated tip--> | |
</Tip> | |
์์ํ๊ธฐ ์ ์ ํ์ํ ๋ชจ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์ด ์๋์ง ํ์ธํ์ธ์: | |
```bash | |
pip install transformers datasets evaluate jiwer | |
``` | |
Hugging Face ๊ณ์ ์ ๋ก๊ทธ์ธํ๋ฉด ๋ชจ๋ธ์ ์ ๋ก๋ํ๊ณ ์ปค๋ฎค๋ํฐ์ ๊ณต์ ํ ์ ์์ต๋๋ค. ํ ํฐ์ ์ ๋ ฅํ์ฌ ๋ก๊ทธ์ธํ์ธ์. | |
```py | |
>>> from huggingface_hub import notebook_login | |
>>> notebook_login() | |
``` | |
## MInDS-14 ๋ฐ์ดํฐ ์ธํธ ๊ฐ์ ธ์ค๊ธฐ[[load-minds-14-dataset]] | |
๋จผ์ , ๐ค Datasets ๋ผ์ด๋ธ๋ฌ๋ฆฌ์์ [MInDS-14](https://huggingface.co/datasets/PolyAI/minds14) ๋ฐ์ดํฐ ์ธํธ์ ์ผ๋ถ๋ถ์ ๊ฐ์ ธ์ค์ธ์. | |
์ด๋ ๊ฒ ํ๋ฉด ์ ์ฒด ๋ฐ์ดํฐ ์ธํธ์ ๋ํ ํ๋ จ์ ์๊ฐ์ ๋ค์ด๊ธฐ ์ ์ ๋ชจ๋ ๊ฒ์ด ์๋ํ๋์ง ์คํํ๊ณ ๊ฒ์ฆํ ์ ์์ต๋๋ค. | |
```py | |
>>> from datasets import load_dataset, Audio | |
>>> minds = load_dataset("PolyAI/minds14", name="en-US", split="train[:100]") | |
``` | |
[`~Dataset.train_test_split`] ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ๋ฐ์ดํฐ ์ธํธ์ `train`์ ํ๋ จ ์ธํธ์ ํ ์คํธ ์ธํธ๋ก ๋๋์ธ์: | |
```py | |
>>> minds = minds.train_test_split(test_size=0.2) | |
``` | |
๊ทธ๋ฆฌ๊ณ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ํ์ธํ์ธ์: | |
```py | |
>>> minds | |
DatasetDict({ | |
train: Dataset({ | |
features: ['path', 'audio', 'transcription', 'english_transcription', 'intent_class', 'lang_id'], | |
num_rows: 16 | |
}) | |
test: Dataset({ | |
features: ['path', 'audio', 'transcription', 'english_transcription', 'intent_class', 'lang_id'], | |
num_rows: 4 | |
}) | |
}) | |
``` | |
๋ฐ์ดํฐ ์ธํธ์๋ `lang_id`์ `english_transcription`๊ณผ ๊ฐ์ ์ ์ฉํ ์ ๋ณด๊ฐ ๋ง์ด ํฌํจ๋์ด ์์ง๋ง, ์ด ๊ฐ์ด๋์์๋ `audio`์ `transcription`์ ์ด์ ์ ๋ง์ถ ๊ฒ์ ๋๋ค. ๋ค๋ฅธ ์ด์ [`~datasets.Dataset.remove_columns`] ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ์ ๊ฑฐํ์ธ์: | |
```py | |
>>> minds = minds.remove_columns(["english_transcription", "intent_class", "lang_id"]) | |
``` | |
์์๋ฅผ ๋ค์ ํ๋ฒ ํ์ธํด๋ณด์ธ์: | |
```py | |
>>> minds["train"][0] | |
{'audio': {'array': array([-0.00024414, 0. , 0. , ..., 0.00024414, | |
0.00024414, 0.00024414], dtype=float32), | |
'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~APP_ERROR/602ba9e2963e11ccd901cd4f.wav', | |
'sampling_rate': 8000}, | |
'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~APP_ERROR/602ba9e2963e11ccd901cd4f.wav', | |
'transcription': "hi I'm trying to use the banking app on my phone and currently my checking and savings account balance is not refreshing"} | |
``` | |
๋ ๊ฐ์ ํ๋๊ฐ ์์ต๋๋ค: | |
- `audio`: ์ค๋์ค ํ์ผ์ ๊ฐ์ ธ์ค๊ณ ๋ฆฌ์ํ๋งํ๊ธฐ ์ํด ํธ์ถํด์ผ ํ๋ ์์ฑ ์ ํธ์ 1์ฐจ์ `array(๋ฐฐ์ด)` | |
- `transcription`: ๋ชฉํ ํ ์คํธ | |
## ์ ์ฒ๋ฆฌ[[preprocess]] | |
๋ค์์ผ๋ก ์ค๋์ค ์ ํธ๋ฅผ ์ฒ๋ฆฌํ๊ธฐ ์ํ Wav2Vec2 ํ๋ก์ธ์๋ฅผ ๊ฐ์ ธ์ต๋๋ค: | |
```py | |
>>> from transformers import AutoProcessor | |
>>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base") | |
``` | |
MInDS-14 ๋ฐ์ดํฐ ์ธํธ์ ์ํ๋ง ๋ ์ดํธ๋ 8000kHz์ด๋ฏ๋ก([๋ฐ์ดํฐ ์ธํธ ์นด๋](https://huggingface.co/datasets/PolyAI/minds14)์์ ํ์ธ), ์ฌ์ ํ๋ จ๋ Wav2Vec2 ๋ชจ๋ธ์ ์ฌ์ฉํ๋ ค๋ฉด ๋ฐ์ดํฐ ์ธํธ๋ฅผ 16000kHz๋ก ๋ฆฌ์ํ๋งํด์ผ ํฉ๋๋ค: | |
```py | |
>>> minds = minds.cast_column("audio", Audio(sampling_rate=16_000)) | |
>>> minds["train"][0] | |
{'audio': {'array': array([-2.38064706e-04, -1.58618059e-04, -5.43987835e-06, ..., | |
2.78103951e-04, 2.38446111e-04, 1.18740834e-04], dtype=float32), | |
'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~APP_ERROR/602ba9e2963e11ccd901cd4f.wav', | |
'sampling_rate': 16000}, | |
'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~APP_ERROR/602ba9e2963e11ccd901cd4f.wav', | |
'transcription': "hi I'm trying to use the banking app on my phone and currently my checking and savings account balance is not refreshing"} | |
``` | |
์์ 'transcription'์์ ๋ณผ ์ ์๋ฏ์ด ํ ์คํธ๋ ๋๋ฌธ์์ ์๋ฌธ์๊ฐ ์์ฌ ์์ต๋๋ค. Wav2Vec2 ํ ํฌ๋์ด์ ๋ ๋๋ฌธ์ ๋ฌธ์์ ๋ํด์๋ง ํ๋ จ๋์ด ์์ผ๋ฏ๋ก ํ ์คํธ๊ฐ ํ ํฌ๋์ด์ ์ ์ดํ์ ์ผ์นํ๋์ง ํ์ธํด์ผ ํฉ๋๋ค: | |
```py | |
>>> def uppercase(example): | |
... return {"transcription": example["transcription"].upper()} | |
>>> minds = minds.map(uppercase) | |
``` | |
์ด์ ๋ค์ ์์ ์ ์ํํ ์ ์ฒ๋ฆฌ ํจ์๋ฅผ ๋ง๋ค์ด๋ณด๊ฒ ์ต๋๋ค: | |
1. `audio` ์ด์ ํธ์ถํ์ฌ ์ค๋์ค ํ์ผ์ ๊ฐ์ ธ์ค๊ณ ๋ฆฌ์ํ๋งํฉ๋๋ค. | |
2. ์ค๋์ค ํ์ผ์์ `input_values`๋ฅผ ์ถ์ถํ๊ณ ํ๋ก์ธ์๋ก `transcription` ์ด์ ํ ํฐํํฉ๋๋ค. | |
```py | |
>>> def prepare_dataset(batch): | |
... audio = batch["audio"] | |
... batch = processor(audio["array"], sampling_rate=audio["sampling_rate"], text=batch["transcription"]) | |
... batch["input_length"] = len(batch["input_values"][0]) | |
... return batch | |
``` | |
์ ์ฒด ๋ฐ์ดํฐ ์ธํธ์ ์ ์ฒ๋ฆฌ ํจ์๋ฅผ ์ ์ฉํ๋ ค๋ฉด ๐ค Datasets [`~datasets.Dataset.map`] ํจ์๋ฅผ ์ฌ์ฉํ์ธ์. `num_proc` ๋งค๊ฐ๋ณ์๋ฅผ ์ฌ์ฉํ์ฌ ํ๋ก์ธ์ค ์๋ฅผ ๋๋ฆฌ๋ฉด `map`์ ์๋๋ฅผ ๋์ผ ์ ์์ต๋๋ค. [`~datasets.Dataset.remove_columns`] ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ํ์ํ์ง ์์ ์ด์ ์ ๊ฑฐํ์ธ์: | |
```py | |
>>> encoded_minds = minds.map(prepare_dataset, remove_columns=minds.column_names["train"], num_proc=4) | |
``` | |
๐ค Transformers์๋ ์๋ ์์ฑ ์ธ์์ฉ ๋ฐ์ดํฐ ์ฝ๋ ์ดํฐ๊ฐ ์์ผ๋ฏ๋ก ์์ ๋ฐฐ์น๋ฅผ ์์ฑํ๋ ค๋ฉด [`DataCollatorWithPadding`]์ ์กฐ์ ํด์ผ ํฉ๋๋ค. ์ด๋ ๊ฒ ํ๋ฉด ๋ฐ์ดํฐ ์ฝ๋ ์ดํฐ๋ ํ ์คํธ์ ๋ ์ด๋ธ์ ๋ฐฐ์น์์ ๊ฐ์ฅ ๊ธด ์์์ ๊ธธ์ด์ ๋์ ์ผ๋ก ํจ๋ฉํ์ฌ ๊ธธ์ด๋ฅผ ๊ท ์ผํ๊ฒ ํฉ๋๋ค. `tokenizer` ํจ์์์ `padding=True`๋ฅผ ์ค์ ํ์ฌ ํ ์คํธ๋ฅผ ํจ๋ฉํ ์ ์์ง๋ง, ๋์ ํจ๋ฉ์ด ๋ ํจ์จ์ ์ ๋๋ค. | |
๋ค๋ฅธ ๋ฐ์ดํฐ ์ฝ๋ ์ดํฐ์ ๋ฌ๋ฆฌ ์ด ํน์ ๋ฐ์ดํฐ ์ฝ๋ ์ดํฐ๋ `input_values`์ `labels`์ ๋ํด ๋ค๋ฅธ ํจ๋ฉ ๋ฐฉ๋ฒ์ ์ ์ฉํด์ผ ํฉ๋๋ค. | |
```py | |
>>> import torch | |
>>> from dataclasses import dataclass, field | |
>>> from typing import Any, Dict, List, Optional, Union | |
>>> @dataclass | |
... class DataCollatorCTCWithPadding: | |
... processor: AutoProcessor | |
... padding: Union[bool, str] = "longest" | |
... def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: | |
... # ์ ๋ ฅ๊ณผ ๋ ์ด๋ธ์ ๋ถํ ํฉ๋๋ค | |
... # ๊ธธ์ด๊ฐ ๋ค๋ฅด๊ณ , ๊ฐ๊ฐ ๋ค๋ฅธ ํจ๋ฉ ๋ฐฉ๋ฒ์ ์ฌ์ฉํด์ผ ํ๊ธฐ ๋๋ฌธ์ ๋๋ค | |
... input_features = [{"input_values": feature["input_values"][0]} for feature in features] | |
... label_features = [{"input_ids": feature["labels"]} for feature in features] | |
... batch = self.processor.pad(input_features, padding=self.padding, return_tensors="pt") | |
... labels_batch = self.processor.pad(labels=label_features, padding=self.padding, return_tensors="pt") | |
... # ํจ๋ฉ์ ๋ํด ์์ค์ ์ ์ฉํ์ง ์๋๋ก -100์ผ๋ก ๋์ฒดํฉ๋๋ค | |
... labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) | |
... batch["labels"] = labels | |
... return batch | |
``` | |
์ด์ `DataCollatorForCTCWithPadding`์ ์ธ์คํด์คํํฉ๋๋ค: | |
```py | |
>>> data_collator = DataCollatorCTCWithPadding(processor=processor, padding="longest") | |
``` | |
## ํ๊ฐํ๊ธฐ[[evaluate]] | |
ํ๋ จ ์ค์ ํ๊ฐ ์งํ๋ฅผ ํฌํจํ๋ฉด ๋ชจ๋ธ์ ์ฑ๋ฅ์ ํ๊ฐํ๋ ๋ฐ ๋์์ด ๋๋ ๊ฒฝ์ฐ๊ฐ ๋ง์ต๋๋ค. ๐ค [Evaluate](https://huggingface.co/docs/evaluate/index) ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํ๋ฉด ํ๊ฐ ๋ฐฉ๋ฒ์ ๋น ๋ฅด๊ฒ ๋ถ๋ฌ์ฌ ์ ์์ต๋๋ค. | |
์ด ์์ ์์๋ [๋จ์ด ์ค๋ฅ์จ(Word Error Rate, WER)](https://huggingface.co/spaces/evaluate-metric/wer) ํ๊ฐ ์งํ๋ฅผ ๊ฐ์ ธ์ต๋๋ค. | |
(ํ๊ฐ ์งํ๋ฅผ ๋ถ๋ฌ์ค๊ณ ๊ณ์ฐํ๋ ๋ฐฉ๋ฒ์ ๐ค Evaluate [๋๋ฌ๋ณด๊ธฐ](https://huggingface.co/docs/evaluate/a_quick_tour)๋ฅผ ์ฐธ์กฐํ์ธ์): | |
```py | |
>>> import evaluate | |
>>> wer = evaluate.load("wer") | |
``` | |
๊ทธ๋ฐ ๋ค์ ์์ธก๊ฐ๊ณผ ๋ ์ด๋ธ์ [`~evaluate.EvaluationModule.compute`]์ ์ ๋ฌํ์ฌ WER์ ๊ณ์ฐํ๋ ํจ์๋ฅผ ๋ง๋ญ๋๋ค: | |
```py | |
>>> import numpy as np | |
>>> def compute_metrics(pred): | |
... pred_logits = pred.predictions | |
... pred_ids = np.argmax(pred_logits, axis=-1) | |
... pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id | |
... pred_str = processor.batch_decode(pred_ids) | |
... label_str = processor.batch_decode(pred.label_ids, group_tokens=False) | |
... wer = wer.compute(predictions=pred_str, references=label_str) | |
... return {"wer": wer} | |
``` | |
์ด์ `compute_metrics` ํจ์๋ฅผ ์ฌ์ฉํ ์ค๋น๊ฐ ๋์์ผ๋ฉฐ, ํ๋ จ์ ์ค์ ํ ๋ ์ด ํจ์๋ก ๋๋์์ฌ ๊ฒ์ ๋๋ค. | |
## ํ๋ จํ๊ธฐ[[train]] | |
<frameworkcontent> | |
<pt> | |
<Tip> | |
[`Trainer`]๋ก ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ๋ ๊ฒ์ด ์ต์ํ์ง ์๋ค๋ฉด, [์ฌ๊ธฐ](../training#train-with-pytorch-trainer)์์ ๊ธฐ๋ณธ ํํ ๋ฆฌ์ผ์ ํ์ธํด๋ณด์ธ์! | |
</Tip> | |
์ด์ ๋ชจ๋ธ ํ๋ จ์ ์์ํ ์ค๋น๊ฐ ๋์์ต๋๋ค! [`AutoModelForCTC`]๋ก Wav2Vec2๋ฅผ ๊ฐ์ ธ์ค์ธ์. `ctc_loss_reduction` ๋งค๊ฐ๋ณ์๋ก CTC ์์ค์ ์ ์ฉํ ์ถ์(reduction) ๋ฐฉ๋ฒ์ ์ง์ ํ์ธ์. ๊ธฐ๋ณธ๊ฐ์ธ ํฉ๊ณ ๋์ ํ๊ท ์ ์ฌ์ฉํ๋ ๊ฒ์ด ๋ ์ข์ ๊ฒฝ์ฐ๊ฐ ๋ง์ต๋๋ค: | |
```py | |
>>> from transformers import AutoModelForCTC, TrainingArguments, Trainer | |
>>> model = AutoModelForCTC.from_pretrained( | |
... "facebook/wav2vec2-base", | |
... ctc_loss_reduction="mean", | |
... pad_token_id=processor.tokenizer.pad_token_id, | |
... ) | |
``` | |
์ด์ ์ธ ๋จ๊ณ๋ง ๋จ์์ต๋๋ค: | |
1. [`TrainingArguments`]์์ ํ๋ จ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ์ ์ํ์ธ์. `output_dir`์ ๋ชจ๋ธ์ ์ ์ฅํ ๊ฒฝ๋ก๋ฅผ ์ง์ ํ๋ ์ ์ผํ ํ์ ๋งค๊ฐ๋ณ์์ ๋๋ค. `push_to_hub=True`๋ฅผ ์ค์ ํ์ฌ ๋ชจ๋ธ์ Hub์ ์ ๋ก๋ ํ ์ ์์ต๋๋ค(๋ชจ๋ธ์ ์ ๋ก๋ํ๋ ค๋ฉด Hugging Face์ ๋ก๊ทธ์ธํด์ผ ํฉ๋๋ค). [`Trainer`]๋ ๊ฐ ์ํญ๋ง๋ค WER์ ํ๊ฐํ๊ณ ํ๋ จ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํฉ๋๋ค. | |
2. ๋ชจ๋ธ, ๋ฐ์ดํฐ ์ธํธ, ํ ํฌ๋์ด์ , ๋ฐ์ดํฐ ์ฝ๋ ์ดํฐ, `compute_metrics` ํจ์์ ํจ๊ป [`Trainer`]์ ํ๋ จ ์ธ์๋ฅผ ์ ๋ฌํ์ธ์. | |
3. [`~Trainer.train`]์ ํธ์ถํ์ฌ ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ์ธ์. | |
```py | |
>>> training_args = TrainingArguments( | |
... output_dir="my_awesome_asr_mind_model", | |
... per_device_train_batch_size=8, | |
... gradient_accumulation_steps=2, | |
... learning_rate=1e-5, | |
... warmup_steps=500, | |
... max_steps=2000, | |
... gradient_checkpointing=True, | |
... fp16=True, | |
... group_by_length=True, | |
... evaluation_strategy="steps", | |
... per_device_eval_batch_size=8, | |
... save_steps=1000, | |
... eval_steps=1000, | |
... logging_steps=25, | |
... load_best_model_at_end=True, | |
... metric_for_best_model="wer", | |
... greater_is_better=False, | |
... push_to_hub=True, | |
... ) | |
>>> trainer = Trainer( | |
... model=model, | |
... args=training_args, | |
... train_dataset=encoded_minds["train"], | |
... eval_dataset=encoded_minds["test"], | |
... tokenizer=processor.feature_extractor, | |
... data_collator=data_collator, | |
... compute_metrics=compute_metrics, | |
... ) | |
>>> trainer.train() | |
``` | |
ํ๋ จ์ด ์๋ฃ๋๋ฉด ๋ชจ๋๊ฐ ๋ชจ๋ธ์ ์ฌ์ฉํ ์ ์๋๋ก [`~transformers.Trainer.push_to_hub`] ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ Hub์ ๊ณต์ ํ์ธ์: | |
```py | |
>>> trainer.push_to_hub() | |
``` | |
</pt> | |
</frameworkcontent> | |
<Tip> | |
์๋ ์์ฑ ์ธ์์ ์ํด ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ๋ ๋ ์์ธํ ์์ ๋ ์์ด ์๋ ์์ฑ ์ธ์์ ์ํ [๋ธ๋ก๊ทธ ํฌ์คํธ](https://huggingface.co/blog/fine-tune-wav2vec2-english)์ ๋ค๊ตญ์ด ์๋ ์์ฑ ์ธ์์ ์ํ [ํฌ์คํธ](https://huggingface.co/blog/fine-tune-xlsr-wav2vec2)๋ฅผ ์ฐธ์กฐํ์ธ์. | |
</Tip> | |
## ์ถ๋ก ํ๊ธฐ[[inference]] | |
์ข์์, ์ด์ ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ์ผ๋ ์ถ๋ก ์ ์ฌ์ฉํ ์ ์์ต๋๋ค! | |
์ถ๋ก ์ ์ฌ์ฉํ ์ค๋์ค ํ์ผ์ ๊ฐ์ ธ์ค์ธ์. ํ์ํ ๊ฒฝ์ฐ ์ค๋์ค ํ์ผ์ ์ํ๋ง ๋น์จ์ ๋ชจ๋ธ์ ์ํ๋ง ๋ ์ดํธ์ ๋ง๊ฒ ๋ฆฌ์ํ๋งํ๋ ๊ฒ์ ์์ง ๋ง์ธ์! | |
```py | |
>>> from datasets import load_dataset, Audio | |
>>> dataset = load_dataset("PolyAI/minds14", "en-US", split="train") | |
>>> dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) | |
>>> sampling_rate = dataset.features["audio"].sampling_rate | |
>>> audio_file = dataset[0]["audio"]["path"] | |
``` | |
์ถ๋ก ์ ์ํด ๋ฏธ์ธ ์กฐ์ ๋ ๋ชจ๋ธ์ ์ํํด๋ณด๋ ๊ฐ์ฅ ๊ฐ๋จํ ๋ฐฉ๋ฒ์ [`pipeline`]์ ์ฌ์ฉํ๋ ๊ฒ์ ๋๋ค. ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์๋ ์์ฑ ์ธ์์ ์ํ `pipeline`์ ์ธ์คํด์คํํ๊ณ ์ค๋์ค ํ์ผ์ ์ ๋ฌํ์ธ์: | |
```py | |
>>> from transformers import pipeline | |
>>> transcriber = pipeline("automatic-speech-recognition", model="stevhliu/my_awesome_asr_minds_model") | |
>>> transcriber(audio_file) | |
{'text': 'I WOUD LIKE O SET UP JOINT ACOUNT WTH Y PARTNER'} | |
``` | |
<Tip> | |
ํ ์คํธ๋ก ๋ณํ๋ ๊ฒฐ๊ณผ๊ฐ ๊ฝค ๊ด์ฐฎ์ง๋ง ๋ ์ข์ ์๋ ์์ต๋๋ค! ๋ ๋์ ๊ฒฐ๊ณผ๋ฅผ ์ป์ผ๋ ค๋ฉด ๋ ๋ง์ ์์ ๋ก ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ์ธ์! | |
</Tip> | |
`pipeline`์ ๊ฒฐ๊ณผ๋ฅผ ์๋์ผ๋ก ์ฌํํ ์๋ ์์ต๋๋ค: | |
<frameworkcontent> | |
<pt> | |
์ค๋์ค ํ์ผ๊ณผ ํ ์คํธ๋ฅผ ์ ์ฒ๋ฆฌํ๊ณ PyTorch ํ ์๋ก `input`์ ๋ฐํํ ํ๋ก์ธ์๋ฅผ ๊ฐ์ ธ์ค์ธ์: | |
```py | |
>>> from transformers import AutoProcessor | |
>>> processor = AutoProcessor.from_pretrained("stevhliu/my_awesome_asr_mind_model") | |
>>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") | |
``` | |
์ ๋ ฅ์ ๋ชจ๋ธ์ ์ ๋ฌํ๊ณ ๋ก์ง์ ๋ฐํํ์ธ์: | |
```py | |
>>> from transformers import AutoModelForCTC | |
>>> model = AutoModelForCTC.from_pretrained("stevhliu/my_awesome_asr_mind_model") | |
>>> with torch.no_grad(): | |
... logits = model(**inputs).logits | |
``` | |
๊ฐ์ฅ ๋์ ํ๋ฅ ์ `input_ids`๋ฅผ ์์ธกํ๊ณ , ํ๋ก์ธ์๋ฅผ ์ฌ์ฉํ์ฌ ์์ธก๋ `input_ids`๋ฅผ ๋ค์ ํ ์คํธ๋ก ๋์ฝ๋ฉํ์ธ์: | |
```py | |
>>> import torch | |
>>> predicted_ids = torch.argmax(logits, dim=-1) | |
>>> transcription = processor.batch_decode(predicted_ids) | |
>>> transcription | |
['I WOUL LIKE O SET UP JOINT ACOUNT WTH Y PARTNER'] | |
``` | |
</pt> | |
</frameworkcontent> |