Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		JustinLin610
		
	commited on
		
		
					Commit 
							
							·
						
						10b0761
	
1
								Parent(s):
							
							7f92696
								
update
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- fairseq/.github/ISSUE_TEMPLATE.md +3 -0
 - fairseq/.github/ISSUE_TEMPLATE/bug_report.md +43 -0
 - fairseq/.github/ISSUE_TEMPLATE/documentation.md +15 -0
 - fairseq/.github/ISSUE_TEMPLATE/feature_request.md +24 -0
 - fairseq/.github/ISSUE_TEMPLATE/how-to-question.md +33 -0
 - fairseq/.github/PULL_REQUEST_TEMPLATE.md +16 -0
 - fairseq/.github/stale.yml +30 -0
 - fairseq/.github/workflows/build.yml +55 -0
 - fairseq/.github/workflows/build_wheels.yml +41 -0
 - fairseq/.gitmodules +4 -0
 - fairseq/CODE_OF_CONDUCT.md +77 -0
 - fairseq/CONTRIBUTING.md +28 -0
 - fairseq/LICENSE +21 -0
 - fairseq/README.md +229 -0
 - fairseq/examples/__init__.py +9 -0
 - fairseq/examples/adaptive_span/README.md +90 -0
 - fairseq/examples/adaptive_span/__init__.py +19 -0
 - fairseq/examples/adaptive_span/adagrad_with_grad_clip.py +128 -0
 - fairseq/examples/adaptive_span/adaptive_span_attention.py +160 -0
 - fairseq/examples/adaptive_span/adaptive_span_loss.py +106 -0
 - fairseq/examples/adaptive_span/adaptive_span_model.py +263 -0
 - fairseq/examples/adaptive_span/adaptive_span_model_wrapper.py +145 -0
 - fairseq/examples/adaptive_span/truncated_bptt_lm_task.py +281 -0
 - fairseq/examples/backtranslation/README.md +297 -0
 - fairseq/examples/backtranslation/deduplicate_lines.py +41 -0
 - fairseq/examples/backtranslation/extract_bt_data.py +72 -0
 - fairseq/examples/backtranslation/prepare-de-monolingual.sh +98 -0
 - fairseq/examples/backtranslation/prepare-wmt18en2de.sh +135 -0
 - fairseq/examples/backtranslation/sacrebleu.sh +37 -0
 - fairseq/examples/backtranslation/tokenized_bleu.sh +46 -0
 - fairseq/examples/bart/README.glue.md +99 -0
 - fairseq/examples/bart/README.md +228 -0
 - fairseq/examples/bart/README.summarization.md +102 -0
 - fairseq/examples/bart/summarize.py +100 -0
 - fairseq/examples/byte_level_bpe/README.md +88 -0
 - fairseq/examples/byte_level_bpe/get_bitext.py +254 -0
 - fairseq/examples/byte_level_bpe/get_data.sh +47 -0
 - fairseq/examples/byte_level_bpe/gru_transformer.py +107 -0
 - fairseq/examples/camembert/README.md +75 -0
 - fairseq/examples/constrained_decoding/README.md +123 -0
 - fairseq/examples/constrained_decoding/normalize.py +27 -0
 - fairseq/examples/constrained_decoding/tok.py +34 -0
 - fairseq/examples/conv_seq2seq/README.md +25 -0
 - fairseq/examples/criss/README.md +61 -0
 - fairseq/examples/criss/download_and_preprocess_flores_test.sh +64 -0
 - fairseq/examples/criss/download_and_preprocess_tatoeba.sh +46 -0
 - fairseq/examples/criss/mining/mine.py +240 -0
 - fairseq/examples/criss/mining/mine_example.sh +103 -0
 - fairseq/examples/criss/save_encoder.py +214 -0
 - fairseq/examples/criss/sentence_retrieval/encoder_analysis.py +92 -0
 
    	
        fairseq/.github/ISSUE_TEMPLATE.md
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ## 👉 [Please follow one of these issue templates](https://github.com/pytorch/fairseq/issues/new/choose) 👈
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            Note: to keep the backlog clean and actionable, issues may be immediately closed if they do not follow one of the above issue templates.
         
     | 
    	
        fairseq/.github/ISSUE_TEMPLATE/bug_report.md
    ADDED
    
    | 
         @@ -0,0 +1,43 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ---
         
     | 
| 2 | 
         
            +
            name: 🐛 Bug Report
         
     | 
| 3 | 
         
            +
            about: Submit a bug report to help us improve
         
     | 
| 4 | 
         
            +
            labels: 'bug, needs triage'
         
     | 
| 5 | 
         
            +
            ---
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            ## 🐛 Bug
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            <!-- A clear and concise description of what the bug is. -->
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            ### To Reproduce
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            Steps to reproduce the behavior (**always include the command you ran**):
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            1. Run cmd '....'
         
     | 
| 16 | 
         
            +
            2. See error
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            <!-- If you have a code sample, error messages, stack traces, please provide it here as well -->
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            #### Code sample
         
     | 
| 22 | 
         
            +
            <!-- Ideally attach a minimal code sample to reproduce the decried issue.
         
     | 
| 23 | 
         
            +
            Minimal means having the shortest code but still preserving the bug. -->
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            ### Expected behavior
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            <!-- A clear and concise description of what you expected to happen. -->
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            ### Environment
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
             - fairseq Version (e.g., 1.0 or main):
         
     | 
| 32 | 
         
            +
             - PyTorch Version (e.g., 1.0)
         
     | 
| 33 | 
         
            +
             - OS (e.g., Linux):
         
     | 
| 34 | 
         
            +
             - How you installed fairseq (`pip`, source):
         
     | 
| 35 | 
         
            +
             - Build command you used (if compiling from source):
         
     | 
| 36 | 
         
            +
             - Python version:
         
     | 
| 37 | 
         
            +
             - CUDA/cuDNN version:
         
     | 
| 38 | 
         
            +
             - GPU models and configuration:
         
     | 
| 39 | 
         
            +
             - Any other relevant information:
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            ### Additional context
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            <!-- Add any other context about the problem here. -->
         
     | 
    	
        fairseq/.github/ISSUE_TEMPLATE/documentation.md
    ADDED
    
    | 
         @@ -0,0 +1,15 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ---
         
     | 
| 2 | 
         
            +
            name: 📚 Documentation/Typos
         
     | 
| 3 | 
         
            +
            about: Report an issue related to documentation or a typo
         
     | 
| 4 | 
         
            +
            labels: 'documentation, needs triage'
         
     | 
| 5 | 
         
            +
            ---
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            ## 📚 Documentation
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            For typos and doc fixes, please go ahead and:
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            1. Create an issue.
         
     | 
| 12 | 
         
            +
            2. Fix the typo.
         
     | 
| 13 | 
         
            +
            3. Submit a PR.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            Thanks!
         
     | 
    	
        fairseq/.github/ISSUE_TEMPLATE/feature_request.md
    ADDED
    
    | 
         @@ -0,0 +1,24 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ---
         
     | 
| 2 | 
         
            +
            name: 🚀 Feature Request
         
     | 
| 3 | 
         
            +
            about: Submit a proposal/request for a new feature
         
     | 
| 4 | 
         
            +
            labels: 'enhancement, help wanted, needs triage'
         
     | 
| 5 | 
         
            +
            ---
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            ## 🚀 Feature Request
         
     | 
| 8 | 
         
            +
            <!-- A clear and concise description of the feature proposal -->
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            ### Motivation
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            <!-- Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too -->
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            ### Pitch
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            <!-- A clear and concise description of what you want to happen. -->
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            ### Alternatives
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            <!-- A clear and concise description of any alternative solutions or features you've considered, if any. -->
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            ### Additional context
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            <!-- Add any other context or screenshots about the feature request here. -->
         
     | 
    	
        fairseq/.github/ISSUE_TEMPLATE/how-to-question.md
    ADDED
    
    | 
         @@ -0,0 +1,33 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ---
         
     | 
| 2 | 
         
            +
            name: ❓ Questions/Help
         
     | 
| 3 | 
         
            +
            about: If you have questions, please first search existing issues and docs
         
     | 
| 4 | 
         
            +
            labels: 'question, needs triage'
         
     | 
| 5 | 
         
            +
            ---
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            ## ❓ Questions and Help
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            ### Before asking:
         
     | 
| 10 | 
         
            +
            1. search the issues.
         
     | 
| 11 | 
         
            +
            2. search the docs.
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            <!-- If you still can't find what you need: -->
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            #### What is your question?
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            #### Code
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            <!-- Please paste a code snippet if your question requires it! -->
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            #### What have you tried?
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            #### What's your environment?
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
             - fairseq Version (e.g., 1.0 or main):
         
     | 
| 26 | 
         
            +
             - PyTorch Version (e.g., 1.0)
         
     | 
| 27 | 
         
            +
             - OS (e.g., Linux):
         
     | 
| 28 | 
         
            +
             - How you installed fairseq (`pip`, source):
         
     | 
| 29 | 
         
            +
             - Build command you used (if compiling from source):
         
     | 
| 30 | 
         
            +
             - Python version:
         
     | 
| 31 | 
         
            +
             - CUDA/cuDNN version:
         
     | 
| 32 | 
         
            +
             - GPU models and configuration:
         
     | 
| 33 | 
         
            +
             - Any other relevant information:
         
     | 
    	
        fairseq/.github/PULL_REQUEST_TEMPLATE.md
    ADDED
    
    | 
         @@ -0,0 +1,16 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Before submitting
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
         
     | 
| 4 | 
         
            +
            - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)?
         
     | 
| 5 | 
         
            +
            - [ ] Did you make sure to update the docs?
         
     | 
| 6 | 
         
            +
            - [ ] Did you write any new necessary tests?
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            ## What does this PR do?
         
     | 
| 9 | 
         
            +
            Fixes # (issue).
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            ## PR review
         
     | 
| 12 | 
         
            +
            Anyone in the community is free to review the PR once the tests have passed.
         
     | 
| 13 | 
         
            +
            If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            ## Did you have fun?
         
     | 
| 16 | 
         
            +
            Make sure you had fun coding 🙃
         
     | 
    	
        fairseq/.github/stale.yml
    ADDED
    
    | 
         @@ -0,0 +1,30 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Configuration for probot-stale - https://github.com/probot/stale
         
     | 
| 2 | 
         
            +
            # Mostly copied from github.com/facebook/react/blob/master/.github/stale.yml
         
     | 
| 3 | 
         
            +
            # Number of days of inactivity before an issue becomes stale
         
     | 
| 4 | 
         
            +
            daysUntilStale: 90
         
     | 
| 5 | 
         
            +
            # Number of days of inactivity before a stale issue is closed
         
     | 
| 6 | 
         
            +
            daysUntilClose: 7
         
     | 
| 7 | 
         
            +
            # Issues with these labels will never be considered stale
         
     | 
| 8 | 
         
            +
            exemptLabels:
         
     | 
| 9 | 
         
            +
              - bug
         
     | 
| 10 | 
         
            +
            # Label to use when marking an issue as stale
         
     | 
| 11 | 
         
            +
            staleLabel: stale
         
     | 
| 12 | 
         
            +
            issues:
         
     | 
| 13 | 
         
            +
              # Comment to post when marking an issue as stale.
         
     | 
| 14 | 
         
            +
              markComment: >
         
     | 
| 15 | 
         
            +
                This issue has been automatically marked as stale.
         
     | 
| 16 | 
         
            +
                **If this issue is still affecting you, please leave any comment** (for example, "bump"), and we'll keep it open.
         
     | 
| 17 | 
         
            +
                We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment!
         
     | 
| 18 | 
         
            +
              # Comment to post when closing a stale issue.
         
     | 
| 19 | 
         
            +
              closeComment: >
         
     | 
| 20 | 
         
            +
                Closing this issue after a prolonged period of inactivity. If this issue is still present in the latest release, please create a new issue with up-to-date information. Thank you!
         
     | 
| 21 | 
         
            +
            pulls:
         
     | 
| 22 | 
         
            +
              # Comment to post when marking a pull request as stale.
         
     | 
| 23 | 
         
            +
              markComment: >
         
     | 
| 24 | 
         
            +
                This pull request has been automatically marked as stale.
         
     | 
| 25 | 
         
            +
                **If this pull request is still relevant, please leave any comment** (for example, "bump"), and we'll keep it open.
         
     | 
| 26 | 
         
            +
                We are sorry that we haven't been able to prioritize reviewing it yet. Your contribution is very much appreciated.
         
     | 
| 27 | 
         
            +
              # Comment to post when closing a stale pull request.
         
     | 
| 28 | 
         
            +
              closeComment: >
         
     | 
| 29 | 
         
            +
                Closing this pull request after a prolonged period of inactivity. If this issue is still present in the latest release, please ask for this pull request to be reopened. Thank you!
         
     | 
| 30 | 
         
            +
             
     | 
    	
        fairseq/.github/workflows/build.yml
    ADDED
    
    | 
         @@ -0,0 +1,55 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            name: build
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            on:
         
     | 
| 4 | 
         
            +
              # Trigger the workflow on push to main or any pull request
         
     | 
| 5 | 
         
            +
              push:
         
     | 
| 6 | 
         
            +
                branches:
         
     | 
| 7 | 
         
            +
                  - main
         
     | 
| 8 | 
         
            +
              pull_request:
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            jobs:
         
     | 
| 11 | 
         
            +
              build:
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                strategy:
         
     | 
| 14 | 
         
            +
                  max-parallel: 4
         
     | 
| 15 | 
         
            +
                  matrix:
         
     | 
| 16 | 
         
            +
                    platform: [ubuntu-latest, macos-latest]
         
     | 
| 17 | 
         
            +
                    python-version: [3.6, 3.7]
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                runs-on: ${{ matrix.platform }}
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                steps:
         
     | 
| 22 | 
         
            +
                - uses: actions/checkout@v2
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                - name: Set up Python ${{ matrix.python-version }}
         
     | 
| 25 | 
         
            +
                  uses: actions/setup-python@v2
         
     | 
| 26 | 
         
            +
                  with:
         
     | 
| 27 | 
         
            +
                    python-version: ${{ matrix.python-version }}
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                - name: Conditionally install pytorch
         
     | 
| 30 | 
         
            +
                  if: matrix.platform == 'windows-latest'
         
     | 
| 31 | 
         
            +
                  run: pip3 install torch -f https://download.pytorch.org/whl/torch_stable.html
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                - name: Install locally
         
     | 
| 34 | 
         
            +
                  run: |
         
     | 
| 35 | 
         
            +
                    python -m pip install --upgrade pip
         
     | 
| 36 | 
         
            +
                    git submodule update --init --recursive
         
     | 
| 37 | 
         
            +
                    python setup.py build_ext --inplace
         
     | 
| 38 | 
         
            +
                    python -m pip install --editable .
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                - name: Install optional test requirements
         
     | 
| 41 | 
         
            +
                  run: |
         
     | 
| 42 | 
         
            +
                    python -m pip install iopath transformers pyarrow
         
     | 
| 43 | 
         
            +
                    python -m pip install git+https://github.com/facebookresearch/fairscale.git@main
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                - name: Lint with flake8
         
     | 
| 46 | 
         
            +
                  run: |
         
     | 
| 47 | 
         
            +
                    pip install flake8
         
     | 
| 48 | 
         
            +
                    # stop the build if there are Python syntax errors or undefined names
         
     | 
| 49 | 
         
            +
                    flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --extend-exclude fairseq/model_parallel/megatron
         
     | 
| 50 | 
         
            +
                    # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
         
     | 
| 51 | 
         
            +
                    flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --extend-exclude fairseq/model_parallel/megatron
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                - name: Run tests
         
     | 
| 54 | 
         
            +
                  run: |
         
     | 
| 55 | 
         
            +
                      python setup.py test
         
     | 
    	
        fairseq/.github/workflows/build_wheels.yml
    ADDED
    
    | 
         @@ -0,0 +1,41 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            name: build_wheels
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            on:
         
     | 
| 4 | 
         
            +
              push:
         
     | 
| 5 | 
         
            +
                branches:
         
     | 
| 6 | 
         
            +
                  - v[0-9]+.[0-9]+.[x0-9]+
         
     | 
| 7 | 
         
            +
                tags:
         
     | 
| 8 | 
         
            +
                  - v*
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            jobs:
         
     | 
| 11 | 
         
            +
              build_wheels:
         
     | 
| 12 | 
         
            +
                name: Build wheels on ${{ matrix.os }}
         
     | 
| 13 | 
         
            +
                runs-on: ${{ matrix.os }}
         
     | 
| 14 | 
         
            +
                strategy:
         
     | 
| 15 | 
         
            +
                  matrix:
         
     | 
| 16 | 
         
            +
                    os: [ubuntu-latest, macos-latest]
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                steps:
         
     | 
| 19 | 
         
            +
                  - uses: actions/checkout@v2
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                  - name: Install Python
         
     | 
| 22 | 
         
            +
                    uses: actions/setup-python@v2
         
     | 
| 23 | 
         
            +
                    with:
         
     | 
| 24 | 
         
            +
                      python-version: '3.7'
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                  - name: Install cibuildwheel
         
     | 
| 27 | 
         
            +
                    run: |
         
     | 
| 28 | 
         
            +
                      python -m pip install cibuildwheel
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                  - name: Build wheels for CPython
         
     | 
| 31 | 
         
            +
                    run: |
         
     | 
| 32 | 
         
            +
                      python -m cibuildwheel --output-dir dist
         
     | 
| 33 | 
         
            +
                    env:
         
     | 
| 34 | 
         
            +
                      CIBW_BUILD: "cp36-*64 cp37-*64 cp38-*64"
         
     | 
| 35 | 
         
            +
                      CIBW_MANYLINUX_X86_64_IMAGE: manylinux1
         
     | 
| 36 | 
         
            +
                      CIBW_BEFORE_BUILD: git submodule update --init --recursive && pip install .
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                  - uses: actions/upload-artifact@v2
         
     | 
| 39 | 
         
            +
                    with:
         
     | 
| 40 | 
         
            +
                      name: wheels
         
     | 
| 41 | 
         
            +
                      path: ./dist/*.whl
         
     | 
    	
        fairseq/.gitmodules
    ADDED
    
    | 
         @@ -0,0 +1,4 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            [submodule "fairseq/model_parallel/megatron"]
         
     | 
| 2 | 
         
            +
                path = fairseq/model_parallel/megatron
         
     | 
| 3 | 
         
            +
                url = https://github.com/ngoyal2707/Megatron-LM
         
     | 
| 4 | 
         
            +
                branch = fairseq
         
     | 
    	
        fairseq/CODE_OF_CONDUCT.md
    ADDED
    
    | 
         @@ -0,0 +1,77 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Code of Conduct
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            ## Our Pledge
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            In the interest of fostering an open and welcoming environment, we as
         
     | 
| 6 | 
         
            +
            contributors and maintainers pledge to make participation in our project and
         
     | 
| 7 | 
         
            +
            our community a harassment-free experience for everyone, regardless of age, body
         
     | 
| 8 | 
         
            +
            size, disability, ethnicity, sex characteristics, gender identity and expression,
         
     | 
| 9 | 
         
            +
            level of experience, education, socio-economic status, nationality, personal
         
     | 
| 10 | 
         
            +
            appearance, race, religion, or sexual identity and orientation.
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            ## Our Standards
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            Examples of behavior that contributes to creating a positive environment
         
     | 
| 15 | 
         
            +
            include:
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            * Using welcoming and inclusive language
         
     | 
| 18 | 
         
            +
            * Being respectful of differing viewpoints and experiences
         
     | 
| 19 | 
         
            +
            * Gracefully accepting constructive criticism
         
     | 
| 20 | 
         
            +
            * Focusing on what is best for the community
         
     | 
| 21 | 
         
            +
            * Showing empathy towards other community members
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            Examples of unacceptable behavior by participants include:
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            * The use of sexualized language or imagery and unwelcome sexual attention or
         
     | 
| 26 | 
         
            +
              advances
         
     | 
| 27 | 
         
            +
            * Trolling, insulting/derogatory comments, and personal or political attacks
         
     | 
| 28 | 
         
            +
            * Public or private harassment
         
     | 
| 29 | 
         
            +
            * Publishing others' private information, such as a physical or electronic
         
     | 
| 30 | 
         
            +
              address, without explicit permission
         
     | 
| 31 | 
         
            +
            * Other conduct which could reasonably be considered inappropriate in a
         
     | 
| 32 | 
         
            +
              professional setting
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            ## Our Responsibilities
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            Project maintainers are responsible for clarifying the standards of acceptable
         
     | 
| 37 | 
         
            +
            behavior and are expected to take appropriate and fair corrective action in
         
     | 
| 38 | 
         
            +
            response to any instances of unacceptable behavior.
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            Project maintainers have the right and responsibility to remove, edit, or
         
     | 
| 41 | 
         
            +
            reject comments, commits, code, wiki edits, issues, and other contributions
         
     | 
| 42 | 
         
            +
            that are not aligned to this Code of Conduct, or to ban temporarily or
         
     | 
| 43 | 
         
            +
            permanently any contributor for other behaviors that they deem inappropriate,
         
     | 
| 44 | 
         
            +
            threatening, offensive, or harmful.
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            ## Scope
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            This Code of Conduct applies within all project spaces, and it also applies when
         
     | 
| 49 | 
         
            +
            an individual is representing the project or its community in public spaces.
         
     | 
| 50 | 
         
            +
            Examples of representing a project or community include using an official
         
     | 
| 51 | 
         
            +
            project e-mail address, posting via an official social media account, or acting
         
     | 
| 52 | 
         
            +
            as an appointed representative at an online or offline event. Representation of
         
     | 
| 53 | 
         
            +
            a project may be further defined and clarified by project maintainers.
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
            ## Enforcement
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
            Instances of abusive, harassing, or otherwise unacceptable behavior may be
         
     | 
| 58 | 
         
            +
            reported by contacting the project team at <[email protected]>. All
         
     | 
| 59 | 
         
            +
            complaints will be reviewed and investigated and will result in a response that
         
     | 
| 60 | 
         
            +
            is deemed necessary and appropriate to the circumstances. The project team is
         
     | 
| 61 | 
         
            +
            obligated to maintain confidentiality with regard to the reporter of an incident.
         
     | 
| 62 | 
         
            +
            Further details of specific enforcement policies may be posted separately.
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            Project maintainers who do not follow or enforce the Code of Conduct in good
         
     | 
| 65 | 
         
            +
            faith may face temporary or permanent repercussions as determined by other
         
     | 
| 66 | 
         
            +
            members of the project's leadership.
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            ## Attribution
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
            This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
         
     | 
| 71 | 
         
            +
            available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
            [homepage]: https://www.contributor-covenant.org
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
            For answers to common questions about this code of conduct, see
         
     | 
| 76 | 
         
            +
            https://www.contributor-covenant.org/faq
         
     | 
| 77 | 
         
            +
             
     | 
    	
        fairseq/CONTRIBUTING.md
    ADDED
    
    | 
         @@ -0,0 +1,28 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq)
         
     | 
| 2 | 
         
            +
            We want to make contributing to this project as easy and transparent as
         
     | 
| 3 | 
         
            +
            possible.
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            ## Pull Requests
         
     | 
| 6 | 
         
            +
            We actively welcome your pull requests.
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            1. Fork the repo and create your branch from `main`.
         
     | 
| 9 | 
         
            +
            2. If you've added code that should be tested, add tests.
         
     | 
| 10 | 
         
            +
            3. If you've changed APIs, update the documentation.
         
     | 
| 11 | 
         
            +
            4. Ensure the test suite passes.
         
     | 
| 12 | 
         
            +
            5. Make sure your code lints.
         
     | 
| 13 | 
         
            +
            6. If you haven't already, complete the Contributor License Agreement ("CLA").
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            ## Contributor License Agreement ("CLA")
         
     | 
| 16 | 
         
            +
            In order to accept your pull request, we need you to submit a CLA. You only need
         
     | 
| 17 | 
         
            +
            to do this once to work on any of Facebook's open source projects.
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            Complete your CLA here: <https://code.facebook.com/cla>
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            ## Issues
         
     | 
| 22 | 
         
            +
            We use GitHub issues to track public bugs. Please ensure your description is
         
     | 
| 23 | 
         
            +
            clear and has sufficient instructions to be able to reproduce the issue.
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            ## License
         
     | 
| 26 | 
         
            +
            By contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq),
         
     | 
| 27 | 
         
            +
            you agree that your contributions will be licensed under the LICENSE file in
         
     | 
| 28 | 
         
            +
            the root directory of this source tree.
         
     | 
    	
        fairseq/LICENSE
    ADDED
    
    | 
         @@ -0,0 +1,21 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            MIT License
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            Permission is hereby granted, free of charge, to any person obtaining a copy
         
     | 
| 6 | 
         
            +
            of this software and associated documentation files (the "Software"), to deal
         
     | 
| 7 | 
         
            +
            in the Software without restriction, including without limitation the rights
         
     | 
| 8 | 
         
            +
            to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         
     | 
| 9 | 
         
            +
            copies of the Software, and to permit persons to whom the Software is
         
     | 
| 10 | 
         
            +
            furnished to do so, subject to the following conditions:
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            The above copyright notice and this permission notice shall be included in all
         
     | 
| 13 | 
         
            +
            copies or substantial portions of the Software.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         
     | 
| 16 | 
         
            +
            IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         
     | 
| 17 | 
         
            +
            FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         
     | 
| 18 | 
         
            +
            AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         
     | 
| 19 | 
         
            +
            LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         
     | 
| 20 | 
         
            +
            OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         
     | 
| 21 | 
         
            +
            SOFTWARE.
         
     | 
    	
        fairseq/README.md
    ADDED
    
    | 
         @@ -0,0 +1,229 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            <p align="center">
         
     | 
| 2 | 
         
            +
              <img src="docs/fairseq_logo.png" width="150">
         
     | 
| 3 | 
         
            +
              <br />
         
     | 
| 4 | 
         
            +
              <br />
         
     | 
| 5 | 
         
            +
              <a href="https://github.com/pytorch/fairseq/blob/main/LICENSE"><img alt="MIT License" src="https://img.shields.io/badge/license-MIT-blue.svg" /></a>
         
     | 
| 6 | 
         
            +
              <a href="https://github.com/pytorch/fairseq/releases"><img alt="Latest Release" src="https://img.shields.io/github/release/pytorch/fairseq.svg" /></a>
         
     | 
| 7 | 
         
            +
              <a href="https://github.com/pytorch/fairseq/actions?query=workflow:build"><img alt="Build Status" src="https://github.com/pytorch/fairseq/workflows/build/badge.svg" /></a>
         
     | 
| 8 | 
         
            +
              <a href="https://fairseq.readthedocs.io/en/latest/?badge=latest"><img alt="Documentation Status" src="https://readthedocs.org/projects/fairseq/badge/?version=latest" /></a>
         
     | 
| 9 | 
         
            +
            </p>
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            --------------------------------------------------------------------------------
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            Fairseq(-py) is a sequence modeling toolkit that allows researchers and
         
     | 
| 14 | 
         
            +
            developers to train custom models for translation, summarization, language
         
     | 
| 15 | 
         
            +
            modeling and other text generation tasks.
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            We provide reference implementations of various sequence modeling papers:
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            <details><summary>List of implemented papers</summary><p>
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            * **Convolutional Neural Networks (CNN)**
         
     | 
| 22 | 
         
            +
              + [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md)
         
     | 
| 23 | 
         
            +
              + [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
         
     | 
| 24 | 
         
            +
              + [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
         
     | 
| 25 | 
         
            +
              + [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
         
     | 
| 26 | 
         
            +
              + [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
         
     | 
| 27 | 
         
            +
            * **LightConv and DynamicConv models**
         
     | 
| 28 | 
         
            +
              + [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
         
     | 
| 29 | 
         
            +
            * **Long Short-Term Memory (LSTM) networks**
         
     | 
| 30 | 
         
            +
              + Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015)
         
     | 
| 31 | 
         
            +
            * **Transformer (self-attention) networks**
         
     | 
| 32 | 
         
            +
              + Attention Is All You Need (Vaswani et al., 2017)
         
     | 
| 33 | 
         
            +
              + [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
         
     | 
| 34 | 
         
            +
              + [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
         
     | 
| 35 | 
         
            +
              + [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md)
         
     | 
| 36 | 
         
            +
              + [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md)
         
     | 
| 37 | 
         
            +
              + [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context (Dai et al., 2019)](examples/truncated_bptt/README.md)
         
     | 
| 38 | 
         
            +
              + [Adaptive Attention Span in Transformers (Sukhbaatar et al., 2019)](examples/adaptive_span/README.md)
         
     | 
| 39 | 
         
            +
              + [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
         
     | 
| 40 | 
         
            +
              + [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
         
     | 
| 41 | 
         
            +
              + [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
         
     | 
| 42 | 
         
            +
              + [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md )
         
     | 
| 43 | 
         
            +
              + [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
         
     | 
| 44 | 
         
            +
              + [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
         
     | 
| 45 | 
         
            +
              + [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
         
     | 
| 46 | 
         
            +
              + [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
         
     | 
| 47 | 
         
            +
              + [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md)
         
     | 
| 48 | 
         
            +
              + [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md)
         
     | 
| 49 | 
         
            +
              + [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
         
     | 
| 50 | 
         
            +
              + [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md)
         
     | 
| 51 | 
         
            +
              + [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979)
         
     | 
| 52 | 
         
            +
              + [Robust wav2vec 2.0: Analyzing Domain Shift in Self-Supervised Pre-Training (Hsu, et al., 2021)](https://arxiv.org/abs/2104.01027)
         
     | 
| 53 | 
         
            +
              + [Unsupervised Speech Recognition (Baevski, et al., 2021)](https://arxiv.org/abs/2105.11084)
         
     | 
| 54 | 
         
            +
            * **Non-autoregressive Transformers**
         
     | 
| 55 | 
         
            +
              + Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
         
     | 
| 56 | 
         
            +
              + Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
         
     | 
| 57 | 
         
            +
              + Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019)
         
     | 
| 58 | 
         
            +
              + Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019)
         
     | 
| 59 | 
         
            +
              + [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
         
     | 
| 60 | 
         
            +
            * **Finetuning**
         
     | 
| 61 | 
         
            +
              + [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md)
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            </p></details>
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            ### What's New:
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            * September 2021 [`master` branch renamed to `main`](https://github.com/github/renaming).
         
     | 
| 68 | 
         
            +
            * July 2021 [Released DrNMT code](examples/discriminative_reranking_nmt/README.md)
         
     | 
| 69 | 
         
            +
            * July 2021 [Released Robust wav2vec 2.0 model](examples/wav2vec/README.md)
         
     | 
| 70 | 
         
            +
            * June 2021 [Released XLMR-XL and XLMR-XXL models](examples/xlmr/README.md)
         
     | 
| 71 | 
         
            +
            * May 2021 [Released Unsupervised Speech Recognition code](examples/wav2vec/unsupervised/README.md)
         
     | 
| 72 | 
         
            +
            * March 2021 [Added full parameter and optimizer state sharding + CPU offloading](examples/fully_sharded_data_parallel/README.md)
         
     | 
| 73 | 
         
            +
            * February 2021 [Added LASER training code](examples/laser/README.md)
         
     | 
| 74 | 
         
            +
            * December 2020: [Added Adaptive Attention Span code](examples/adaptive_span/README.md)
         
     | 
| 75 | 
         
            +
            * December 2020: [GottBERT model and code released](examples/gottbert/README.md)
         
     | 
| 76 | 
         
            +
            * November 2020: Adopted the [Hydra](https://github.com/facebookresearch/hydra) configuration framework
         
     | 
| 77 | 
         
            +
              * [see documentation explaining how to use it for new and existing projects](docs/hydra_integration.md)
         
     | 
| 78 | 
         
            +
            * November 2020: [fairseq 0.10.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.10.0)
         
     | 
| 79 | 
         
            +
            * October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md)
         
     | 
| 80 | 
         
            +
            * October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md)
         
     | 
| 81 | 
         
            +
            * October 2020: [Added CRISS models and code](examples/criss/README.md)
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
            <details><summary>Previous updates</summary><p>
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
            * September 2020: [Added Linformer code](examples/linformer/README.md)
         
     | 
| 86 | 
         
            +
            * September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md)
         
     | 
| 87 | 
         
            +
            * August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md)
         
     | 
| 88 | 
         
            +
            * August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md)
         
     | 
| 89 | 
         
            +
            * July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md)
         
     | 
| 90 | 
         
            +
            * May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq)
         
     | 
| 91 | 
         
            +
            * April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md)
         
     | 
| 92 | 
         
            +
            * April 2020: [Quant-Noise code released](examples/quant_noise/README.md)
         
     | 
| 93 | 
         
            +
            * April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md)
         
     | 
| 94 | 
         
            +
            * March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md)
         
     | 
| 95 | 
         
            +
            * February 2020: [mBART model and code released](examples/mbart/README.md)
         
     | 
| 96 | 
         
            +
            * February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/main/examples/backtranslation#training-your-own-model-wmt18-english-german)
         
     | 
| 97 | 
         
            +
            * December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0)
         
     | 
| 98 | 
         
            +
            * November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example)
         
     | 
| 99 | 
         
            +
            * November 2019: [CamemBERT model and code released](examples/camembert/README.md)
         
     | 
| 100 | 
         
            +
            * November 2019: [BART model and code released](examples/bart/README.md)
         
     | 
| 101 | 
         
            +
            * November 2019: [XLM-R models and code released](examples/xlmr/README.md)
         
     | 
| 102 | 
         
            +
            * September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md)
         
     | 
| 103 | 
         
            +
            * August 2019: [WMT'19 models released](examples/wmt19/README.md)
         
     | 
| 104 | 
         
            +
            * July 2019: fairseq relicensed under MIT license
         
     | 
| 105 | 
         
            +
            * July 2019: [RoBERTa models and code released](examples/roberta/README.md)
         
     | 
| 106 | 
         
            +
            * June 2019: [wav2vec models and code released](examples/wav2vec/README.md)
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
            </p></details>
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
            ### Features:
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
            * multi-GPU training on one machine or across multiple machines (data and model parallel)
         
     | 
| 113 | 
         
            +
            * fast generation on both CPU and GPU with multiple search algorithms implemented:
         
     | 
| 114 | 
         
            +
              + beam search
         
     | 
| 115 | 
         
            +
              + Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424))
         
     | 
| 116 | 
         
            +
              + sampling (unconstrained, top-k and top-p/nucleus)
         
     | 
| 117 | 
         
            +
              + [lexically constrained decoding](examples/constrained_decoding/README.md) (Post & Vilar, 2018)
         
     | 
| 118 | 
         
            +
            * [gradient accumulation](https://fairseq.readthedocs.io/en/latest/getting_started.html#large-mini-batch-training-with-delayed-updates) enables training with large mini-batches even on a single GPU
         
     | 
| 119 | 
         
            +
            * [mixed precision training](https://fairseq.readthedocs.io/en/latest/getting_started.html#training-with-half-precision-floating-point-fp16) (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores))
         
     | 
| 120 | 
         
            +
            * [extensible](https://fairseq.readthedocs.io/en/latest/overview.html): easily register new models, criterions, tasks, optimizers and learning rate schedulers
         
     | 
| 121 | 
         
            +
            * [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration
         
     | 
| 122 | 
         
            +
            * [full parameter and optimizer state sharding](examples/fully_sharded_data_parallel/README.md)
         
     | 
| 123 | 
         
            +
            * [offloading parameters to CPU](examples/fully_sharded_data_parallel/README.md)
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
            We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples)
         
     | 
| 126 | 
         
            +
            with a convenient `torch.hub` interface:
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
            ``` python
         
     | 
| 129 | 
         
            +
            en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model')
         
     | 
| 130 | 
         
            +
            en2de.translate('Hello world', beam=5)
         
     | 
| 131 | 
         
            +
            # 'Hallo Welt'
         
     | 
| 132 | 
         
            +
            ```
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
            See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/)
         
     | 
| 135 | 
         
            +
            and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples.
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
            # Requirements and Installation
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
            * [PyTorch](http://pytorch.org/) version >= 1.5.0
         
     | 
| 140 | 
         
            +
            * Python version >= 3.6
         
     | 
| 141 | 
         
            +
            * For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
         
     | 
| 142 | 
         
            +
            * **To install fairseq** and develop locally:
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
            ``` bash
         
     | 
| 145 | 
         
            +
            git clone https://github.com/pytorch/fairseq
         
     | 
| 146 | 
         
            +
            cd fairseq
         
     | 
| 147 | 
         
            +
            pip install --editable ./
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
            # on MacOS:
         
     | 
| 150 | 
         
            +
            # CFLAGS="-stdlib=libc++" pip install --editable ./
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
            # to install the latest stable release (0.10.x)
         
     | 
| 153 | 
         
            +
            # pip install fairseq
         
     | 
| 154 | 
         
            +
            ```
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
            * **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library:
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
            ``` bash
         
     | 
| 159 | 
         
            +
            git clone https://github.com/NVIDIA/apex
         
     | 
| 160 | 
         
            +
            cd apex
         
     | 
| 161 | 
         
            +
            pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \
         
     | 
| 162 | 
         
            +
              --global-option="--deprecated_fused_adam" --global-option="--xentropy" \
         
     | 
| 163 | 
         
            +
              --global-option="--fast_multihead_attn" ./
         
     | 
| 164 | 
         
            +
            ```
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
            * **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow`
         
     | 
| 167 | 
         
            +
            * If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size`
         
     | 
| 168 | 
         
            +
             as command line options to `nvidia-docker run` .
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
            # Getting Started
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
            The [full documentation](https://fairseq.readthedocs.io/) contains instructions
         
     | 
| 173 | 
         
            +
            for getting started, training new models and extending fairseq with new model
         
     | 
| 174 | 
         
            +
            types and tasks.
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
            # Pre-trained models and examples
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
            We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below,
         
     | 
| 179 | 
         
            +
            as well as example training and evaluation commands.
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
            * [Translation](examples/translation/README.md): convolutional and transformer models are available
         
     | 
| 182 | 
         
            +
            * [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
            We also have more detailed READMEs to reproduce results from specific papers:
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
            * [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
         
     | 
| 187 | 
         
            +
            * [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
         
     | 
| 188 | 
         
            +
            * [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
         
     | 
| 189 | 
         
            +
            * [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md)
         
     | 
| 190 | 
         
            +
            * [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
         
     | 
| 191 | 
         
            +
            * [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
         
     | 
| 192 | 
         
            +
            * [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md)
         
     | 
| 193 | 
         
            +
            * [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md)
         
     | 
| 194 | 
         
            +
            * [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
         
     | 
| 195 | 
         
            +
            * [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
         
     | 
| 196 | 
         
            +
            * [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
         
     | 
| 197 | 
         
            +
            * [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
         
     | 
| 198 | 
         
            +
            * [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
         
     | 
| 199 | 
         
            +
            * [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
         
     | 
| 200 | 
         
            +
            * [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
         
     | 
| 201 | 
         
            +
            * [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
         
     | 
| 202 | 
         
            +
            * [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
         
     | 
| 203 | 
         
            +
            * [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
         
     | 
| 204 | 
         
            +
            * [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
         
     | 
| 205 | 
         
            +
            * [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/README.conv.md)
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
            # Join the fairseq community
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
            * Twitter: https://twitter.com/fairseq
         
     | 
| 210 | 
         
            +
            * Facebook page: https://www.facebook.com/groups/fairseq.users
         
     | 
| 211 | 
         
            +
            * Google group: https://groups.google.com/forum/#!forum/fairseq-users
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
            # License
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
            fairseq(-py) is MIT-licensed.
         
     | 
| 216 | 
         
            +
            The license applies to the pre-trained models as well.
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
            # Citation
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
            Please cite as:
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
            ``` bibtex
         
     | 
| 223 | 
         
            +
            @inproceedings{ott2019fairseq,
         
     | 
| 224 | 
         
            +
              title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
         
     | 
| 225 | 
         
            +
              author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
         
     | 
| 226 | 
         
            +
              booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
         
     | 
| 227 | 
         
            +
              year = {2019},
         
     | 
| 228 | 
         
            +
            }
         
     | 
| 229 | 
         
            +
            ```
         
     | 
    	
        fairseq/examples/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,9 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            try:
         
     | 
| 7 | 
         
            +
                from fairseq.version import __version__  # noqa
         
     | 
| 8 | 
         
            +
            except ImportError:
         
     | 
| 9 | 
         
            +
                pass
         
     | 
    	
        fairseq/examples/adaptive_span/README.md
    ADDED
    
    | 
         @@ -0,0 +1,90 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Adaptive Span
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            Adaptive Span is a novel self-attention mechanism that can learn its optimal
         
     | 
| 4 | 
         
            +
            attention span. This allows us to extend significantly the maximum context size
         
     | 
| 5 | 
         
            +
            used in Transformer, while maintaining control over their memory footprint
         
     | 
| 6 | 
         
            +
            and computational time. It uses the Truncated BPTT technique for training,
         
     | 
| 7 | 
         
            +
            as in [transformerXL](https://github.com/pytorch/fairseq/blob/main/examples/truncated_bptt/README.md).
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            Adaptive Span was introduced by paper:
         
     | 
| 10 | 
         
            +
            [Adaptive Attention Span in Transformers](https://arxiv.org/abs/1905.07799),
         
     | 
| 11 | 
         
            +
            which achieved state-of-the-art language modeling results at the time of publication.
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            We manage to reproduce their result in fairseq and keep most of the
         
     | 
| 14 | 
         
            +
            [original implementation](https://github.com/facebookresearch/adaptive-span) untouched.
         
     | 
| 15 | 
         
            +
            You can refer to the their sweep file as well if any combination of hyperparameter is not clear.
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            ##### 0. Setup
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            First you need to process the Enwik8 dataset, we use the pre-tokenized dataset
         
     | 
| 20 | 
         
            +
            from [adaptive span paper](https://github.com/facebookresearch/adaptive-span/blob/master/get_data.sh).
         
     | 
| 21 | 
         
            +
            You can download the dataset, and then run:
         
     | 
| 22 | 
         
            +
            ```bash
         
     | 
| 23 | 
         
            +
            fairseq-preprocess --only-source --trainpref ~/data/enwik8/train.txt \
         
     | 
| 24 | 
         
            +
                --validpref ~/data/enwik8/valid.txt --testpref ~/data/enwik8/test.txt \
         
     | 
| 25 | 
         
            +
                --destdir ~/data/enwik8/data-bin/ --joined-dictionary --workers 20
         
     | 
| 26 | 
         
            +
            ```
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            ##### 1. Train a Adaptive Span model on Enwik8
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            We will train a 12-layer Adaptive Span model following the [hyperparameters
         
     | 
| 31 | 
         
            +
            used in the original
         
     | 
| 32 | 
         
            +
            paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh).
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            The following command assumes 4 GPUs, so that the total batch size is 64
         
     | 
| 35 | 
         
            +
            sequences (4 x 16). Training should take 2-3 days on 4 V100 GPUs:
         
     | 
| 36 | 
         
            +
            ```bash
         
     | 
| 37 | 
         
            +
            CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
         
     | 
| 38 | 
         
            +
                --user-dir examples/adaptive_span \
         
     | 
| 39 | 
         
            +
                --data  ~/data/enwik8/data-bin/ \
         
     | 
| 40 | 
         
            +
                --fp16 --fp16-no-flatten-grads --max-update 600000 \
         
     | 
| 41 | 
         
            +
                --task truncated_bptt_lm --tokens-per-sample 512 --arch adaptive_span \
         
     | 
| 42 | 
         
            +
                --n-layer 12 --d-model 512 --n-head 8 --d-inner 2048 --dropout 0.3 \
         
     | 
| 43 | 
         
            +
                --attn-span 8192 --optimizer adagrad_with_grad_clip --adagrad-clip 0.03 \
         
     | 
| 44 | 
         
            +
                --validate-interval-updates 1000 \
         
     | 
| 45 | 
         
            +
                --lr-scheduler fixed --warmup-updates 32000 --batch-size-valid 32 \
         
     | 
| 46 | 
         
            +
                --lr 0.07 --criterion adaptive_span_loss --batch-size 16 --update-freq 1 \
         
     | 
| 47 | 
         
            +
                --seed 2 --log-format json --log-interval 25 --aux-loss-scaler 5e-07
         
     | 
| 48 | 
         
            +
            ```
         
     | 
| 49 | 
         
            +
            This should land around 1.05 on validation, 1.03 on test. You can lower the
         
     | 
| 50 | 
         
            +
            --aux-loss-scaler for better performance (longer span). It gives ~0.03 bpc
         
     | 
| 51 | 
         
            +
            improvement to the transformerXL baseline here.
         
     | 
| 52 | 
         
            +
            If training on a single GPU, set `--update-freq=4` to accumulate 4x gradients
         
     | 
| 53 | 
         
            +
            and simulate training on 4 GPUs.
         
     | 
| 54 | 
         
            +
            You can also reproduce the transformerXL result on enwik8 using this code base.
         
     | 
| 55 | 
         
            +
            It should land around 1.06 on test,matching the [original paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_enwik8_base.sh).
         
     | 
| 56 | 
         
            +
            You can try by
         
     | 
| 57 | 
         
            +
            ```bash
         
     | 
| 58 | 
         
            +
            CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
         
     | 
| 59 | 
         
            +
                --user-dir examples/truncated_bptt \
         
     | 
| 60 | 
         
            +
                ~/data/enwik8/data-bin/ \
         
     | 
| 61 | 
         
            +
                --task truncated_bptt_lm  --fp16 --max-update 400000 \
         
     | 
| 62 | 
         
            +
                --tokens-per-sample 512 --arch transformer_xl --n-layer 12 \
         
     | 
| 63 | 
         
            +
                --d-model 512 --n-head 8 --d-head 64 --d-inner 2048 --dropout 0.1 \
         
     | 
| 64 | 
         
            +
                --dropatt 0.0 --mem-len 512 --optimizer adam --clip-norm 0.25 \
         
     | 
| 65 | 
         
            +
                --lr-scheduler cosine --warmup-updates 0 \
         
     | 
| 66 | 
         
            +
                --lr 0.0 --lr 0.00025 --batch-size 15 \
         
     | 
| 67 | 
         
            +
                --update-freq 1 --seed 2 --log-format json --log-interval 25 \
         
     | 
| 68 | 
         
            +
                --fp16
         
     | 
| 69 | 
         
            +
            ```
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
            ##### 2. Evaluate
         
     | 
| 72 | 
         
            +
            For Adaptive Span:
         
     | 
| 73 | 
         
            +
            ```bash
         
     | 
| 74 | 
         
            +
            fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \
         
     | 
| 75 | 
         
            +
             --user-dir examples/adaptive_span \
         
     | 
| 76 | 
         
            +
             --task truncated_bptt_lm --batch-size 8 --tokens-per-sample 512 --gen-subset test
         
     | 
| 77 | 
         
            +
            ```
         
     | 
| 78 | 
         
            +
            For Transformer-XL evaluation:
         
     | 
| 79 | 
         
            +
            ```bash
         
     | 
| 80 | 
         
            +
            fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \
         
     | 
| 81 | 
         
            +
                --user-dir examples/truncated_bptt/ --task truncated_bptt_lm --batch-size 8 \
         
     | 
| 82 | 
         
            +
                --tokens-per-sample 80 \
         
     | 
| 83 | 
         
            +
                --model-overrides '{"mem_len":2100,"clamp_len":820,"same_length":True}' \
         
     | 
| 84 | 
         
            +
                --gen-subset valid
         
     | 
| 85 | 
         
            +
            ```
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
            *Note:* During training the model saw 512 tokens of context
         
     | 
| 88 | 
         
            +
            (``--tokens-per-sample=512``), with batch size 8. These settings match the evaluation
         
     | 
| 89 | 
         
            +
            settings from [the original
         
     | 
| 90 | 
         
            +
            paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh).
         
     | 
    	
        fairseq/examples/adaptive_span/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,19 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import importlib
         
     | 
| 7 | 
         
            +
            import os
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            # automatically import any Python files in the current directory
         
     | 
| 10 | 
         
            +
            cur_dir = os.path.dirname(__file__)
         
     | 
| 11 | 
         
            +
            for file in os.listdir(cur_dir):
         
     | 
| 12 | 
         
            +
                path = os.path.join(cur_dir, file)
         
     | 
| 13 | 
         
            +
                if (
         
     | 
| 14 | 
         
            +
                    not file.startswith("_")
         
     | 
| 15 | 
         
            +
                    and not file.startswith(".")
         
     | 
| 16 | 
         
            +
                    and (file.endswith(".py") or os.path.isdir(path))
         
     | 
| 17 | 
         
            +
                ):
         
     | 
| 18 | 
         
            +
                    mod_name = file[: file.find(".py")] if file.endswith(".py") else file
         
     | 
| 19 | 
         
            +
                    module = importlib.import_module(__name__ + "." + mod_name)
         
     | 
    	
        fairseq/examples/adaptive_span/adagrad_with_grad_clip.py
    ADDED
    
    | 
         @@ -0,0 +1,128 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from torch.optim import Adagrad
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from fairseq.optim import LegacyFairseqOptimizer, register_optimizer
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            @register_optimizer("adagrad_with_grad_clip")
         
     | 
| 12 | 
         
            +
            class FairseqAdagradWithGradClip(LegacyFairseqOptimizer):
         
     | 
| 13 | 
         
            +
                def __init__(self, args, params):
         
     | 
| 14 | 
         
            +
                    super().__init__(args)
         
     | 
| 15 | 
         
            +
                    self._optimizer = AdagradWithGradClip(params, **self.optimizer_config)
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                @staticmethod
         
     | 
| 18 | 
         
            +
                def add_args(parser):
         
     | 
| 19 | 
         
            +
                    """Add optimizer-specific arguments to the parser."""
         
     | 
| 20 | 
         
            +
                    # fmt: off
         
     | 
| 21 | 
         
            +
                    parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
         
     | 
| 22 | 
         
            +
                                        help='weight decay')
         
     | 
| 23 | 
         
            +
                    parser.add_argument('--adagrad-clip', default=0.0, type=float, metavar='D',
         
     | 
| 24 | 
         
            +
                                        help='internal grad clip')
         
     | 
| 25 | 
         
            +
                    # fmt: on
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                @property
         
     | 
| 28 | 
         
            +
                def optimizer_config(self):
         
     | 
| 29 | 
         
            +
                    """
         
     | 
| 30 | 
         
            +
                    Return a kwarg dictionary that will be used to override optimizer
         
     | 
| 31 | 
         
            +
                    args stored in checkpoints. This allows us to load a checkpoint and
         
     | 
| 32 | 
         
            +
                    resume training using a different set of optimizer args, e.g., with a
         
     | 
| 33 | 
         
            +
                    different learning rate.
         
     | 
| 34 | 
         
            +
                    """
         
     | 
| 35 | 
         
            +
                    return {
         
     | 
| 36 | 
         
            +
                        "lr": self.args.lr[0],
         
     | 
| 37 | 
         
            +
                        "weight_decay": self.args.weight_decay,
         
     | 
| 38 | 
         
            +
                        "grad_clip": self.args.adagrad_clip,
         
     | 
| 39 | 
         
            +
                    }
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                @property
         
     | 
| 42 | 
         
            +
                def supports_flat_params(self):
         
     | 
| 43 | 
         
            +
                    return False
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            def _clip_grad(clr, grad, group_grad_clip):
         
     | 
| 47 | 
         
            +
                if group_grad_clip > 0:
         
     | 
| 48 | 
         
            +
                    norm = grad.norm(2).item()
         
     | 
| 49 | 
         
            +
                    if norm > group_grad_clip:
         
     | 
| 50 | 
         
            +
                        clr *= group_grad_clip / (norm + 1e-10)
         
     | 
| 51 | 
         
            +
                return clr
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            class AdagradWithGradClip(Adagrad):
         
     | 
| 55 | 
         
            +
                """Adagrad algorithm with custom gradient clipping"""
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                def __init__(
         
     | 
| 58 | 
         
            +
                    self,
         
     | 
| 59 | 
         
            +
                    params,
         
     | 
| 60 | 
         
            +
                    lr=1e-2,
         
     | 
| 61 | 
         
            +
                    lr_decay=0,
         
     | 
| 62 | 
         
            +
                    weight_decay=0,
         
     | 
| 63 | 
         
            +
                    initial_accumulator_value=0,
         
     | 
| 64 | 
         
            +
                    grad_clip=0,
         
     | 
| 65 | 
         
            +
                ):
         
     | 
| 66 | 
         
            +
                    Adagrad.__init__(
         
     | 
| 67 | 
         
            +
                        self,
         
     | 
| 68 | 
         
            +
                        params,
         
     | 
| 69 | 
         
            +
                        lr=lr,
         
     | 
| 70 | 
         
            +
                        lr_decay=lr_decay,
         
     | 
| 71 | 
         
            +
                        weight_decay=weight_decay,
         
     | 
| 72 | 
         
            +
                        initial_accumulator_value=initial_accumulator_value,
         
     | 
| 73 | 
         
            +
                    )
         
     | 
| 74 | 
         
            +
                    self.defaults["grad_clip"] = grad_clip
         
     | 
| 75 | 
         
            +
                    self.param_groups[0].setdefault("grad_clip", grad_clip)
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                def step(self, closure=None):
         
     | 
| 78 | 
         
            +
                    loss = None
         
     | 
| 79 | 
         
            +
                    if closure is not None:
         
     | 
| 80 | 
         
            +
                        loss = closure()
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                    for group in self.param_groups:
         
     | 
| 83 | 
         
            +
                        for p in group["params"]:
         
     | 
| 84 | 
         
            +
                            if p.grad is None:
         
     | 
| 85 | 
         
            +
                                continue
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                            grad = p.grad.data
         
     | 
| 88 | 
         
            +
                            state = self.state[p]
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                            state["step"] += 1
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                            if group["weight_decay"] != 0:
         
     | 
| 93 | 
         
            +
                                if p.grad.data.is_sparse:
         
     | 
| 94 | 
         
            +
                                    raise RuntimeError(
         
     | 
| 95 | 
         
            +
                                        "weight_decay option is "
         
     | 
| 96 | 
         
            +
                                        "not compatible with sparse "
         
     | 
| 97 | 
         
            +
                                        "gradients"
         
     | 
| 98 | 
         
            +
                                    )
         
     | 
| 99 | 
         
            +
                                grad = grad.add(group["weight_decay"], p.data)
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                            clr = group["lr"] / (1 + (state["step"] - 1) * group["lr_decay"])
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                            # clip
         
     | 
| 104 | 
         
            +
                            clr = _clip_grad(clr=clr, grad=grad, group_grad_clip=group["grad_clip"])
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                            if grad.is_sparse:
         
     | 
| 107 | 
         
            +
                                # the update is non-linear so indices must be unique
         
     | 
| 108 | 
         
            +
                                grad = grad.coalesce()
         
     | 
| 109 | 
         
            +
                                grad_indices = grad._indices()
         
     | 
| 110 | 
         
            +
                                grad_values = grad._values()
         
     | 
| 111 | 
         
            +
                                size = grad.size()
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                                def make_sparse(values):
         
     | 
| 114 | 
         
            +
                                    constructor = grad.new
         
     | 
| 115 | 
         
            +
                                    if grad_indices.dim() == 0 or values.dim() == 0:
         
     | 
| 116 | 
         
            +
                                        return constructor().resize_as_(grad)
         
     | 
| 117 | 
         
            +
                                    return constructor(grad_indices, values, size)
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                                state["sum"].add_(make_sparse(grad_values.pow(2)))
         
     | 
| 120 | 
         
            +
                                std = state["sum"]._sparse_mask(grad)
         
     | 
| 121 | 
         
            +
                                std_values = std._values().sqrt_().add_(1e-10)
         
     | 
| 122 | 
         
            +
                                p.data.add_(-clr, make_sparse(grad_values / std_values))
         
     | 
| 123 | 
         
            +
                            else:
         
     | 
| 124 | 
         
            +
                                state["sum"].addcmul_(1, grad, grad)
         
     | 
| 125 | 
         
            +
                                std = state["sum"].sqrt().add_(1e-10)
         
     | 
| 126 | 
         
            +
                                p.data.addcdiv_(-clr, grad, std)
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                    return loss
         
     | 
    	
        fairseq/examples/adaptive_span/adaptive_span_attention.py
    ADDED
    
    | 
         @@ -0,0 +1,160 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
            import math
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            import torch.nn as nn
         
     | 
| 9 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            class AdaptiveMask(nn.Module):
         
     | 
| 13 | 
         
            +
                """Soft masking function for adaptive size.
         
     | 
| 14 | 
         
            +
                It masks out the last K values of an input. The masking value
         
     | 
| 15 | 
         
            +
                goes from 1 to 0 gradually, so K can be learned with
         
     | 
| 16 | 
         
            +
                back-propagation.
         
     | 
| 17 | 
         
            +
                Args:
         
     | 
| 18 | 
         
            +
                    max_size: maximum size (i.e. input dimension)
         
     | 
| 19 | 
         
            +
                    ramp_size: size of the ramp going from 0 to 1
         
     | 
| 20 | 
         
            +
                    init_val: initial size proportion not to be masked out
         
     | 
| 21 | 
         
            +
                    shape: learn multiple sizes independent of each other
         
     | 
| 22 | 
         
            +
                """
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                def __init__(self, max_size, ramp_size, init_val=0, shape=(1,)):
         
     | 
| 25 | 
         
            +
                    nn.Module.__init__(self)
         
     | 
| 26 | 
         
            +
                    self._max_size = max_size
         
     | 
| 27 | 
         
            +
                    self._ramp_size = ramp_size
         
     | 
| 28 | 
         
            +
                    self.current_val = nn.Parameter(torch.zeros(*shape) + init_val)
         
     | 
| 29 | 
         
            +
                    mask_template = torch.linspace(1 - max_size, 0, steps=max_size)
         
     | 
| 30 | 
         
            +
                    self.register_buffer("mask_template", mask_template)
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                def forward(self, x):
         
     | 
| 33 | 
         
            +
                    mask = self.mask_template.float() + self.current_val.float() * self._max_size
         
     | 
| 34 | 
         
            +
                    mask = mask / self._ramp_size + 1
         
     | 
| 35 | 
         
            +
                    mask = mask.clamp(0, 1)
         
     | 
| 36 | 
         
            +
                    if x.size(-1) < self._max_size:
         
     | 
| 37 | 
         
            +
                        # the input could have been trimmed beforehand to save computation
         
     | 
| 38 | 
         
            +
                        mask = mask.narrow(-1, self._max_size - x.size(-1), x.size(-1))
         
     | 
| 39 | 
         
            +
                    x = (x * mask).type_as(x)
         
     | 
| 40 | 
         
            +
                    return x
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                def get_current_max_size(self, include_ramp=True):
         
     | 
| 43 | 
         
            +
                    current_size = math.ceil(self.current_val.max().item() * self._max_size)
         
     | 
| 44 | 
         
            +
                    if include_ramp:
         
     | 
| 45 | 
         
            +
                        current_size += self._ramp_size
         
     | 
| 46 | 
         
            +
                    current_size = max(0, min(self._max_size, current_size))
         
     | 
| 47 | 
         
            +
                    return current_size
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                def get_current_avg_size(self, include_ramp=True):
         
     | 
| 50 | 
         
            +
                    current_size = math.ceil(
         
     | 
| 51 | 
         
            +
                        self.current_val.float().mean().item() * self._max_size
         
     | 
| 52 | 
         
            +
                    )
         
     | 
| 53 | 
         
            +
                    if include_ramp:
         
     | 
| 54 | 
         
            +
                        current_size += self._ramp_size
         
     | 
| 55 | 
         
            +
                    current_size = max(0, min(self._max_size, current_size))
         
     | 
| 56 | 
         
            +
                    return current_size
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                def clamp_param(self):
         
     | 
| 59 | 
         
            +
                    """this need to be called after each update"""
         
     | 
| 60 | 
         
            +
                    self.current_val.data.clamp_(0, 1)
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            class AdaptiveSpan(nn.Module):
         
     | 
| 64 | 
         
            +
                """Adaptive attention span for Transformerself.
         
     | 
| 65 | 
         
            +
                This module learns an attention span length from data for each
         
     | 
| 66 | 
         
            +
                self-attention head.
         
     | 
| 67 | 
         
            +
                Args:
         
     | 
| 68 | 
         
            +
                    attn_span: maximum attention span
         
     | 
| 69 | 
         
            +
                    adapt_span_loss: loss coefficient for the span length
         
     | 
| 70 | 
         
            +
                    adapt_span_ramp: length of the masking ramp
         
     | 
| 71 | 
         
            +
                    adapt_span_init: initial size ratio
         
     | 
| 72 | 
         
            +
                    adapt_span_cache: adapt cache size to reduce memory usage
         
     | 
| 73 | 
         
            +
                """
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                def __init__(
         
     | 
| 76 | 
         
            +
                    self,
         
     | 
| 77 | 
         
            +
                    attn_span,
         
     | 
| 78 | 
         
            +
                    adapt_span_ramp,
         
     | 
| 79 | 
         
            +
                    adapt_span_init,
         
     | 
| 80 | 
         
            +
                    n_head,
         
     | 
| 81 | 
         
            +
                    adapt_span_layer,
         
     | 
| 82 | 
         
            +
                    **kargs
         
     | 
| 83 | 
         
            +
                ):
         
     | 
| 84 | 
         
            +
                    nn.Module.__init__(self)
         
     | 
| 85 | 
         
            +
                    self._max_span = attn_span
         
     | 
| 86 | 
         
            +
                    self._n_head = n_head
         
     | 
| 87 | 
         
            +
                    self._adapt_span_layer = adapt_span_layer
         
     | 
| 88 | 
         
            +
                    if self._adapt_span_layer:
         
     | 
| 89 | 
         
            +
                        self._mask = AdaptiveMask(
         
     | 
| 90 | 
         
            +
                            max_size=self._max_span,
         
     | 
| 91 | 
         
            +
                            ramp_size=adapt_span_ramp,
         
     | 
| 92 | 
         
            +
                            init_val=adapt_span_init,
         
     | 
| 93 | 
         
            +
                        )
         
     | 
| 94 | 
         
            +
                    else:
         
     | 
| 95 | 
         
            +
                        self._mask = AdaptiveMask(
         
     | 
| 96 | 
         
            +
                            max_size=self._max_span,
         
     | 
| 97 | 
         
            +
                            ramp_size=adapt_span_ramp,
         
     | 
| 98 | 
         
            +
                            init_val=adapt_span_init,
         
     | 
| 99 | 
         
            +
                            shape=(n_head, 1, 1),
         
     | 
| 100 | 
         
            +
                        )
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                def forward(self, attn, normalize=True):
         
     | 
| 103 | 
         
            +
                    """mask attention with the right span"""
         
     | 
| 104 | 
         
            +
                    # batch and head dimensions are merged together, so separate them first
         
     | 
| 105 | 
         
            +
                    self.clamp_param()
         
     | 
| 106 | 
         
            +
                    if self._adapt_span_layer:
         
     | 
| 107 | 
         
            +
                        attn = self._mask(attn)
         
     | 
| 108 | 
         
            +
                    else:
         
     | 
| 109 | 
         
            +
                        B = attn.size(0)  # batch size
         
     | 
| 110 | 
         
            +
                        M = attn.size(1)  # block size
         
     | 
| 111 | 
         
            +
                        attn = attn.reshape(B // self._n_head, self._n_head, M, -1)
         
     | 
| 112 | 
         
            +
                        attn = self._mask(attn)
         
     | 
| 113 | 
         
            +
                        attn = attn.view(B, M, -1)
         
     | 
| 114 | 
         
            +
                    return attn
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                def get_trim_len(self):
         
     | 
| 117 | 
         
            +
                    """how much of memory can be trimmed to reduce computation"""
         
     | 
| 118 | 
         
            +
                    L = self._max_span
         
     | 
| 119 | 
         
            +
                    trim_len = min(L - 1, L - self._mask.get_current_max_size())
         
     | 
| 120 | 
         
            +
                    # too fine granularity might be bad for the memory management
         
     | 
| 121 | 
         
            +
                    trim_len = math.floor(trim_len / 64) * 64
         
     | 
| 122 | 
         
            +
                    return trim_len
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                def trim_memory(self, query, key, value, key_pe):
         
     | 
| 125 | 
         
            +
                    """trim out unnecessary memory beforehand to reduce computation"""
         
     | 
| 126 | 
         
            +
                    trim_len = self.get_trim_len()
         
     | 
| 127 | 
         
            +
                    cache_size = key.size(1) - query.size(1)
         
     | 
| 128 | 
         
            +
                    trim_len_cache = trim_len - (self._max_span - cache_size)
         
     | 
| 129 | 
         
            +
                    if trim_len_cache > 0:
         
     | 
| 130 | 
         
            +
                        key = key[:, trim_len_cache:, :]
         
     | 
| 131 | 
         
            +
                        value = value[:, trim_len_cache:, :]
         
     | 
| 132 | 
         
            +
                    elif trim_len_cache < 0:
         
     | 
| 133 | 
         
            +
                        # cache is too short! this happens when validation resumes
         
     | 
| 134 | 
         
            +
                        # after a lot of updates.
         
     | 
| 135 | 
         
            +
                        key = F.pad(key, [0, 0, -trim_len_cache, 0])
         
     | 
| 136 | 
         
            +
                        value = F.pad(value, [0, 0, -trim_len_cache, 0])
         
     | 
| 137 | 
         
            +
                    if trim_len > 0:
         
     | 
| 138 | 
         
            +
                        if key_pe is not None:
         
     | 
| 139 | 
         
            +
                            key_pe = key_pe[:, :, trim_len:]
         
     | 
| 140 | 
         
            +
                    return key, value, key_pe
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                def get_cache_size(self):
         
     | 
| 143 | 
         
            +
                    """determine how long the cache should be"""
         
     | 
| 144 | 
         
            +
                    trim_len = self.get_trim_len()
         
     | 
| 145 | 
         
            +
                    # give a buffer of 64 steps since a span might increase
         
     | 
| 146 | 
         
            +
                    # in future updates
         
     | 
| 147 | 
         
            +
                    return min(self._max_span, self._max_span - trim_len + 64)
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                def get_loss(self):
         
     | 
| 150 | 
         
            +
                    """a loss term for regularizing the span length"""
         
     | 
| 151 | 
         
            +
                    return self._max_span * self._mask.current_val.float().mean()
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                def get_current_max_span(self):
         
     | 
| 154 | 
         
            +
                    return self._mask.get_current_max_size()
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                def get_current_avg_span(self):
         
     | 
| 157 | 
         
            +
                    return self._mask.get_current_avg_size()
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                def clamp_param(self):
         
     | 
| 160 | 
         
            +
                    self._mask.clamp_param()
         
     | 
    	
        fairseq/examples/adaptive_span/adaptive_span_loss.py
    ADDED
    
    | 
         @@ -0,0 +1,106 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import math
         
     | 
| 7 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 10 | 
         
            +
            from fairseq import metrics, utils
         
     | 
| 11 | 
         
            +
            from fairseq.criterions import register_criterion
         
     | 
| 12 | 
         
            +
            from fairseq.criterions.cross_entropy import CrossEntropyCriterion
         
     | 
| 13 | 
         
            +
            from fairseq.dataclass import FairseqDataclass
         
     | 
| 14 | 
         
            +
            from omegaconf import II
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            @dataclass
         
     | 
| 18 | 
         
            +
            class AdaptiveSpanCriterionConfig(FairseqDataclass):
         
     | 
| 19 | 
         
            +
                sentence_avg: bool = II("optimization.sentence_avg")
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            @register_criterion("adaptive_span_loss", dataclass=AdaptiveSpanCriterionConfig)
         
     | 
| 23 | 
         
            +
            class AdaptiveSpanCriterion(CrossEntropyCriterion):
         
     | 
| 24 | 
         
            +
                def __init__(self, task, sentence_avg):
         
     | 
| 25 | 
         
            +
                    super().__init__(task, sentence_avg)
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                def forward(self, model, sample, reduce=True):
         
     | 
| 28 | 
         
            +
                    """Compute the loss for the given sample.
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                    Returns a tuple with three elements:
         
     | 
| 31 | 
         
            +
                    1) the loss here is summed, different from the adaptive span code
         
     | 
| 32 | 
         
            +
                    2) the sample size, which is used as the denominator for the gradient
         
     | 
| 33 | 
         
            +
                    3) logging outputs to display while training
         
     | 
| 34 | 
         
            +
                    """
         
     | 
| 35 | 
         
            +
                    net_output = model(**sample["net_input"])
         
     | 
| 36 | 
         
            +
                    loss, aux_loss, avg_span, max_span = self.compute_loss(
         
     | 
| 37 | 
         
            +
                        model, net_output, sample, reduce=reduce
         
     | 
| 38 | 
         
            +
                    )
         
     | 
| 39 | 
         
            +
                    sample_size = (
         
     | 
| 40 | 
         
            +
                        sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
         
     | 
| 41 | 
         
            +
                    )
         
     | 
| 42 | 
         
            +
                    loss /= sample_size
         
     | 
| 43 | 
         
            +
                    total_loss = loss + aux_loss
         
     | 
| 44 | 
         
            +
                    sample_size = 1
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                    logging_output = {
         
     | 
| 47 | 
         
            +
                        "loss": loss.data,
         
     | 
| 48 | 
         
            +
                        "ntokens": sample["ntokens"],
         
     | 
| 49 | 
         
            +
                        "nsentences": sample["target"].size(0),
         
     | 
| 50 | 
         
            +
                        "sample_size": sample_size,
         
     | 
| 51 | 
         
            +
                        "total_loss": total_loss.data,
         
     | 
| 52 | 
         
            +
                        "avg_span": avg_span * sample_size,
         
     | 
| 53 | 
         
            +
                        "max_span": max_span * sample_size,
         
     | 
| 54 | 
         
            +
                    }
         
     | 
| 55 | 
         
            +
                    return total_loss, sample_size, logging_output
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                def compute_loss(self, model, net_output, sample, reduce=True):
         
     | 
| 58 | 
         
            +
                    loss, _ = super().compute_loss(model, net_output, sample, reduce)
         
     | 
| 59 | 
         
            +
                    aux_loss = model.get_aux_loss()
         
     | 
| 60 | 
         
            +
                    avg_span = model.get_current_avg_span()
         
     | 
| 61 | 
         
            +
                    max_span = model.get_current_max_span()
         
     | 
| 62 | 
         
            +
                    return loss, aux_loss, avg_span, max_span
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                @staticmethod
         
     | 
| 65 | 
         
            +
                def reduce_metrics(logging_outputs) -> None:
         
     | 
| 66 | 
         
            +
                    """Aggregate logging outputs from data parallel training."""
         
     | 
| 67 | 
         
            +
                    loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
         
     | 
| 68 | 
         
            +
                    ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
         
     | 
| 69 | 
         
            +
                    sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
         
     | 
| 70 | 
         
            +
                    total_loss_sum = sum(log.get("total_loss", 0) for log in logging_outputs)
         
     | 
| 71 | 
         
            +
                    avg_span_sum = sum(log.get("avg_span", 0) for log in logging_outputs)
         
     | 
| 72 | 
         
            +
                    max_span_sum = sum(log.get("max_span", 0) for log in logging_outputs)
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                    # we divide by log(2) to convert the loss from base e to base 2
         
     | 
| 75 | 
         
            +
                    metrics.log_scalar(
         
     | 
| 76 | 
         
            +
                        "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
         
     | 
| 77 | 
         
            +
                    )
         
     | 
| 78 | 
         
            +
                    metrics.log_scalar("avg_span", avg_span_sum / sample_size, sample_size, round=3)
         
     | 
| 79 | 
         
            +
                    metrics.log_scalar("max_span", max_span_sum / sample_size, sample_size, round=3)
         
     | 
| 80 | 
         
            +
                    # total loss contains the L1 norm on adaptive-span
         
     | 
| 81 | 
         
            +
                    metrics.log_scalar(
         
     | 
| 82 | 
         
            +
                        "total_loss",
         
     | 
| 83 | 
         
            +
                        total_loss_sum / sample_size / math.log(2),
         
     | 
| 84 | 
         
            +
                        sample_size,
         
     | 
| 85 | 
         
            +
                        round=3,
         
     | 
| 86 | 
         
            +
                    )
         
     | 
| 87 | 
         
            +
                    if sample_size != ntokens:
         
     | 
| 88 | 
         
            +
                        metrics.log_scalar(
         
     | 
| 89 | 
         
            +
                            "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
         
     | 
| 90 | 
         
            +
                        )
         
     | 
| 91 | 
         
            +
                        metrics.log_derived(
         
     | 
| 92 | 
         
            +
                            "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
         
     | 
| 93 | 
         
            +
                        )
         
     | 
| 94 | 
         
            +
                    else:
         
     | 
| 95 | 
         
            +
                        metrics.log_derived(
         
     | 
| 96 | 
         
            +
                            "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
         
     | 
| 97 | 
         
            +
                        )
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                @staticmethod
         
     | 
| 100 | 
         
            +
                def logging_outputs_can_be_summed() -> bool:
         
     | 
| 101 | 
         
            +
                    """
         
     | 
| 102 | 
         
            +
                    Whether the logging outputs returned by `forward` can be summed
         
     | 
| 103 | 
         
            +
                    across workers prior to calling `reduce_metrics`. Setting this
         
     | 
| 104 | 
         
            +
                    to True will improves distributed training speed.
         
     | 
| 105 | 
         
            +
                    """
         
     | 
| 106 | 
         
            +
                    return True
         
     | 
    	
        fairseq/examples/adaptive_span/adaptive_span_model.py
    ADDED
    
    | 
         @@ -0,0 +1,263 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            # All rights reserved.
         
     | 
| 3 | 
         
            +
            #
         
     | 
| 4 | 
         
            +
            # This source code is licensed under the license found in the
         
     | 
| 5 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import math
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            import torch
         
     | 
| 10 | 
         
            +
            import torch.nn as nn
         
     | 
| 11 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            from fairseq.modules.layer_norm import LayerNorm
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            from .adaptive_span_attention import AdaptiveSpan
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            # Size notations:
         
     | 
| 18 | 
         
            +
            # B = batch_size, H = d_model, M = block_size, L = attn_span
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            def _skew(X, pad_value):
         
     | 
| 22 | 
         
            +
                """shift every row 1 step to right"""
         
     | 
| 23 | 
         
            +
                # X = B x M x L
         
     | 
| 24 | 
         
            +
                B, M, L = X.size()
         
     | 
| 25 | 
         
            +
                X = F.pad(X, (0, M + 1), value=pad_value)  # B x M x (L+M+1)
         
     | 
| 26 | 
         
            +
                X = X.view(B, -1)  # B x ML+MM+M
         
     | 
| 27 | 
         
            +
                X = X[:, :-M]  # B x ML+MM
         
     | 
| 28 | 
         
            +
                X = X.view(B, M, M + L)  # B x M x L+M
         
     | 
| 29 | 
         
            +
                return X
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            def _unskew(X):
         
     | 
| 33 | 
         
            +
                """reverse _skew operation"""
         
     | 
| 34 | 
         
            +
                # X = B x M x L+M
         
     | 
| 35 | 
         
            +
                B, M, L = X.size()
         
     | 
| 36 | 
         
            +
                L -= M
         
     | 
| 37 | 
         
            +
                X = X.view(B, -1)  # B x ML+MM
         
     | 
| 38 | 
         
            +
                X = F.pad(X, (0, M))  # B x ML+MM+M
         
     | 
| 39 | 
         
            +
                X = X.view(B, M, M + L + 1)  # B x M x L+M+1
         
     | 
| 40 | 
         
            +
                X = X[:, :, :L]  # B x M x L
         
     | 
| 41 | 
         
            +
                return X
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            class SeqAttention(nn.Module):
         
     | 
| 45 | 
         
            +
                """Sequential self-attention layer.
         
     | 
| 46 | 
         
            +
                Each token will attend to its previous fixed number of steps.
         
     | 
| 47 | 
         
            +
                Note that attention doesn't include the current step itself.
         
     | 
| 48 | 
         
            +
                """
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                def __init__(self, d_model, n_head, attn_span, dropout, adapt_span_layer, **kargs):
         
     | 
| 51 | 
         
            +
                    nn.Module.__init__(self)
         
     | 
| 52 | 
         
            +
                    self.dropout = nn.Dropout(dropout)
         
     | 
| 53 | 
         
            +
                    self.d_model = d_model  # size of a single head
         
     | 
| 54 | 
         
            +
                    self.attn_span = attn_span
         
     | 
| 55 | 
         
            +
                    self.adaptive_span = AdaptiveSpan(
         
     | 
| 56 | 
         
            +
                        attn_span=attn_span,
         
     | 
| 57 | 
         
            +
                        n_head=n_head,
         
     | 
| 58 | 
         
            +
                        adapt_span_layer=adapt_span_layer,
         
     | 
| 59 | 
         
            +
                        **kargs
         
     | 
| 60 | 
         
            +
                    )
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                def forward(self, query, key, value, key_pe):
         
     | 
| 63 | 
         
            +
                    # query size = B x M x H
         
     | 
| 64 | 
         
            +
                    # key, value sizes = B x (M+L) x H
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    key, value, key_pe = self.adaptive_span.trim_memory(query, key, value, key_pe)
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                    # compute attention from context
         
     | 
| 69 | 
         
            +
                    # B x M (dest) x (M+L) (src)
         
     | 
| 70 | 
         
            +
                    attn_cont = torch.matmul(query, key.transpose(-1, -2))
         
     | 
| 71 | 
         
            +
                    attn_cont = _unskew(attn_cont)  # B x M x L
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                    # compute the effect of position embedding
         
     | 
| 74 | 
         
            +
                    attn_pos = torch.matmul(query, key_pe)  # B x M x L_pos
         
     | 
| 75 | 
         
            +
                    attn = attn_cont + attn_pos
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                    attn = attn / math.sqrt(self.d_model)  # B x M X L_pos
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                    attn = F.softmax(attn.float(), dim=-1).type_as(attn)
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                    # trim attention lengths according to the learned span
         
     | 
| 82 | 
         
            +
                    attn = self.adaptive_span(attn)
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                    attn = self.dropout(attn)  # B x M X L_pos
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                    attn_cont = _skew(attn, 0)  # B x M X (L+M)
         
     | 
| 87 | 
         
            +
                    out = torch.matmul(attn_cont, value)  # B x M x H
         
     | 
| 88 | 
         
            +
                    return out
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                def get_cache_size(self):
         
     | 
| 91 | 
         
            +
                    return self.adaptive_span.get_cache_size()
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
            class MultiHeadSeqAttention(nn.Module):
         
     | 
| 95 | 
         
            +
                def __init__(self, d_model, n_head, **kargs):
         
     | 
| 96 | 
         
            +
                    nn.Module.__init__(self)
         
     | 
| 97 | 
         
            +
                    assert d_model % n_head == 0
         
     | 
| 98 | 
         
            +
                    self.n_head = n_head
         
     | 
| 99 | 
         
            +
                    self.head_dim = d_model // n_head
         
     | 
| 100 | 
         
            +
                    self.attn = SeqAttention(d_model=self.head_dim, n_head=n_head, **kargs)
         
     | 
| 101 | 
         
            +
                    self.proj_query = nn.Linear(d_model, d_model, bias=False)
         
     | 
| 102 | 
         
            +
                    nn.init.xavier_normal_(self.proj_query.weight)
         
     | 
| 103 | 
         
            +
                    self.proj_out = nn.Linear(d_model, d_model, bias=False)
         
     | 
| 104 | 
         
            +
                    nn.init.xavier_normal_(self.proj_out.weight)
         
     | 
| 105 | 
         
            +
                    self.proj_val = nn.Linear(d_model, d_model, bias=False)
         
     | 
| 106 | 
         
            +
                    nn.init.xavier_normal_(self.proj_val.weight)
         
     | 
| 107 | 
         
            +
                    self.proj_key = nn.Linear(d_model, d_model, bias=False)
         
     | 
| 108 | 
         
            +
                    nn.init.xavier_normal_(self.proj_key.weight)
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                def head_reshape(self, x):
         
     | 
| 111 | 
         
            +
                    K = self.n_head
         
     | 
| 112 | 
         
            +
                    D = self.head_dim
         
     | 
| 113 | 
         
            +
                    x = x.view(x.size()[:-1] + (K, D))  # B x (M+L) x K x D
         
     | 
| 114 | 
         
            +
                    x = x.transpose(1, 2).contiguous()  # B x K x (M+L) x D
         
     | 
| 115 | 
         
            +
                    x = x.view(-1, x.size(-2), x.size(-1))  # B_K x (M+L) x D
         
     | 
| 116 | 
         
            +
                    return x
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                def forward(self, query, key, value, key_pe):
         
     | 
| 119 | 
         
            +
                    B = query.size(0)
         
     | 
| 120 | 
         
            +
                    K = self.n_head
         
     | 
| 121 | 
         
            +
                    D = self.head_dim
         
     | 
| 122 | 
         
            +
                    M = query.size(1)
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                    query = self.proj_query(query)
         
     | 
| 125 | 
         
            +
                    query = self.head_reshape(query)
         
     | 
| 126 | 
         
            +
                    value = self.proj_val(value)
         
     | 
| 127 | 
         
            +
                    value = self.head_reshape(value)
         
     | 
| 128 | 
         
            +
                    key = self.proj_key(key)
         
     | 
| 129 | 
         
            +
                    key = self.head_reshape(key)
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                    out = self.attn(query, key, value, key_pe)  # B_K x M x D
         
     | 
| 132 | 
         
            +
                    out = out.view(B, K, M, D)  # B x K x M x D
         
     | 
| 133 | 
         
            +
                    out = out.transpose(1, 2).contiguous()  # B x M x K x D
         
     | 
| 134 | 
         
            +
                    out = out.view(B, M, -1)  # B x M x K_D
         
     | 
| 135 | 
         
            +
                    out = self.proj_out(out)
         
     | 
| 136 | 
         
            +
                    return out
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
            class FeedForwardLayer(nn.Module):
         
     | 
| 140 | 
         
            +
                def __init__(self, d_model, d_inner, dropout, **kargs):
         
     | 
| 141 | 
         
            +
                    nn.Module.__init__(self)
         
     | 
| 142 | 
         
            +
                    self.fc1 = nn.Linear(d_model, d_inner)
         
     | 
| 143 | 
         
            +
                    self.fc2 = nn.Linear(d_inner, d_model)
         
     | 
| 144 | 
         
            +
                    nn.init.xavier_uniform_(self.fc1.weight)
         
     | 
| 145 | 
         
            +
                    nn.init.xavier_uniform_(self.fc2.weight)
         
     | 
| 146 | 
         
            +
                    self.dropout = nn.Dropout(dropout)
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                def forward(self, h):
         
     | 
| 149 | 
         
            +
                    h1 = F.relu(self.fc1(h))
         
     | 
| 150 | 
         
            +
                    h1 = self.dropout(h1)
         
     | 
| 151 | 
         
            +
                    h2 = self.fc2(h1)
         
     | 
| 152 | 
         
            +
                    return h2
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
            class TransformerSeqLayer(nn.Module):
         
     | 
| 156 | 
         
            +
                def __init__(self, d_model, **kargs):
         
     | 
| 157 | 
         
            +
                    nn.Module.__init__(self)
         
     | 
| 158 | 
         
            +
                    self.attn = MultiHeadSeqAttention(d_model=d_model, **kargs)
         
     | 
| 159 | 
         
            +
                    self.norm1 = LayerNorm(d_model)
         
     | 
| 160 | 
         
            +
                    self.ff = FeedForwardLayer(d_model=d_model, **kargs)
         
     | 
| 161 | 
         
            +
                    self.norm2 = LayerNorm(d_model)
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                def forward(self, h, h_cache, key_pe):
         
     | 
| 164 | 
         
            +
                    # h = B x M x H
         
     | 
| 165 | 
         
            +
                    # h_cache = B x L x H
         
     | 
| 166 | 
         
            +
                    h_all = torch.cat([h_cache, h], dim=1)  # B x (M+L) x H
         
     | 
| 167 | 
         
            +
                    attn_out = self.attn(h, h_all, h_all, key_pe)
         
     | 
| 168 | 
         
            +
                    h = self.norm1(h + attn_out)  # B x M x H
         
     | 
| 169 | 
         
            +
                    if self.ff is not None:
         
     | 
| 170 | 
         
            +
                        ff_out = self.ff(h)
         
     | 
| 171 | 
         
            +
                        out = self.norm2(h + ff_out)  # B x M x H
         
     | 
| 172 | 
         
            +
                    else:
         
     | 
| 173 | 
         
            +
                        out = h
         
     | 
| 174 | 
         
            +
                    return out
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                def get_cache_size(self):
         
     | 
| 177 | 
         
            +
                    return self.attn.attn.get_cache_size()
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
            class TransformerSeq(nn.Module):
         
     | 
| 181 | 
         
            +
                def __init__(
         
     | 
| 182 | 
         
            +
                    self,
         
     | 
| 183 | 
         
            +
                    vocab_size,
         
     | 
| 184 | 
         
            +
                    d_model,
         
     | 
| 185 | 
         
            +
                    n_head,
         
     | 
| 186 | 
         
            +
                    n_layer,
         
     | 
| 187 | 
         
            +
                    attn_span,
         
     | 
| 188 | 
         
            +
                    emb_dropout,
         
     | 
| 189 | 
         
            +
                    aux_loss_scaler,
         
     | 
| 190 | 
         
            +
                    adapt_span_layer,
         
     | 
| 191 | 
         
            +
                    **kargs
         
     | 
| 192 | 
         
            +
                ):
         
     | 
| 193 | 
         
            +
                    nn.Module.__init__(self)
         
     | 
| 194 | 
         
            +
                    # token embeddings
         
     | 
| 195 | 
         
            +
                    self.in_emb = nn.Embedding(vocab_size, d_model)
         
     | 
| 196 | 
         
            +
                    nn.init.normal_(self.in_emb.weight, mean=0, std=d_model ** -0.5)
         
     | 
| 197 | 
         
            +
                    self.out_emb = nn.Linear(d_model, vocab_size)
         
     | 
| 198 | 
         
            +
                    self.aux_loss_scaler = aux_loss_scaler
         
     | 
| 199 | 
         
            +
                    if emb_dropout > 0:
         
     | 
| 200 | 
         
            +
                        self.emb_dropout = nn.Dropout(emb_dropout)
         
     | 
| 201 | 
         
            +
                    else:
         
     | 
| 202 | 
         
            +
                        self.emb_dropout = None
         
     | 
| 203 | 
         
            +
                    # position embeddings
         
     | 
| 204 | 
         
            +
                    self.key_pe = nn.Parameter(torch.randn(1, d_model // n_head, attn_span))
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                    self.layers = nn.ModuleList()
         
     | 
| 207 | 
         
            +
                    self.layers.extend(
         
     | 
| 208 | 
         
            +
                        TransformerSeqLayer(
         
     | 
| 209 | 
         
            +
                            d_model=d_model,
         
     | 
| 210 | 
         
            +
                            n_head=n_head,
         
     | 
| 211 | 
         
            +
                            attn_span=attn_span,
         
     | 
| 212 | 
         
            +
                            adapt_span_layer=adapt_span_layer,
         
     | 
| 213 | 
         
            +
                            **kargs
         
     | 
| 214 | 
         
            +
                        )
         
     | 
| 215 | 
         
            +
                        for _ in range(n_layer)
         
     | 
| 216 | 
         
            +
                    )
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
                def forward(self, x, h_cache, target=None):
         
     | 
| 219 | 
         
            +
                    # x size = B x M
         
     | 
| 220 | 
         
            +
                    block_size = x.size(1)
         
     | 
| 221 | 
         
            +
                    h = self.in_emb(x)  # B x M x H
         
     | 
| 222 | 
         
            +
                    if self.emb_dropout is not None:
         
     | 
| 223 | 
         
            +
                        h = self.emb_dropout(h)
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
                    h_cache_next = []
         
     | 
| 226 | 
         
            +
                    for l, layer in enumerate(self.layers):
         
     | 
| 227 | 
         
            +
                        cache_size = layer.attn.attn.get_cache_size()
         
     | 
| 228 | 
         
            +
                        if cache_size > block_size:
         
     | 
| 229 | 
         
            +
                            h_cache_next_l = torch.cat(
         
     | 
| 230 | 
         
            +
                                [h_cache[l][:, -cache_size + block_size :, :], h], dim=1
         
     | 
| 231 | 
         
            +
                            ).detach()
         
     | 
| 232 | 
         
            +
                        else:
         
     | 
| 233 | 
         
            +
                            h_cache_next_l = h[:, -cache_size:, :].detach()
         
     | 
| 234 | 
         
            +
                        h_cache_next.append(h_cache_next_l)
         
     | 
| 235 | 
         
            +
                        h = layer(h, h_cache[l], self.key_pe)  # B x M x H
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
                    if self.emb_dropout is not None:
         
     | 
| 238 | 
         
            +
                        h = self.emb_dropout(h)
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                    out = F.log_softmax(self.out_emb(h).float(), dim=-1).type_as(h)
         
     | 
| 241 | 
         
            +
                    dummy_loss = None
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                    return out, h_cache_next, dummy_loss
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
                def get_aux_loss(self):
         
     | 
| 246 | 
         
            +
                    loss = 0.0
         
     | 
| 247 | 
         
            +
                    for layer in self.layers:
         
     | 
| 248 | 
         
            +
                        loss += layer.attn.attn.adaptive_span.get_loss()
         
     | 
| 249 | 
         
            +
                    return self.aux_loss_scaler * loss
         
     | 
| 250 | 
         
            +
             
     | 
| 251 | 
         
            +
                def get_current_max_span(self):
         
     | 
| 252 | 
         
            +
                    max_span = 0.0
         
     | 
| 253 | 
         
            +
                    for layer in self.layers:
         
     | 
| 254 | 
         
            +
                        max_span = max(
         
     | 
| 255 | 
         
            +
                            max_span, layer.attn.attn.adaptive_span.get_current_max_span()
         
     | 
| 256 | 
         
            +
                        )
         
     | 
| 257 | 
         
            +
                    return max_span
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
                def get_current_avg_span(self):
         
     | 
| 260 | 
         
            +
                    avg_span = 0.0
         
     | 
| 261 | 
         
            +
                    for layer in self.layers:
         
     | 
| 262 | 
         
            +
                        avg_span += layer.attn.attn.adaptive_span.get_current_avg_span()
         
     | 
| 263 | 
         
            +
                    return avg_span / len(self.layers)
         
     | 
    	
        fairseq/examples/adaptive_span/adaptive_span_model_wrapper.py
    ADDED
    
    | 
         @@ -0,0 +1,145 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import logging
         
     | 
| 7 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 8 | 
         
            +
            from typing import Dict, List, Optional
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            import torch
         
     | 
| 11 | 
         
            +
            from fairseq.dataclass import FairseqDataclass
         
     | 
| 12 | 
         
            +
            from fairseq.models import (
         
     | 
| 13 | 
         
            +
                FairseqIncrementalDecoder,
         
     | 
| 14 | 
         
            +
                FairseqLanguageModel,
         
     | 
| 15 | 
         
            +
                register_model,
         
     | 
| 16 | 
         
            +
            )
         
     | 
| 17 | 
         
            +
            from .adaptive_span_model import TransformerSeq as AdaptiveSpanTransformerModel
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            @dataclass
         
     | 
| 24 | 
         
            +
            class AdaptiveSpanSmallConfig(FairseqDataclass):
         
     | 
| 25 | 
         
            +
                # defaults come from https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8_small.sh
         
     | 
| 26 | 
         
            +
                vocab_size: int = 50
         
     | 
| 27 | 
         
            +
                d_model: int = 256
         
     | 
| 28 | 
         
            +
                n_head: int = 4
         
     | 
| 29 | 
         
            +
                d_inner: int = 1024
         
     | 
| 30 | 
         
            +
                n_layer: int = 8
         
     | 
| 31 | 
         
            +
                attn_span: int = 1024
         
     | 
| 32 | 
         
            +
                dropout: float = 0.0
         
     | 
| 33 | 
         
            +
                emb_dropout: float = 0.0
         
     | 
| 34 | 
         
            +
                adapt_span_ramp: int = 32
         
     | 
| 35 | 
         
            +
                adapt_span_init: float = 0.0
         
     | 
| 36 | 
         
            +
                aux_loss_scaler: float = 0.000002
         
     | 
| 37 | 
         
            +
                adapt_span_layer: bool = False
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            @register_model("adaptive_span", dataclass=AdaptiveSpanSmallConfig)
         
     | 
| 41 | 
         
            +
            class AdaptiveSpanTransformer(FairseqLanguageModel):
         
     | 
| 42 | 
         
            +
                @classmethod
         
     | 
| 43 | 
         
            +
                def build_model(cls, cfg: AdaptiveSpanSmallConfig, task):
         
     | 
| 44 | 
         
            +
                    return cls(AdaptiveSpanDecoder(cfg, task))
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                def get_aux_loss(self):
         
     | 
| 47 | 
         
            +
                    return self.decoder.get_aux_loss()
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                def get_current_max_span(self):
         
     | 
| 50 | 
         
            +
                    return self.decoder.get_current_max_span()
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                def get_current_avg_span(self):
         
     | 
| 53 | 
         
            +
                    return self.decoder.get_current_avg_span()
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
            class AdaptiveSpanDecoder(FairseqIncrementalDecoder):
         
     | 
| 57 | 
         
            +
                def __init__(self, cfg, task):
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                    super().__init__(task.target_dictionary)
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    self.config = cfg
         
     | 
| 62 | 
         
            +
                    config = AdaptiveSpanSmallConfig(
         
     | 
| 63 | 
         
            +
                        vocab_size=len(task.target_dictionary),
         
     | 
| 64 | 
         
            +
                        d_model=cfg.d_model,
         
     | 
| 65 | 
         
            +
                        n_head=cfg.n_head,
         
     | 
| 66 | 
         
            +
                        d_inner=cfg.d_inner,
         
     | 
| 67 | 
         
            +
                        n_layer=cfg.n_layer,
         
     | 
| 68 | 
         
            +
                        attn_span=cfg.attn_span,
         
     | 
| 69 | 
         
            +
                        dropout=cfg.dropout,
         
     | 
| 70 | 
         
            +
                        emb_dropout=cfg.emb_dropout,
         
     | 
| 71 | 
         
            +
                        adapt_span_ramp=cfg.adapt_span_ramp,
         
     | 
| 72 | 
         
            +
                        adapt_span_init=cfg.adapt_span_init,
         
     | 
| 73 | 
         
            +
                        aux_loss_scaler=cfg.aux_loss_scaler,
         
     | 
| 74 | 
         
            +
                        adapt_span_layer=cfg.adapt_span_layer,
         
     | 
| 75 | 
         
            +
                    )
         
     | 
| 76 | 
         
            +
                    logger.info(config)
         
     | 
| 77 | 
         
            +
                    self.model = AdaptiveSpanTransformerModel(**config.__dict__)
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                    self._mems = None
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                def forward(
         
     | 
| 82 | 
         
            +
                    self,
         
     | 
| 83 | 
         
            +
                    src_tokens,
         
     | 
| 84 | 
         
            +
                    incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None,
         
     | 
| 85 | 
         
            +
                    encoder_out=None,
         
     | 
| 86 | 
         
            +
                ):
         
     | 
| 87 | 
         
            +
                    bsz = src_tokens.size(0)
         
     | 
| 88 | 
         
            +
                    if incremental_state is not None:  # used during inference
         
     | 
| 89 | 
         
            +
                        mems = self.get_incremental_state("mems")
         
     | 
| 90 | 
         
            +
                        src_tokens = src_tokens[:, -1:]  # only keep the most recent token
         
     | 
| 91 | 
         
            +
                    else:
         
     | 
| 92 | 
         
            +
                        mems = self._mems
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                    if mems is None:
         
     | 
| 95 | 
         
            +
                        # first time init
         
     | 
| 96 | 
         
            +
                        mems = self.init_hid_cache(bsz)
         
     | 
| 97 | 
         
            +
                    output = self.model(x=src_tokens, h_cache=mems,)
         
     | 
| 98 | 
         
            +
                    if incremental_state is not None:
         
     | 
| 99 | 
         
            +
                        self.set_incremental_state(incremental_state, "mems", output[1])
         
     | 
| 100 | 
         
            +
                    else:
         
     | 
| 101 | 
         
            +
                        self._mems = output[1]
         
     | 
| 102 | 
         
            +
                    return (output[0],)
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                def max_positions(self):
         
     | 
| 105 | 
         
            +
                    return self.config.attn_span
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                def init_hid_cache(self, batch_sz):
         
     | 
| 108 | 
         
            +
                    hid = []
         
     | 
| 109 | 
         
            +
                    for layer in self.model.layers:
         
     | 
| 110 | 
         
            +
                        param = next(self.model.parameters())
         
     | 
| 111 | 
         
            +
                        h = torch.zeros(
         
     | 
| 112 | 
         
            +
                            batch_sz,
         
     | 
| 113 | 
         
            +
                            layer.get_cache_size(),
         
     | 
| 114 | 
         
            +
                            self.config.d_model,
         
     | 
| 115 | 
         
            +
                            dtype=param.dtype,
         
     | 
| 116 | 
         
            +
                            device=param.device,
         
     | 
| 117 | 
         
            +
                        )
         
     | 
| 118 | 
         
            +
                        hid.append(h)
         
     | 
| 119 | 
         
            +
                    return hid
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                def get_aux_loss(self):
         
     | 
| 122 | 
         
            +
                    return self.model.get_aux_loss()
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                def get_current_max_span(self):
         
     | 
| 125 | 
         
            +
                    return self.model.get_current_max_span()
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                def get_current_avg_span(self):
         
     | 
| 128 | 
         
            +
                    return self.model.get_current_avg_span()
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                def reorder_incremental_state(
         
     | 
| 131 | 
         
            +
                    self,
         
     | 
| 132 | 
         
            +
                    incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]],
         
     | 
| 133 | 
         
            +
                    new_order: torch.Tensor,
         
     | 
| 134 | 
         
            +
                ):
         
     | 
| 135 | 
         
            +
                    """Reorder incremental state.
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                    This will be called when the order of the input has changed from the
         
     | 
| 138 | 
         
            +
                    previous time step. A typical use case is beam search, where the input
         
     | 
| 139 | 
         
            +
                    order changes between time steps based on the selection of beams.
         
     | 
| 140 | 
         
            +
                    """
         
     | 
| 141 | 
         
            +
                    raise NotImplementedError("This is required for generation/beam search")
         
     | 
| 142 | 
         
            +
                    # mems = self.get_incremental_state(incremental_state, "mems")
         
     | 
| 143 | 
         
            +
                    # if mems is not None:
         
     | 
| 144 | 
         
            +
                    #     new_mems = [mems_i.index_select(1, new_order) for mems_i in mems]
         
     | 
| 145 | 
         
            +
                    #     self.set_incremental_state(incremental_state, "mems", new_mems)
         
     | 
    	
        fairseq/examples/adaptive_span/truncated_bptt_lm_task.py
    ADDED
    
    | 
         @@ -0,0 +1,281 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import logging
         
     | 
| 7 | 
         
            +
            import os
         
     | 
| 8 | 
         
            +
            from dataclasses import dataclass, field
         
     | 
| 9 | 
         
            +
            from typing import List, Optional, Tuple
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            import torch
         
     | 
| 12 | 
         
            +
            from fairseq import utils
         
     | 
| 13 | 
         
            +
            from fairseq.data import (
         
     | 
| 14 | 
         
            +
                Dictionary,
         
     | 
| 15 | 
         
            +
                TokenBlockDataset,
         
     | 
| 16 | 
         
            +
                data_utils,
         
     | 
| 17 | 
         
            +
                iterators,
         
     | 
| 18 | 
         
            +
            )
         
     | 
| 19 | 
         
            +
            from fairseq.dataclass import FairseqDataclass
         
     | 
| 20 | 
         
            +
            from fairseq.distributed import utils as dist_utils
         
     | 
| 21 | 
         
            +
            from fairseq.tasks import FairseqTask, register_task
         
     | 
| 22 | 
         
            +
            from omegaconf import II
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            @dataclass
         
     | 
| 29 | 
         
            +
            class TruncatedBPTTLMConfig(FairseqDataclass):
         
     | 
| 30 | 
         
            +
                data: str = field(default="???", metadata={"help": "path to data directory"})
         
     | 
| 31 | 
         
            +
                tokens_per_sample: int = field(
         
     | 
| 32 | 
         
            +
                    default=1024,
         
     | 
| 33 | 
         
            +
                    metadata={"help": "max number of tokens per sequence"},
         
     | 
| 34 | 
         
            +
                )
         
     | 
| 35 | 
         
            +
                batch_size: int = II("dataset.batch_size")
         
     | 
| 36 | 
         
            +
                # Some models use *max_target_positions* to know how many positional
         
     | 
| 37 | 
         
            +
                # embeddings to learn. We use II(...) to make it default to
         
     | 
| 38 | 
         
            +
                # *tokens_per_sample*, but in principle there could be more positional
         
     | 
| 39 | 
         
            +
                # embeddings than tokens in a single batch. This may also be irrelevant for
         
     | 
| 40 | 
         
            +
                # custom model implementations.
         
     | 
| 41 | 
         
            +
                max_target_positions: int = II("task.tokens_per_sample")
         
     | 
| 42 | 
         
            +
                # these will be populated automatically if not provided
         
     | 
| 43 | 
         
            +
                data_parallel_rank: Optional[int] = None
         
     | 
| 44 | 
         
            +
                data_parallel_size: Optional[int] = None
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            @register_task("truncated_bptt_lm", dataclass=TruncatedBPTTLMConfig)
         
     | 
| 48 | 
         
            +
            class TruncatedBPTTLMTask(FairseqTask):
         
     | 
| 49 | 
         
            +
                def __init__(self, cfg: TruncatedBPTTLMConfig):
         
     | 
| 50 | 
         
            +
                    super().__init__(cfg)
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                    if cfg.data_parallel_rank is None or cfg.data_parallel_size is None:
         
     | 
| 53 | 
         
            +
                        if torch.distributed.is_initialized():
         
     | 
| 54 | 
         
            +
                            cfg.data_parallel_rank = dist_utils.get_data_parallel_rank()
         
     | 
| 55 | 
         
            +
                            cfg.data_parallel_size = dist_utils.get_data_parallel_world_size()
         
     | 
| 56 | 
         
            +
                        else:
         
     | 
| 57 | 
         
            +
                            cfg.data_parallel_rank = 0
         
     | 
| 58 | 
         
            +
                            cfg.data_parallel_size = 1
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                    # load the dictionary
         
     | 
| 61 | 
         
            +
                    paths = utils.split_paths(cfg.data)
         
     | 
| 62 | 
         
            +
                    assert len(paths) > 0
         
     | 
| 63 | 
         
            +
                    self.dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
         
     | 
| 64 | 
         
            +
                    logger.info("dictionary: {} types".format(len(self.dictionary)))
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                def load_dataset(self, split, epoch=1, combine=False, **kwargs):
         
     | 
| 67 | 
         
            +
                    """Load a given dataset split (e.g., train, valid, test)"""
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                    # support sharded datasets
         
     | 
| 70 | 
         
            +
                    paths = utils.split_paths(self.cfg.data)
         
     | 
| 71 | 
         
            +
                    assert len(paths) > 0
         
     | 
| 72 | 
         
            +
                    data_path = paths[(epoch - 1) % len(paths)]
         
     | 
| 73 | 
         
            +
                    split_path = os.path.join(data_path, split)
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                    # each element of *data* will be a tensorized line from the original
         
     | 
| 76 | 
         
            +
                    # text dataset, similar to ``open(split_path).readlines()``
         
     | 
| 77 | 
         
            +
                    data = data_utils.load_indexed_dataset(
         
     | 
| 78 | 
         
            +
                        split_path, self.dictionary, combine=combine
         
     | 
| 79 | 
         
            +
                    )
         
     | 
| 80 | 
         
            +
                    if data is None:
         
     | 
| 81 | 
         
            +
                        raise FileNotFoundError(
         
     | 
| 82 | 
         
            +
                            "Dataset not found: {} ({})".format(split, split_path)
         
     | 
| 83 | 
         
            +
                        )
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                    # this is similar to ``data.view(-1).split(tokens_per_sample)``
         
     | 
| 86 | 
         
            +
                    data = TokenBlockDataset(
         
     | 
| 87 | 
         
            +
                        data,
         
     | 
| 88 | 
         
            +
                        data.sizes,
         
     | 
| 89 | 
         
            +
                        block_size=self.cfg.tokens_per_sample,
         
     | 
| 90 | 
         
            +
                        pad=None,  # unused
         
     | 
| 91 | 
         
            +
                        eos=None,  # unused
         
     | 
| 92 | 
         
            +
                        break_mode="none",
         
     | 
| 93 | 
         
            +
                    )
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                    self.datasets[split] = TruncatedBPTTDataset(
         
     | 
| 96 | 
         
            +
                        data=data,
         
     | 
| 97 | 
         
            +
                        bsz_per_shard=self.cfg.batch_size,
         
     | 
| 98 | 
         
            +
                        shard_id=self.cfg.data_parallel_rank,
         
     | 
| 99 | 
         
            +
                        num_shards=self.cfg.data_parallel_size,
         
     | 
| 100 | 
         
            +
                    )
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                def dataset(self, split):
         
     | 
| 103 | 
         
            +
                    return self.datasets[split]
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                def get_batch_iterator(
         
     | 
| 106 | 
         
            +
                    self, dataset, num_workers=0, epoch=1, data_buffer_size=0, **kwargs
         
     | 
| 107 | 
         
            +
                ):
         
     | 
| 108 | 
         
            +
                    return iterators.EpochBatchIterator(
         
     | 
| 109 | 
         
            +
                        dataset=dataset,
         
     | 
| 110 | 
         
            +
                        collate_fn=self._collate_fn,
         
     | 
| 111 | 
         
            +
                        num_workers=num_workers,
         
     | 
| 112 | 
         
            +
                        epoch=epoch,
         
     | 
| 113 | 
         
            +
                        buffer_size=data_buffer_size,
         
     | 
| 114 | 
         
            +
                        # we don't use the batching functionality from EpochBatchIterator;
         
     | 
| 115 | 
         
            +
                        # instead every item in *dataset* is a whole batch
         
     | 
| 116 | 
         
            +
                        batch_sampler=[[i] for i in range(len(dataset))],
         
     | 
| 117 | 
         
            +
                        disable_shuffling=True,
         
     | 
| 118 | 
         
            +
                    )
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                def _collate_fn(self, items: List[List[torch.Tensor]]):
         
     | 
| 121 | 
         
            +
                    # we don't use fairseq's batching functionality, so we expect a single
         
     | 
| 122 | 
         
            +
                    # Tensor of type List[torch.Tensor]
         
     | 
| 123 | 
         
            +
                    assert len(items) == 1
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    # item will have shape B x T (the last batch may have length < T)
         
     | 
| 126 | 
         
            +
                    id, item = items[0]
         
     | 
| 127 | 
         
            +
                    item = data_utils.collate_tokens(item, pad_idx=self.source_dictionary.pad())
         
     | 
| 128 | 
         
            +
                    B, T = item.size()
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                    # shift item one position over and append a padding token for the target
         
     | 
| 131 | 
         
            +
                    target = torch.nn.functional.pad(
         
     | 
| 132 | 
         
            +
                        item[:, 1:], (0, 1, 0, 0), value=self.target_dictionary.pad()
         
     | 
| 133 | 
         
            +
                    )
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                    # fairseq expects batches to have the following structure
         
     | 
| 136 | 
         
            +
                    return {
         
     | 
| 137 | 
         
            +
                        "id": torch.tensor([id]*item.size(0)),
         
     | 
| 138 | 
         
            +
                        "net_input": {
         
     | 
| 139 | 
         
            +
                            "src_tokens": item,
         
     | 
| 140 | 
         
            +
                        },
         
     | 
| 141 | 
         
            +
                        "target": target,
         
     | 
| 142 | 
         
            +
                        "nsentences": item.size(0),
         
     | 
| 143 | 
         
            +
                        "ntokens": item.numel(),
         
     | 
| 144 | 
         
            +
                    }
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                def build_dataset_for_inference(
         
     | 
| 147 | 
         
            +
                    self, src_tokens: List[torch.Tensor], src_lengths: List[int], **kwargs
         
     | 
| 148 | 
         
            +
                ) -> torch.utils.data.Dataset:
         
     | 
| 149 | 
         
            +
                    eos = self.source_dictionary.eos()
         
     | 
| 150 | 
         
            +
                    dataset = TokenBlockDataset(
         
     | 
| 151 | 
         
            +
                        src_tokens,
         
     | 
| 152 | 
         
            +
                        src_lengths,
         
     | 
| 153 | 
         
            +
                        block_size=None,  # ignored for "eos" break mode
         
     | 
| 154 | 
         
            +
                        pad=self.source_dictionary.pad(),
         
     | 
| 155 | 
         
            +
                        eos=eos,
         
     | 
| 156 | 
         
            +
                        break_mode="eos",
         
     | 
| 157 | 
         
            +
                    )
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                    class Dataset(torch.utils.data.Dataset):
         
     | 
| 160 | 
         
            +
                        def __getitem__(self, i):
         
     | 
| 161 | 
         
            +
                            item = dataset[i]
         
     | 
| 162 | 
         
            +
                            if item[-1] == eos:
         
     | 
| 163 | 
         
            +
                                # remove eos to support generating with a prefix
         
     | 
| 164 | 
         
            +
                                item = item[:-1]
         
     | 
| 165 | 
         
            +
                            return (i, [item])
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                        def __len__(self):
         
     | 
| 168 | 
         
            +
                            return len(dataset)
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                    return Dataset()
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                def inference_step(
         
     | 
| 173 | 
         
            +
                    self, generator, models, sample, prefix_tokens=None, constraints=None
         
     | 
| 174 | 
         
            +
                ):
         
     | 
| 175 | 
         
            +
                    with torch.no_grad():
         
     | 
| 176 | 
         
            +
                        if constraints is not None:
         
     | 
| 177 | 
         
            +
                            raise NotImplementedError
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                        # SequenceGenerator doesn't use *src_tokens* directly, we need to
         
     | 
| 180 | 
         
            +
                        # pass the *prefix_tokens* argument instead.
         
     | 
| 181 | 
         
            +
                        if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement():
         
     | 
| 182 | 
         
            +
                            prefix_tokens = sample["net_input"]["src_tokens"]
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                        # begin generation with the end-of-sentence token
         
     | 
| 185 | 
         
            +
                        bos_token = self.source_dictionary.eos()
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                        return generator.generate(
         
     | 
| 188 | 
         
            +
                            models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token
         
     | 
| 189 | 
         
            +
                        )
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                def eval_lm_dataloader(
         
     | 
| 192 | 
         
            +
                    self,
         
     | 
| 193 | 
         
            +
                    dataset,
         
     | 
| 194 | 
         
            +
                    max_tokens: Optional[int] = 36000,
         
     | 
| 195 | 
         
            +
                    batch_size: Optional[int] = None,
         
     | 
| 196 | 
         
            +
                    max_positions: Optional[int] = None,
         
     | 
| 197 | 
         
            +
                    num_shards: int = 1,
         
     | 
| 198 | 
         
            +
                    shard_id: int = 0,
         
     | 
| 199 | 
         
            +
                    num_workers: int = 1,
         
     | 
| 200 | 
         
            +
                    data_buffer_size: int = 10,
         
     | 
| 201 | 
         
            +
                    context_window: int = 0,
         
     | 
| 202 | 
         
            +
                ):
         
     | 
| 203 | 
         
            +
                    if context_window > 0:
         
     | 
| 204 | 
         
            +
                        raise NotImplementedError(
         
     | 
| 205 | 
         
            +
                            "Transformer-XL doesn't need --context-window, try "
         
     | 
| 206 | 
         
            +
                            "--model-overrides '{\"mem_len\":42}' instead "
         
     | 
| 207 | 
         
            +
                        )
         
     | 
| 208 | 
         
            +
                    return self.get_batch_iterator(
         
     | 
| 209 | 
         
            +
                        dataset=dataset,
         
     | 
| 210 | 
         
            +
                        max_tokens=max_tokens,
         
     | 
| 211 | 
         
            +
                        max_sentences=batch_size,
         
     | 
| 212 | 
         
            +
                        max_positions=max_positions,
         
     | 
| 213 | 
         
            +
                        ignore_invalid_inputs=True,
         
     | 
| 214 | 
         
            +
                        num_shards=num_shards,
         
     | 
| 215 | 
         
            +
                        shard_id=shard_id,
         
     | 
| 216 | 
         
            +
                        num_workers=num_workers,
         
     | 
| 217 | 
         
            +
                        data_buffer_size=data_buffer_size,
         
     | 
| 218 | 
         
            +
                    ).next_epoch_itr(shuffle=False)
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
                @property
         
     | 
| 221 | 
         
            +
                def source_dictionary(self):
         
     | 
| 222 | 
         
            +
                    return self.dictionary
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
                @property
         
     | 
| 225 | 
         
            +
                def target_dictionary(self):
         
     | 
| 226 | 
         
            +
                    return self.dictionary
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
            class TruncatedBPTTDataset(torch.utils.data.Dataset):
         
     | 
| 230 | 
         
            +
                def __init__(
         
     | 
| 231 | 
         
            +
                    self,
         
     | 
| 232 | 
         
            +
                    data: List[torch.Tensor],  # ordered list of items
         
     | 
| 233 | 
         
            +
                    bsz_per_shard,  # number of items processed per GPUs per forward
         
     | 
| 234 | 
         
            +
                    shard_id,  # current GPU ID
         
     | 
| 235 | 
         
            +
                    num_shards,  # number of GPUs
         
     | 
| 236 | 
         
            +
                ):
         
     | 
| 237 | 
         
            +
                    super().__init__()
         
     | 
| 238 | 
         
            +
                    self.data = data
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                    def batchify(data, bsz):
         
     | 
| 241 | 
         
            +
                        # Work out how cleanly we can divide the dataset into bsz parts.
         
     | 
| 242 | 
         
            +
                        nbatch = data.size(0) // bsz
         
     | 
| 243 | 
         
            +
                        # Trim off any extra elements that wouldn't cleanly fit (remainders).
         
     | 
| 244 | 
         
            +
                        data = data.narrow(0, 0, nbatch * bsz)
         
     | 
| 245 | 
         
            +
                        # Evenly divide the data across the bsz batches.
         
     | 
| 246 | 
         
            +
                        data = data.view(bsz, -1).contiguous()
         
     | 
| 247 | 
         
            +
                        return data
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
                    # total number of sequences processed by all GPUs in each forward pass
         
     | 
| 250 | 
         
            +
                    global_batch_size = bsz_per_shard * num_shards
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
                    """
         
     | 
| 253 | 
         
            +
                    With a 16 item dataset, bsz_per_shard=2 and num_shards=3,
         
     | 
| 254 | 
         
            +
                    *indices* might look like:
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                        indices = [[0, 1],
         
     | 
| 257 | 
         
            +
                                   [2, 3],
         
     | 
| 258 | 
         
            +
                                   [4, 5],
         
     | 
| 259 | 
         
            +
                                   [6, 7],
         
     | 
| 260 | 
         
            +
                                   [8, 9],
         
     | 
| 261 | 
         
            +
                                   [10, 11]]
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                    The size of the TruncatedBPTTDataset instance will be 2,
         
     | 
| 264 | 
         
            +
                    and shard 1 will see items:
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                        [(0, [data[4], data[6]]),
         
     | 
| 267 | 
         
            +
                         (1, [data[5], data[7]])]
         
     | 
| 268 | 
         
            +
                    """
         
     | 
| 269 | 
         
            +
                    indices = batchify(torch.arange(len(data)), global_batch_size)
         
     | 
| 270 | 
         
            +
                    assert indices.size(0) == global_batch_size
         
     | 
| 271 | 
         
            +
             
     | 
| 272 | 
         
            +
                    self.my_indices = indices[
         
     | 
| 273 | 
         
            +
                        shard_id * bsz_per_shard : (shard_id + 1) * bsz_per_shard
         
     | 
| 274 | 
         
            +
                    ]
         
     | 
| 275 | 
         
            +
                    assert self.my_indices.size(0) == bsz_per_shard
         
     | 
| 276 | 
         
            +
             
     | 
| 277 | 
         
            +
                def __len__(self):
         
     | 
| 278 | 
         
            +
                    return self.my_indices.size(1)
         
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
                def __getitem__(self, i) -> Tuple[int, List[torch.Tensor]]:
         
     | 
| 281 | 
         
            +
                    return (i, [self.data[idx] for idx in self.my_indices[:, i]])
         
     | 
    	
        fairseq/examples/backtranslation/README.md
    ADDED
    
    | 
         @@ -0,0 +1,297 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Understanding Back-Translation at Scale (Edunov et al., 2018)
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            This page includes pre-trained models from the paper [Understanding Back-Translation at Scale (Edunov et al., 2018)](https://arxiv.org/abs/1808.09381).
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            ## Pre-trained models
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            Model | Description | Dataset | Download
         
     | 
| 8 | 
         
            +
            ---|---|---|---
         
     | 
| 9 | 
         
            +
            `transformer.wmt18.en-de` | Transformer <br> ([Edunov et al., 2018](https://arxiv.org/abs/1808.09381)) <br> WMT'18 winner | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz) <br> See NOTE in the archive
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            ## Example usage (torch.hub)
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            We require a few additional Python dependencies for preprocessing:
         
     | 
| 14 | 
         
            +
            ```bash
         
     | 
| 15 | 
         
            +
            pip install subword_nmt sacremoses
         
     | 
| 16 | 
         
            +
            ```
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            Then to generate translations from the full model ensemble:
         
     | 
| 19 | 
         
            +
            ```python
         
     | 
| 20 | 
         
            +
            import torch
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            # List available models
         
     | 
| 23 | 
         
            +
            torch.hub.list('pytorch/fairseq')  # [..., 'transformer.wmt18.en-de', ... ]
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            # Load the WMT'18 En-De ensemble
         
     | 
| 26 | 
         
            +
            en2de_ensemble = torch.hub.load(
         
     | 
| 27 | 
         
            +
                'pytorch/fairseq', 'transformer.wmt18.en-de',
         
     | 
| 28 | 
         
            +
                checkpoint_file='wmt18.model1.pt:wmt18.model2.pt:wmt18.model3.pt:wmt18.model4.pt:wmt18.model5.pt',
         
     | 
| 29 | 
         
            +
                tokenizer='moses', bpe='subword_nmt')
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            # The ensemble contains 5 models
         
     | 
| 32 | 
         
            +
            len(en2de_ensemble.models)
         
     | 
| 33 | 
         
            +
            # 5
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            # Translate
         
     | 
| 36 | 
         
            +
            en2de_ensemble.translate('Hello world!')
         
     | 
| 37 | 
         
            +
            # 'Hallo Welt!'
         
     | 
| 38 | 
         
            +
            ```
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            ## Training your own model (WMT'18 English-German)
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            The following instructions can be adapted to reproduce the models from the paper.
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            #### Step 1. Prepare parallel data and optionally train a baseline (English-German) model
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            First download and preprocess the data:
         
     | 
| 48 | 
         
            +
            ```bash
         
     | 
| 49 | 
         
            +
            # Download and prepare the data
         
     | 
| 50 | 
         
            +
            cd examples/backtranslation/
         
     | 
| 51 | 
         
            +
            bash prepare-wmt18en2de.sh
         
     | 
| 52 | 
         
            +
            cd ../..
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            # Binarize the data
         
     | 
| 55 | 
         
            +
            TEXT=examples/backtranslation/wmt18_en_de
         
     | 
| 56 | 
         
            +
            fairseq-preprocess \
         
     | 
| 57 | 
         
            +
                --joined-dictionary \
         
     | 
| 58 | 
         
            +
                --source-lang en --target-lang de \
         
     | 
| 59 | 
         
            +
                --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
         
     | 
| 60 | 
         
            +
                --destdir data-bin/wmt18_en_de --thresholdtgt 0 --thresholdsrc 0 \
         
     | 
| 61 | 
         
            +
                --workers 20
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            # Copy the BPE code into the data-bin directory for future use
         
     | 
| 64 | 
         
            +
            cp examples/backtranslation/wmt18_en_de/code data-bin/wmt18_en_de/code
         
     | 
| 65 | 
         
            +
            ```
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            (Optionally) Train a baseline model (English-German) using just the parallel data:
         
     | 
| 68 | 
         
            +
            ```bash
         
     | 
| 69 | 
         
            +
            CHECKPOINT_DIR=checkpoints_en_de_parallel
         
     | 
| 70 | 
         
            +
            fairseq-train --fp16 \
         
     | 
| 71 | 
         
            +
                data-bin/wmt18_en_de \
         
     | 
| 72 | 
         
            +
                --source-lang en --target-lang de \
         
     | 
| 73 | 
         
            +
                --arch transformer_wmt_en_de_big --share-all-embeddings \
         
     | 
| 74 | 
         
            +
                --dropout 0.3 --weight-decay 0.0 \
         
     | 
| 75 | 
         
            +
                --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
         
     | 
| 76 | 
         
            +
                --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
         
     | 
| 77 | 
         
            +
                --lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
         
     | 
| 78 | 
         
            +
                --max-tokens 3584 --update-freq 16 \
         
     | 
| 79 | 
         
            +
                --max-update 30000 \
         
     | 
| 80 | 
         
            +
                --save-dir $CHECKPOINT_DIR
         
     | 
| 81 | 
         
            +
            # Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
         
     | 
| 82 | 
         
            +
            # different number of GPUs.
         
     | 
| 83 | 
         
            +
            ```
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
            Average the last 10 checkpoints:
         
     | 
| 86 | 
         
            +
            ```bash
         
     | 
| 87 | 
         
            +
            python scripts/average_checkpoints.py \
         
     | 
| 88 | 
         
            +
                --inputs $CHECKPOINT_DIR \
         
     | 
| 89 | 
         
            +
                --num-epoch-checkpoints 10 \
         
     | 
| 90 | 
         
            +
                --output $CHECKPOINT_DIR/checkpoint.avg10.pt
         
     | 
| 91 | 
         
            +
            ```
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
            Evaluate BLEU:
         
     | 
| 94 | 
         
            +
            ```bash
         
     | 
| 95 | 
         
            +
            # tokenized BLEU on newstest2017:
         
     | 
| 96 | 
         
            +
            bash examples/backtranslation/tokenized_bleu.sh \
         
     | 
| 97 | 
         
            +
                wmt17 \
         
     | 
| 98 | 
         
            +
                en-de \
         
     | 
| 99 | 
         
            +
                data-bin/wmt18_en_de \
         
     | 
| 100 | 
         
            +
                data-bin/wmt18_en_de/code \
         
     | 
| 101 | 
         
            +
                $CHECKPOINT_DIR/checkpoint.avg10.pt
         
     | 
| 102 | 
         
            +
            # BLEU4 = 29.57, 60.9/35.4/22.9/15.5 (BP=1.000, ratio=1.014, syslen=63049, reflen=62152)
         
     | 
| 103 | 
         
            +
            # compare to 29.46 in Table 1, which is also for tokenized BLEU
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
            # generally it's better to report (detokenized) sacrebleu though:
         
     | 
| 106 | 
         
            +
            bash examples/backtranslation/sacrebleu.sh \
         
     | 
| 107 | 
         
            +
                wmt17 \
         
     | 
| 108 | 
         
            +
                en-de \
         
     | 
| 109 | 
         
            +
                data-bin/wmt18_en_de \
         
     | 
| 110 | 
         
            +
                data-bin/wmt18_en_de/code \
         
     | 
| 111 | 
         
            +
                $CHECKPOINT_DIR/checkpoint.avg10.pt
         
     | 
| 112 | 
         
            +
            # BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 29.0 60.6/34.7/22.4/14.9 (BP = 1.000 ratio = 1.013 hyp_len = 62099 ref_len = 61287)
         
     | 
| 113 | 
         
            +
            ```
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
            #### Step 2. Back-translate monolingual German data
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
            Train a reverse model (German-English) to do the back-translation:
         
     | 
| 119 | 
         
            +
            ```bash
         
     | 
| 120 | 
         
            +
            CHECKPOINT_DIR=checkpoints_de_en_parallel
         
     | 
| 121 | 
         
            +
            fairseq-train --fp16 \
         
     | 
| 122 | 
         
            +
                data-bin/wmt18_en_de \
         
     | 
| 123 | 
         
            +
                --source-lang de --target-lang en \
         
     | 
| 124 | 
         
            +
                --arch transformer_wmt_en_de_big --share-all-embeddings \
         
     | 
| 125 | 
         
            +
                --dropout 0.3 --weight-decay 0.0 \
         
     | 
| 126 | 
         
            +
                --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
         
     | 
| 127 | 
         
            +
                --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
         
     | 
| 128 | 
         
            +
                --lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
         
     | 
| 129 | 
         
            +
                --max-tokens 3584 --update-freq 16 \
         
     | 
| 130 | 
         
            +
                --max-update 30000 \
         
     | 
| 131 | 
         
            +
                --save-dir $CHECKPOINT_DIR
         
     | 
| 132 | 
         
            +
            # Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
         
     | 
| 133 | 
         
            +
            # different number of GPUs.
         
     | 
| 134 | 
         
            +
            ```
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
            Let's evaluate the back-translation (BT) model to make sure it is well trained:
         
     | 
| 137 | 
         
            +
            ```bash
         
     | 
| 138 | 
         
            +
            bash examples/backtranslation/sacrebleu.sh \
         
     | 
| 139 | 
         
            +
                wmt17 \
         
     | 
| 140 | 
         
            +
                de-en \
         
     | 
| 141 | 
         
            +
                data-bin/wmt18_en_de \
         
     | 
| 142 | 
         
            +
                data-bin/wmt18_en_de/code \
         
     | 
| 143 | 
         
            +
                $CHECKPOINT_DIR/checkpoint_best.py
         
     | 
| 144 | 
         
            +
            # BLEU+case.mixed+lang.de-en+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 34.9 66.9/41.8/28.5/19.9 (BP = 0.983 ratio = 0.984 hyp_len = 63342 ref_len = 64399)
         
     | 
| 145 | 
         
            +
            # compare to the best system from WMT'17 which scored 35.1: http://matrix.statmt.org/matrix/systems_list/1868
         
     | 
| 146 | 
         
            +
            ```
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
            Next prepare the monolingual data:
         
     | 
| 149 | 
         
            +
            ```bash
         
     | 
| 150 | 
         
            +
            # Download and prepare the monolingual data
         
     | 
| 151 | 
         
            +
            # By default the script samples 25M monolingual sentences, which after
         
     | 
| 152 | 
         
            +
            # deduplication should be just over 24M sentences. These are split into 25
         
     | 
| 153 | 
         
            +
            # shards, each with 1M sentences (except for the last shard).
         
     | 
| 154 | 
         
            +
            cd examples/backtranslation/
         
     | 
| 155 | 
         
            +
            bash prepare-de-monolingual.sh
         
     | 
| 156 | 
         
            +
            cd ../..
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
            # Binarize each shard of the monolingual data
         
     | 
| 159 | 
         
            +
            TEXT=examples/backtranslation/wmt18_de_mono
         
     | 
| 160 | 
         
            +
            for SHARD in $(seq -f "%02g" 0 24); do \
         
     | 
| 161 | 
         
            +
                fairseq-preprocess \
         
     | 
| 162 | 
         
            +
                    --only-source \
         
     | 
| 163 | 
         
            +
                    --source-lang de --target-lang en \
         
     | 
| 164 | 
         
            +
                    --joined-dictionary \
         
     | 
| 165 | 
         
            +
                    --srcdict data-bin/wmt18_en_de/dict.de.txt \
         
     | 
| 166 | 
         
            +
                    --testpref $TEXT/bpe.monolingual.dedup.${SHARD} \
         
     | 
| 167 | 
         
            +
                    --destdir data-bin/wmt18_de_mono/shard${SHARD} \
         
     | 
| 168 | 
         
            +
                    --workers 20; \
         
     | 
| 169 | 
         
            +
                cp data-bin/wmt18_en_de/dict.en.txt data-bin/wmt18_de_mono/shard${SHARD}/; \
         
     | 
| 170 | 
         
            +
            done
         
     | 
| 171 | 
         
            +
            ```
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
            Now we're ready to perform back-translation over the monolingual data. The
         
     | 
| 174 | 
         
            +
            following command generates via sampling, but it's possible to use greedy
         
     | 
| 175 | 
         
            +
            decoding (`--beam 1`), beam search (`--beam 5`),
         
     | 
| 176 | 
         
            +
            top-k sampling (`--sampling --beam 1 --sampling-topk 10`), etc.:
         
     | 
| 177 | 
         
            +
            ```bash
         
     | 
| 178 | 
         
            +
            mkdir backtranslation_output
         
     | 
| 179 | 
         
            +
            for SHARD in $(seq -f "%02g" 0 24); do \
         
     | 
| 180 | 
         
            +
                fairseq-generate --fp16 \
         
     | 
| 181 | 
         
            +
                    data-bin/wmt18_de_mono/shard${SHARD} \
         
     | 
| 182 | 
         
            +
                    --path $CHECKPOINT_DIR/checkpoint_best.pt \
         
     | 
| 183 | 
         
            +
                    --skip-invalid-size-inputs-valid-test \
         
     | 
| 184 | 
         
            +
                    --max-tokens 4096 \
         
     | 
| 185 | 
         
            +
                    --sampling --beam 1 \
         
     | 
| 186 | 
         
            +
                > backtranslation_output/sampling.shard${SHARD}.out; \
         
     | 
| 187 | 
         
            +
            done
         
     | 
| 188 | 
         
            +
            ```
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
            After BT, use the `extract_bt_data.py` script to re-combine the shards, extract
         
     | 
| 191 | 
         
            +
            the back-translations and apply length ratio filters:
         
     | 
| 192 | 
         
            +
            ```bash
         
     | 
| 193 | 
         
            +
            python examples/backtranslation/extract_bt_data.py \
         
     | 
| 194 | 
         
            +
                --minlen 1 --maxlen 250 --ratio 1.5 \
         
     | 
| 195 | 
         
            +
                --output backtranslation_output/bt_data --srclang en --tgtlang de \
         
     | 
| 196 | 
         
            +
                backtranslation_output/sampling.shard*.out
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
            # Ensure lengths are the same:
         
     | 
| 199 | 
         
            +
            # wc -l backtranslation_output/bt_data.{en,de}
         
     | 
| 200 | 
         
            +
            #   21795614 backtranslation_output/bt_data.en
         
     | 
| 201 | 
         
            +
            #   21795614 backtranslation_output/bt_data.de
         
     | 
| 202 | 
         
            +
            #   43591228 total
         
     | 
| 203 | 
         
            +
            ```
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
            Binarize the filtered BT data and combine it with the parallel data:
         
     | 
| 206 | 
         
            +
            ```bash
         
     | 
| 207 | 
         
            +
            TEXT=backtranslation_output
         
     | 
| 208 | 
         
            +
            fairseq-preprocess \
         
     | 
| 209 | 
         
            +
                --source-lang en --target-lang de \
         
     | 
| 210 | 
         
            +
                --joined-dictionary \
         
     | 
| 211 | 
         
            +
                --srcdict data-bin/wmt18_en_de/dict.en.txt \
         
     | 
| 212 | 
         
            +
                --trainpref $TEXT/bt_data \
         
     | 
| 213 | 
         
            +
                --destdir data-bin/wmt18_en_de_bt \
         
     | 
| 214 | 
         
            +
                --workers 20
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
            # We want to train on the combined data, so we'll symlink the parallel + BT data
         
     | 
| 217 | 
         
            +
            # in the wmt18_en_de_para_plus_bt directory. We link the parallel data as "train"
         
     | 
| 218 | 
         
            +
            # and the BT data as "train1", so that fairseq will combine them automatically
         
     | 
| 219 | 
         
            +
            # and so that we can use the `--upsample-primary` option to upsample the
         
     | 
| 220 | 
         
            +
            # parallel data (if desired).
         
     | 
| 221 | 
         
            +
            PARA_DATA=$(readlink -f data-bin/wmt18_en_de)
         
     | 
| 222 | 
         
            +
            BT_DATA=$(readlink -f data-bin/wmt18_en_de_bt)
         
     | 
| 223 | 
         
            +
            COMB_DATA=data-bin/wmt18_en_de_para_plus_bt
         
     | 
| 224 | 
         
            +
            mkdir -p $COMB_DATA
         
     | 
| 225 | 
         
            +
            for LANG in en de; do \
         
     | 
| 226 | 
         
            +
                ln -s ${PARA_DATA}/dict.$LANG.txt ${COMB_DATA}/dict.$LANG.txt; \
         
     | 
| 227 | 
         
            +
                for EXT in bin idx; do \
         
     | 
| 228 | 
         
            +
                    ln -s ${PARA_DATA}/train.en-de.$LANG.$EXT ${COMB_DATA}/train.en-de.$LANG.$EXT; \
         
     | 
| 229 | 
         
            +
                    ln -s ${BT_DATA}/train.en-de.$LANG.$EXT ${COMB_DATA}/train1.en-de.$LANG.$EXT; \
         
     | 
| 230 | 
         
            +
                    ln -s ${PARA_DATA}/valid.en-de.$LANG.$EXT ${COMB_DATA}/valid.en-de.$LANG.$EXT; \
         
     | 
| 231 | 
         
            +
                    ln -s ${PARA_DATA}/test.en-de.$LANG.$EXT ${COMB_DATA}/test.en-de.$LANG.$EXT; \
         
     | 
| 232 | 
         
            +
                done; \
         
     | 
| 233 | 
         
            +
            done
         
     | 
| 234 | 
         
            +
            ```
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
            #### 3. Train an English-German model over the combined parallel + BT data
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
            Finally we can train a model over the parallel + BT data:
         
     | 
| 240 | 
         
            +
            ```bash
         
     | 
| 241 | 
         
            +
            CHECKPOINT_DIR=checkpoints_en_de_parallel_plus_bt
         
     | 
| 242 | 
         
            +
            fairseq-train --fp16 \
         
     | 
| 243 | 
         
            +
                data-bin/wmt18_en_de_para_plus_bt \
         
     | 
| 244 | 
         
            +
                --upsample-primary 16 \
         
     | 
| 245 | 
         
            +
                --source-lang en --target-lang de \
         
     | 
| 246 | 
         
            +
                --arch transformer_wmt_en_de_big --share-all-embeddings \
         
     | 
| 247 | 
         
            +
                --dropout 0.3 --weight-decay 0.0 \
         
     | 
| 248 | 
         
            +
                --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
         
     | 
| 249 | 
         
            +
                --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
         
     | 
| 250 | 
         
            +
                --lr 0.0007 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
         
     | 
| 251 | 
         
            +
                --max-tokens 3584 --update-freq 16 \
         
     | 
| 252 | 
         
            +
                --max-update 100000 \
         
     | 
| 253 | 
         
            +
                --save-dir $CHECKPOINT_DIR
         
     | 
| 254 | 
         
            +
            # Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
         
     | 
| 255 | 
         
            +
            # different number of GPUs.
         
     | 
| 256 | 
         
            +
            ```
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
            Average the last 10 checkpoints:
         
     | 
| 259 | 
         
            +
            ```bash
         
     | 
| 260 | 
         
            +
            python scripts/average_checkpoints.py \
         
     | 
| 261 | 
         
            +
                --inputs $CHECKPOINT_DIR \
         
     | 
| 262 | 
         
            +
                --num-epoch-checkpoints 10 \
         
     | 
| 263 | 
         
            +
                --output $CHECKPOINT_DIR/checkpoint.avg10.pt
         
     | 
| 264 | 
         
            +
            ```
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
            Evaluate BLEU:
         
     | 
| 267 | 
         
            +
            ```bash
         
     | 
| 268 | 
         
            +
            # tokenized BLEU on newstest2017:
         
     | 
| 269 | 
         
            +
            bash examples/backtranslation/tokenized_bleu.sh \
         
     | 
| 270 | 
         
            +
                wmt17 \
         
     | 
| 271 | 
         
            +
                en-de \
         
     | 
| 272 | 
         
            +
                data-bin/wmt18_en_de \
         
     | 
| 273 | 
         
            +
                data-bin/wmt18_en_de/code \
         
     | 
| 274 | 
         
            +
                $CHECKPOINT_DIR/checkpoint.avg10.pt
         
     | 
| 275 | 
         
            +
            # BLEU4 = 32.35, 64.4/38.9/26.2/18.3 (BP=0.977, ratio=0.977, syslen=60729, reflen=62152)
         
     | 
| 276 | 
         
            +
            # compare to 32.35 in Table 1, which is also for tokenized BLEU
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
            # generally it's better to report (detokenized) sacrebleu:
         
     | 
| 279 | 
         
            +
            bash examples/backtranslation/sacrebleu.sh \
         
     | 
| 280 | 
         
            +
                wmt17 \
         
     | 
| 281 | 
         
            +
                en-de \
         
     | 
| 282 | 
         
            +
                data-bin/wmt18_en_de \
         
     | 
| 283 | 
         
            +
                data-bin/wmt18_en_de/code \
         
     | 
| 284 | 
         
            +
                $CHECKPOINT_DIR/checkpoint.avg10.pt
         
     | 
| 285 | 
         
            +
            # BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 31.5 64.3/38.2/25.6/17.6 (BP = 0.971 ratio = 0.971 hyp_len = 59515 ref_len = 61287)
         
     | 
| 286 | 
         
            +
            ```
         
     | 
| 287 | 
         
            +
             
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
            ## Citation
         
     | 
| 290 | 
         
            +
            ```bibtex
         
     | 
| 291 | 
         
            +
            @inproceedings{edunov2018backtranslation,
         
     | 
| 292 | 
         
            +
              title = {Understanding Back-Translation at Scale},
         
     | 
| 293 | 
         
            +
              author = {Edunov, Sergey and Ott, Myle and Auli, Michael and Grangier, David},
         
     | 
| 294 | 
         
            +
              booktitle = {Conference of the Association for Computational Linguistics (ACL)},
         
     | 
| 295 | 
         
            +
              year = 2018,
         
     | 
| 296 | 
         
            +
            }
         
     | 
| 297 | 
         
            +
            ```
         
     | 
    	
        fairseq/examples/backtranslation/deduplicate_lines.py
    ADDED
    
    | 
         @@ -0,0 +1,41 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            #!/usr/bin/python3
         
     | 
| 2 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 3 | 
         
            +
            #
         
     | 
| 4 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 5 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import argparse
         
     | 
| 8 | 
         
            +
            import fileinput
         
     | 
| 9 | 
         
            +
            import hashlib
         
     | 
| 10 | 
         
            +
            import sys
         
     | 
| 11 | 
         
            +
            from multiprocessing import Pool
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            def get_hashes_and_lines(raw_line):
         
     | 
| 15 | 
         
            +
                hash = hashlib.md5(raw_line).hexdigest()
         
     | 
| 16 | 
         
            +
                return hash, raw_line
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            def main():
         
     | 
| 20 | 
         
            +
                parser = argparse.ArgumentParser()
         
     | 
| 21 | 
         
            +
                parser.add_argument("--workers", type=int, default=10)
         
     | 
| 22 | 
         
            +
                parser.add_argument("files", nargs="*", help="input files")
         
     | 
| 23 | 
         
            +
                args = parser.parse_args()
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                seen = set()
         
     | 
| 26 | 
         
            +
                with fileinput.input(args.files, mode="rb") as h:
         
     | 
| 27 | 
         
            +
                    pool = Pool(args.workers)
         
     | 
| 28 | 
         
            +
                    results = pool.imap_unordered(get_hashes_and_lines, h, 1000)
         
     | 
| 29 | 
         
            +
                    for i, (hash, raw_line) in enumerate(results):
         
     | 
| 30 | 
         
            +
                        if hash not in seen:
         
     | 
| 31 | 
         
            +
                            seen.add(hash)
         
     | 
| 32 | 
         
            +
                            sys.stdout.buffer.write(raw_line)
         
     | 
| 33 | 
         
            +
                        if i % 1000000 == 0:
         
     | 
| 34 | 
         
            +
                            print(i, file=sys.stderr, end="", flush=True)
         
     | 
| 35 | 
         
            +
                        elif i % 100000 == 0:
         
     | 
| 36 | 
         
            +
                            print(".", file=sys.stderr, end="", flush=True)
         
     | 
| 37 | 
         
            +
                print(file=sys.stderr, flush=True)
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 41 | 
         
            +
                main()
         
     | 
    	
        fairseq/examples/backtranslation/extract_bt_data.py
    ADDED
    
    | 
         @@ -0,0 +1,72 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            #!/usr/bin/env python
         
     | 
| 2 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 3 | 
         
            +
            #
         
     | 
| 4 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 5 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import argparse
         
     | 
| 8 | 
         
            +
            import fileinput
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            from tqdm import tqdm
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            def main():
         
     | 
| 14 | 
         
            +
                parser = argparse.ArgumentParser(
         
     | 
| 15 | 
         
            +
                    description=(
         
     | 
| 16 | 
         
            +
                        "Extract back-translations from the stdout of fairseq-generate. "
         
     | 
| 17 | 
         
            +
                        "If there are multiply hypotheses for a source, we only keep the first one. "
         
     | 
| 18 | 
         
            +
                    )
         
     | 
| 19 | 
         
            +
                )
         
     | 
| 20 | 
         
            +
                parser.add_argument("--output", required=True, help="output prefix")
         
     | 
| 21 | 
         
            +
                parser.add_argument(
         
     | 
| 22 | 
         
            +
                    "--srclang", required=True, help="source language (extracted from H-* lines)"
         
     | 
| 23 | 
         
            +
                )
         
     | 
| 24 | 
         
            +
                parser.add_argument(
         
     | 
| 25 | 
         
            +
                    "--tgtlang", required=True, help="target language (extracted from S-* lines)"
         
     | 
| 26 | 
         
            +
                )
         
     | 
| 27 | 
         
            +
                parser.add_argument("--minlen", type=int, help="min length filter")
         
     | 
| 28 | 
         
            +
                parser.add_argument("--maxlen", type=int, help="max length filter")
         
     | 
| 29 | 
         
            +
                parser.add_argument("--ratio", type=float, help="ratio filter")
         
     | 
| 30 | 
         
            +
                parser.add_argument("files", nargs="*", help="input files")
         
     | 
| 31 | 
         
            +
                args = parser.parse_args()
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                def validate(src, tgt):
         
     | 
| 34 | 
         
            +
                    srclen = len(src.split(" ")) if src != "" else 0
         
     | 
| 35 | 
         
            +
                    tgtlen = len(tgt.split(" ")) if tgt != "" else 0
         
     | 
| 36 | 
         
            +
                    if (
         
     | 
| 37 | 
         
            +
                        (args.minlen is not None and (srclen < args.minlen or tgtlen < args.minlen))
         
     | 
| 38 | 
         
            +
                        or (
         
     | 
| 39 | 
         
            +
                            args.maxlen is not None
         
     | 
| 40 | 
         
            +
                            and (srclen > args.maxlen or tgtlen > args.maxlen)
         
     | 
| 41 | 
         
            +
                        )
         
     | 
| 42 | 
         
            +
                        or (
         
     | 
| 43 | 
         
            +
                            args.ratio is not None
         
     | 
| 44 | 
         
            +
                            and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio)
         
     | 
| 45 | 
         
            +
                        )
         
     | 
| 46 | 
         
            +
                    ):
         
     | 
| 47 | 
         
            +
                        return False
         
     | 
| 48 | 
         
            +
                    return True
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                def safe_index(toks, index, default):
         
     | 
| 51 | 
         
            +
                    try:
         
     | 
| 52 | 
         
            +
                        return toks[index]
         
     | 
| 53 | 
         
            +
                    except IndexError:
         
     | 
| 54 | 
         
            +
                        return default
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                with open(args.output + "." + args.srclang, "w") as src_h, open(
         
     | 
| 57 | 
         
            +
                    args.output + "." + args.tgtlang, "w"
         
     | 
| 58 | 
         
            +
                ) as tgt_h:
         
     | 
| 59 | 
         
            +
                    for line in tqdm(fileinput.input(args.files)):
         
     | 
| 60 | 
         
            +
                        if line.startswith("S-"):
         
     | 
| 61 | 
         
            +
                            tgt = safe_index(line.rstrip().split("\t"), 1, "")
         
     | 
| 62 | 
         
            +
                        elif line.startswith("H-"):
         
     | 
| 63 | 
         
            +
                            if tgt is not None:
         
     | 
| 64 | 
         
            +
                                src = safe_index(line.rstrip().split("\t"), 2, "")
         
     | 
| 65 | 
         
            +
                                if validate(src, tgt):
         
     | 
| 66 | 
         
            +
                                    print(src, file=src_h)
         
     | 
| 67 | 
         
            +
                                    print(tgt, file=tgt_h)
         
     | 
| 68 | 
         
            +
                                tgt = None
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 72 | 
         
            +
                main()
         
     | 
    	
        fairseq/examples/backtranslation/prepare-de-monolingual.sh
    ADDED
    
    | 
         @@ -0,0 +1,98 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            #!/bin/bash
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            SCRIPTS=mosesdecoder/scripts
         
     | 
| 4 | 
         
            +
            TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
         
     | 
| 5 | 
         
            +
            NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl
         
     | 
| 6 | 
         
            +
            REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl
         
     | 
| 7 | 
         
            +
            BPEROOT=subword-nmt/subword_nmt
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            BPE_CODE=wmt18_en_de/code
         
     | 
| 11 | 
         
            +
            SUBSAMPLE_SIZE=25000000
         
     | 
| 12 | 
         
            +
            LANG=de
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            OUTDIR=wmt18_${LANG}_mono
         
     | 
| 16 | 
         
            +
            orig=orig
         
     | 
| 17 | 
         
            +
            tmp=$OUTDIR/tmp
         
     | 
| 18 | 
         
            +
            mkdir -p $OUTDIR $tmp
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            URLS=(
         
     | 
| 22 | 
         
            +
                "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2007.de.shuffled.gz"
         
     | 
| 23 | 
         
            +
                "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2008.de.shuffled.gz"
         
     | 
| 24 | 
         
            +
                "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2009.de.shuffled.gz"
         
     | 
| 25 | 
         
            +
                "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2010.de.shuffled.gz"
         
     | 
| 26 | 
         
            +
                "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2011.de.shuffled.gz"
         
     | 
| 27 | 
         
            +
                "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2012.de.shuffled.gz"
         
     | 
| 28 | 
         
            +
                "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2013.de.shuffled.gz"
         
     | 
| 29 | 
         
            +
                "http://www.statmt.org/wmt15/training-monolingual-news-crawl-v2/news.2014.de.shuffled.v2.gz"
         
     | 
| 30 | 
         
            +
                "http://data.statmt.org/wmt16/translation-task/news.2015.de.shuffled.gz"
         
     | 
| 31 | 
         
            +
                "http://data.statmt.org/wmt17/translation-task/news.2016.de.shuffled.gz"
         
     | 
| 32 | 
         
            +
                "http://data.statmt.org/wmt18/translation-task/news.2017.de.shuffled.deduped.gz"
         
     | 
| 33 | 
         
            +
            )
         
     | 
| 34 | 
         
            +
            FILES=(
         
     | 
| 35 | 
         
            +
                "news.2007.de.shuffled.gz"
         
     | 
| 36 | 
         
            +
                "news.2008.de.shuffled.gz"
         
     | 
| 37 | 
         
            +
                "news.2009.de.shuffled.gz"
         
     | 
| 38 | 
         
            +
                "news.2010.de.shuffled.gz"
         
     | 
| 39 | 
         
            +
                "news.2011.de.shuffled.gz"
         
     | 
| 40 | 
         
            +
                "news.2012.de.shuffled.gz"
         
     | 
| 41 | 
         
            +
                "news.2013.de.shuffled.gz"
         
     | 
| 42 | 
         
            +
                "news.2014.de.shuffled.v2.gz"
         
     | 
| 43 | 
         
            +
                "news.2015.de.shuffled.gz"
         
     | 
| 44 | 
         
            +
                "news.2016.de.shuffled.gz"
         
     | 
| 45 | 
         
            +
                "news.2017.de.shuffled.deduped.gz"
         
     | 
| 46 | 
         
            +
            )
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
            cd $orig
         
     | 
| 50 | 
         
            +
            for ((i=0;i<${#URLS[@]};++i)); do
         
     | 
| 51 | 
         
            +
                file=${FILES[i]}
         
     | 
| 52 | 
         
            +
                if [ -f $file ]; then
         
     | 
| 53 | 
         
            +
                    echo "$file already exists, skipping download"
         
     | 
| 54 | 
         
            +
                else
         
     | 
| 55 | 
         
            +
                    url=${URLS[i]}
         
     | 
| 56 | 
         
            +
                    wget "$url"
         
     | 
| 57 | 
         
            +
                fi
         
     | 
| 58 | 
         
            +
            done
         
     | 
| 59 | 
         
            +
            cd ..
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
            if [ -f $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG} ]; then
         
     | 
| 63 | 
         
            +
                echo "found monolingual sample, skipping shuffle/sample/tokenize"
         
     | 
| 64 | 
         
            +
            else
         
     | 
| 65 | 
         
            +
                gzip -c -d -k $(for FILE in "${FILES[@]}"; do echo $orig/$FILE; done) \
         
     | 
| 66 | 
         
            +
                | shuf -n $SUBSAMPLE_SIZE \
         
     | 
| 67 | 
         
            +
                | perl $NORM_PUNC $LANG \
         
     | 
| 68 | 
         
            +
                | perl $REM_NON_PRINT_CHAR \
         
     | 
| 69 | 
         
            +
                | perl $TOKENIZER -threads 8 -a -l $LANG \
         
     | 
| 70 | 
         
            +
                > $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG}
         
     | 
| 71 | 
         
            +
            fi
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
            if [ -f $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG} ]; then
         
     | 
| 75 | 
         
            +
                echo "found BPE monolingual sample, skipping BPE step"
         
     | 
| 76 | 
         
            +
            else
         
     | 
| 77 | 
         
            +
                python $BPEROOT/apply_bpe.py -c $BPE_CODE \
         
     | 
| 78 | 
         
            +
                    < $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG} \
         
     | 
| 79 | 
         
            +
                    > $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG}
         
     | 
| 80 | 
         
            +
            fi
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
            if [ -f $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG} ]; then
         
     | 
| 84 | 
         
            +
                echo "found deduplicated monolingual sample, skipping deduplication step"
         
     | 
| 85 | 
         
            +
            else
         
     | 
| 86 | 
         
            +
                python deduplicate_lines.py $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG} \
         
     | 
| 87 | 
         
            +
                > $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG}
         
     | 
| 88 | 
         
            +
            fi
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
            if [ -f $OUTDIR/bpe.monolingual.dedup.00.de ]; then
         
     | 
| 92 | 
         
            +
                echo "found sharded data, skipping sharding step"
         
     | 
| 93 | 
         
            +
            else
         
     | 
| 94 | 
         
            +
                split --lines 1000000 --numeric-suffixes \
         
     | 
| 95 | 
         
            +
                    --additional-suffix .${LANG} \
         
     | 
| 96 | 
         
            +
                    $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG} \
         
     | 
| 97 | 
         
            +
                    $OUTDIR/bpe.monolingual.dedup.
         
     | 
| 98 | 
         
            +
            fi
         
     | 
    	
        fairseq/examples/backtranslation/prepare-wmt18en2de.sh
    ADDED
    
    | 
         @@ -0,0 +1,135 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            #!/bin/bash
         
     | 
| 2 | 
         
            +
            # Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            echo 'Cloning Moses github repository (for tokenization scripts)...'
         
     | 
| 5 | 
         
            +
            git clone https://github.com/moses-smt/mosesdecoder.git
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
         
     | 
| 8 | 
         
            +
            git clone https://github.com/rsennrich/subword-nmt.git
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            SCRIPTS=mosesdecoder/scripts
         
     | 
| 11 | 
         
            +
            TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
         
     | 
| 12 | 
         
            +
            CLEAN=$SCRIPTS/training/clean-corpus-n.perl
         
     | 
| 13 | 
         
            +
            NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl
         
     | 
| 14 | 
         
            +
            REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl
         
     | 
| 15 | 
         
            +
            BPEROOT=subword-nmt/subword_nmt
         
     | 
| 16 | 
         
            +
            BPE_TOKENS=32000
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            URLS=(
         
     | 
| 19 | 
         
            +
                "http://statmt.org/wmt13/training-parallel-europarl-v7.tgz"
         
     | 
| 20 | 
         
            +
                "http://statmt.org/wmt13/training-parallel-commoncrawl.tgz"
         
     | 
| 21 | 
         
            +
                "http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz"
         
     | 
| 22 | 
         
            +
                "http://data.statmt.org/wmt18/translation-task/rapid2016.tgz"
         
     | 
| 23 | 
         
            +
                "http://data.statmt.org/wmt17/translation-task/dev.tgz"
         
     | 
| 24 | 
         
            +
                "http://statmt.org/wmt14/test-full.tgz"
         
     | 
| 25 | 
         
            +
            )
         
     | 
| 26 | 
         
            +
            FILES=(
         
     | 
| 27 | 
         
            +
                "training-parallel-europarl-v7.tgz"
         
     | 
| 28 | 
         
            +
                "training-parallel-commoncrawl.tgz"
         
     | 
| 29 | 
         
            +
                "training-parallel-nc-v13.tgz"
         
     | 
| 30 | 
         
            +
                "rapid2016.tgz"
         
     | 
| 31 | 
         
            +
                "dev.tgz"
         
     | 
| 32 | 
         
            +
                "test-full.tgz"
         
     | 
| 33 | 
         
            +
            )
         
     | 
| 34 | 
         
            +
            CORPORA=(
         
     | 
| 35 | 
         
            +
                "training/europarl-v7.de-en"
         
     | 
| 36 | 
         
            +
                "commoncrawl.de-en"
         
     | 
| 37 | 
         
            +
                "training-parallel-nc-v13/news-commentary-v13.de-en"
         
     | 
| 38 | 
         
            +
                "rapid2016.de-en"
         
     | 
| 39 | 
         
            +
            )
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            if [ ! -d "$SCRIPTS" ]; then
         
     | 
| 42 | 
         
            +
                echo "Please set SCRIPTS variable correctly to point to Moses scripts."
         
     | 
| 43 | 
         
            +
                exit 1
         
     | 
| 44 | 
         
            +
            fi
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            OUTDIR=wmt18_en_de
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            src=en
         
     | 
| 49 | 
         
            +
            tgt=de
         
     | 
| 50 | 
         
            +
            lang=en-de
         
     | 
| 51 | 
         
            +
            prep=$OUTDIR
         
     | 
| 52 | 
         
            +
            tmp=$prep/tmp
         
     | 
| 53 | 
         
            +
            orig=orig
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
            mkdir -p $orig $tmp $prep
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
            cd $orig
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            for ((i=0;i<${#URLS[@]};++i)); do
         
     | 
| 60 | 
         
            +
                file=${FILES[i]}
         
     | 
| 61 | 
         
            +
                if [ -f $file ]; then
         
     | 
| 62 | 
         
            +
                    echo "$file already exists, skipping download"
         
     | 
| 63 | 
         
            +
                else
         
     | 
| 64 | 
         
            +
                    url=${URLS[i]}
         
     | 
| 65 | 
         
            +
                    wget "$url"
         
     | 
| 66 | 
         
            +
                    if [ -f $file ]; then
         
     | 
| 67 | 
         
            +
                        echo "$url successfully downloaded."
         
     | 
| 68 | 
         
            +
                    else
         
     | 
| 69 | 
         
            +
                        echo "$url not successfully downloaded."
         
     | 
| 70 | 
         
            +
                        exit 1
         
     | 
| 71 | 
         
            +
                    fi
         
     | 
| 72 | 
         
            +
                    if [ ${file: -4} == ".tgz" ]; then
         
     | 
| 73 | 
         
            +
                        tar zxvf $file
         
     | 
| 74 | 
         
            +
                    elif [ ${file: -4} == ".tar" ]; then
         
     | 
| 75 | 
         
            +
                        tar xvf $file
         
     | 
| 76 | 
         
            +
                    fi
         
     | 
| 77 | 
         
            +
                fi
         
     | 
| 78 | 
         
            +
            done
         
     | 
| 79 | 
         
            +
            cd ..
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
            echo "pre-processing train data..."
         
     | 
| 82 | 
         
            +
            for l in $src $tgt; do
         
     | 
| 83 | 
         
            +
                rm $tmp/train.tags.$lang.tok.$l
         
     | 
| 84 | 
         
            +
                for f in "${CORPORA[@]}"; do
         
     | 
| 85 | 
         
            +
                    cat $orig/$f.$l | \
         
     | 
| 86 | 
         
            +
                        perl $NORM_PUNC $l | \
         
     | 
| 87 | 
         
            +
                        perl $REM_NON_PRINT_CHAR | \
         
     | 
| 88 | 
         
            +
                        perl $TOKENIZER -threads 8 -a -l $l >> $tmp/train.tags.$lang.tok.$l
         
     | 
| 89 | 
         
            +
                done
         
     | 
| 90 | 
         
            +
            done
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
            echo "pre-processing test data..."
         
     | 
| 93 | 
         
            +
            for l in $src $tgt; do
         
     | 
| 94 | 
         
            +
                if [ "$l" == "$src" ]; then
         
     | 
| 95 | 
         
            +
                    t="src"
         
     | 
| 96 | 
         
            +
                else
         
     | 
| 97 | 
         
            +
                    t="ref"
         
     | 
| 98 | 
         
            +
                fi
         
     | 
| 99 | 
         
            +
                grep '<seg id' $orig/test-full/newstest2014-deen-$t.$l.sgm | \
         
     | 
| 100 | 
         
            +
                    sed -e 's/<seg id="[0-9]*">\s*//g' | \
         
     | 
| 101 | 
         
            +
                    sed -e 's/\s*<\/seg>\s*//g' | \
         
     | 
| 102 | 
         
            +
                    sed -e "s/\’/\'/g" | \
         
     | 
| 103 | 
         
            +
                perl $TOKENIZER -threads 8 -a -l $l > $tmp/test.$l
         
     | 
| 104 | 
         
            +
                echo ""
         
     | 
| 105 | 
         
            +
            done
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
            echo "splitting train and valid..."
         
     | 
| 108 | 
         
            +
            for l in $src $tgt; do
         
     | 
| 109 | 
         
            +
                awk '{if (NR%100 == 0)  print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/valid.$l
         
     | 
| 110 | 
         
            +
                awk '{if (NR%100 != 0)  print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/train.$l
         
     | 
| 111 | 
         
            +
            done
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
            TRAIN=$tmp/train.de-en
         
     | 
| 114 | 
         
            +
            BPE_CODE=$prep/code
         
     | 
| 115 | 
         
            +
            rm -f $TRAIN
         
     | 
| 116 | 
         
            +
            for l in $src $tgt; do
         
     | 
| 117 | 
         
            +
                cat $tmp/train.$l >> $TRAIN
         
     | 
| 118 | 
         
            +
            done
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
            echo "learn_bpe.py on ${TRAIN}..."
         
     | 
| 121 | 
         
            +
            python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
            for L in $src $tgt; do
         
     | 
| 124 | 
         
            +
                for f in train.$L valid.$L test.$L; do
         
     | 
| 125 | 
         
            +
                    echo "apply_bpe.py to ${f}..."
         
     | 
| 126 | 
         
            +
                    python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $tmp/bpe.$f
         
     | 
| 127 | 
         
            +
                done
         
     | 
| 128 | 
         
            +
            done
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
            perl $CLEAN -ratio 1.5 $tmp/bpe.train $src $tgt $prep/train 1 250
         
     | 
| 131 | 
         
            +
            perl $CLEAN -ratio 1.5 $tmp/bpe.valid $src $tgt $prep/valid 1 250
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
            for L in $src $tgt; do
         
     | 
| 134 | 
         
            +
                cp $tmp/bpe.test.$L $prep/test.$L
         
     | 
| 135 | 
         
            +
            done
         
     | 
    	
        fairseq/examples/backtranslation/sacrebleu.sh
    ADDED
    
    | 
         @@ -0,0 +1,37 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            #!/bin/bash
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            if [ $# -ne 5 ]; then
         
     | 
| 4 | 
         
            +
                echo "usage: $0 [dataset=wmt14/full] [langpair=en-de] [databin] [bpecode] [model]"
         
     | 
| 5 | 
         
            +
                exit
         
     | 
| 6 | 
         
            +
            fi
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            DATASET=$1
         
     | 
| 10 | 
         
            +
            LANGPAIR=$2
         
     | 
| 11 | 
         
            +
            DATABIN=$3
         
     | 
| 12 | 
         
            +
            BPECODE=$4
         
     | 
| 13 | 
         
            +
            MODEL=$5
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            SRCLANG=$(echo $LANGPAIR | cut -d '-' -f 1)
         
     | 
| 16 | 
         
            +
            TGTLANG=$(echo $LANGPAIR | cut -d '-' -f 2)
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            BPEROOT=examples/backtranslation/subword-nmt/subword_nmt
         
     | 
| 20 | 
         
            +
            if [ ! -e $BPEROOT ]; then
         
     | 
| 21 | 
         
            +
                BPEROOT=subword-nmt/subword_nmt
         
     | 
| 22 | 
         
            +
                if [ ! -e $BPEROOT ]; then
         
     | 
| 23 | 
         
            +
                    echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
         
     | 
| 24 | 
         
            +
                    git clone https://github.com/rsennrich/subword-nmt.git
         
     | 
| 25 | 
         
            +
                fi
         
     | 
| 26 | 
         
            +
            fi
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            sacrebleu -t $DATASET -l $LANGPAIR --echo src \
         
     | 
| 30 | 
         
            +
            | sacremoses tokenize -a -l $SRCLANG -q \
         
     | 
| 31 | 
         
            +
            | python $BPEROOT/apply_bpe.py -c $BPECODE \
         
     | 
| 32 | 
         
            +
            | fairseq-interactive $DATABIN --path $MODEL \
         
     | 
| 33 | 
         
            +
                -s $SRCLANG -t $TGTLANG \
         
     | 
| 34 | 
         
            +
                --beam 5 --remove-bpe --buffer-size 1024 --max-tokens 8000 \
         
     | 
| 35 | 
         
            +
            | grep ^H- | cut -f 3- \
         
     | 
| 36 | 
         
            +
            | sacremoses detokenize -l $TGTLANG -q \
         
     | 
| 37 | 
         
            +
            | sacrebleu -t $DATASET -l $LANGPAIR
         
     | 
    	
        fairseq/examples/backtranslation/tokenized_bleu.sh
    ADDED
    
    | 
         @@ -0,0 +1,46 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            #!/bin/bash
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            if [ $# -ne 5 ]; then
         
     | 
| 4 | 
         
            +
                echo "usage: $0 [dataset=wmt14/full] [langpair=en-de] [databin] [bpecode] [model]"
         
     | 
| 5 | 
         
            +
                exit
         
     | 
| 6 | 
         
            +
            fi
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            DATASET=$1
         
     | 
| 10 | 
         
            +
            LANGPAIR=$2
         
     | 
| 11 | 
         
            +
            DATABIN=$3
         
     | 
| 12 | 
         
            +
            BPECODE=$4
         
     | 
| 13 | 
         
            +
            MODEL=$5
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            SRCLANG=$(echo $LANGPAIR | cut -d '-' -f 1)
         
     | 
| 16 | 
         
            +
            TGTLANG=$(echo $LANGPAIR | cut -d '-' -f 2)
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            BPEROOT=examples/backtranslation/subword-nmt/subword_nmt
         
     | 
| 20 | 
         
            +
            if [ ! -e $BPEROOT ]; then
         
     | 
| 21 | 
         
            +
                BPEROOT=subword-nmt/subword_nmt
         
     | 
| 22 | 
         
            +
                if [ ! -e $BPEROOT ]; then
         
     | 
| 23 | 
         
            +
                    echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
         
     | 
| 24 | 
         
            +
                    git clone https://github.com/rsennrich/subword-nmt.git
         
     | 
| 25 | 
         
            +
                fi
         
     | 
| 26 | 
         
            +
            fi
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            TMP_REF=$(mktemp)
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            sacrebleu -t $DATASET -l $LANGPAIR --echo ref -q \
         
     | 
| 32 | 
         
            +
            | sacremoses normalize -l $TGTLANG -q \
         
     | 
| 33 | 
         
            +
            | sacremoses tokenize -a -l $TGTLANG -q \
         
     | 
| 34 | 
         
            +
            > $TMP_REF
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            sacrebleu -t $DATASET -l $LANGPAIR --echo src -q \
         
     | 
| 37 | 
         
            +
            | sacremoses normalize -l $SRCLANG -q \
         
     | 
| 38 | 
         
            +
            | sacremoses tokenize -a -l $SRCLANG -q \
         
     | 
| 39 | 
         
            +
            | python $BPEROOT/apply_bpe.py -c $BPECODE \
         
     | 
| 40 | 
         
            +
            | fairseq-interactive $DATABIN --path $MODEL \
         
     | 
| 41 | 
         
            +
                -s $SRCLANG -t $TGTLANG \
         
     | 
| 42 | 
         
            +
                --beam 5 --remove-bpe --buffer-size 1024 --max-tokens 8000 \
         
     | 
| 43 | 
         
            +
            | grep ^H- | cut -f 3- \
         
     | 
| 44 | 
         
            +
            | fairseq-score --ref $TMP_REF
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            rm -f $TMP_REF
         
     | 
    	
        fairseq/examples/bart/README.glue.md
    ADDED
    
    | 
         @@ -0,0 +1,99 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Fine-tuning BART on GLUE tasks
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            ### 1) Download the data from GLUE website (https://gluebenchmark.com/tasks) using following commands:
         
     | 
| 4 | 
         
            +
            ```bash
         
     | 
| 5 | 
         
            +
            wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py
         
     | 
| 6 | 
         
            +
            python download_glue_data.py --data_dir glue_data --tasks all
         
     | 
| 7 | 
         
            +
            ```
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            ### 2) Preprocess GLUE task data (same as RoBERTa):
         
     | 
| 10 | 
         
            +
            ```bash
         
     | 
| 11 | 
         
            +
            ./examples/roberta/preprocess_GLUE_tasks.sh glue_data <glue_task_name>
         
     | 
| 12 | 
         
            +
            ```
         
     | 
| 13 | 
         
            +
            `glue_task_name` is one of the following:
         
     | 
| 14 | 
         
            +
            `{ALL, QQP, MNLI, QNLI, MRPC, RTE, STS-B, SST-2, CoLA}`
         
     | 
| 15 | 
         
            +
            Use `ALL` for preprocessing all the glue tasks.
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            ### 3) Fine-tuning on GLUE task:
         
     | 
| 18 | 
         
            +
            Example fine-tuning cmd for `RTE` task
         
     | 
| 19 | 
         
            +
            ```bash
         
     | 
| 20 | 
         
            +
            TOTAL_NUM_UPDATES=2036  # 10 epochs through RTE for bsz 16
         
     | 
| 21 | 
         
            +
            WARMUP_UPDATES=61      # 6 percent of the number of updates
         
     | 
| 22 | 
         
            +
            LR=1e-05                # Peak LR for polynomial LR scheduler.
         
     | 
| 23 | 
         
            +
            NUM_CLASSES=2
         
     | 
| 24 | 
         
            +
            MAX_SENTENCES=16        # Batch size.
         
     | 
| 25 | 
         
            +
            BART_PATH=/path/to/bart/model.pt
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            CUDA_VISIBLE_DEVICES=0,1 fairseq-train RTE-bin/ \
         
     | 
| 28 | 
         
            +
                --restore-file $BART_PATH \
         
     | 
| 29 | 
         
            +
                --batch-size $MAX_SENTENCES \
         
     | 
| 30 | 
         
            +
                --max-tokens 4400 \
         
     | 
| 31 | 
         
            +
                --task sentence_prediction \
         
     | 
| 32 | 
         
            +
                --add-prev-output-tokens \
         
     | 
| 33 | 
         
            +
                --layernorm-embedding \
         
     | 
| 34 | 
         
            +
                --share-all-embeddings \
         
     | 
| 35 | 
         
            +
                --share-decoder-input-output-embed \
         
     | 
| 36 | 
         
            +
                --reset-optimizer --reset-dataloader --reset-meters \
         
     | 
| 37 | 
         
            +
                --required-batch-size-multiple 1 \
         
     | 
| 38 | 
         
            +
                --init-token 0 \
         
     | 
| 39 | 
         
            +
                --arch bart_large \
         
     | 
| 40 | 
         
            +
                --criterion sentence_prediction \
         
     | 
| 41 | 
         
            +
                --num-classes $NUM_CLASSES \
         
     | 
| 42 | 
         
            +
                --dropout 0.1 --attention-dropout 0.1 \
         
     | 
| 43 | 
         
            +
                --weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 \
         
     | 
| 44 | 
         
            +
                --clip-norm 0.0 \
         
     | 
| 45 | 
         
            +
                --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
         
     | 
| 46 | 
         
            +
                --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
         
     | 
| 47 | 
         
            +
                --max-epoch 10 \
         
     | 
| 48 | 
         
            +
                --find-unused-parameters \
         
     | 
| 49 | 
         
            +
                --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric;
         
     | 
| 50 | 
         
            +
            ```
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            For each of the GLUE task, you will need to use following cmd-line arguments:
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
         
     | 
| 55 | 
         
            +
            ---|---|---|---|---|---|---|---|---
         
     | 
| 56 | 
         
            +
            `--num-classes` | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 1
         
     | 
| 57 | 
         
            +
            `--lr` | 5e-6 | 1e-5 | 1e-5 | 1e-5 | 5e-6 | 2e-5 | 2e-5 | 2e-5
         
     | 
| 58 | 
         
            +
            `bsz` | 128 | 32 | 32 | 32 | 128 | 64 | 64 | 32
         
     | 
| 59 | 
         
            +
            `--total-num-update` | 30968 | 33112 | 113272 | 1018 | 5233 | 1148 | 1334 | 1799
         
     | 
| 60 | 
         
            +
            `--warmup-updates` | 1858 | 1986 | 6796 | 61 | 314 | 68 | 80 | 107
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
            For `STS-B` additionally add `--regression-target --best-checkpoint-metric loss` and remove `--maximize-best-checkpoint-metric`.
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            **Note:**
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
            a) `--total-num-updates` is used by `--polynomial_decay` scheduler and is calculated for `--max-epoch=10` and `--batch-size=32/64/128` depending on the task.
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            b) Above cmd-args and hyperparams are tested on Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--batch-size`.
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
            ### Inference on GLUE task
         
     | 
| 71 | 
         
            +
            After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet:
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
            ```python
         
     | 
| 74 | 
         
            +
            from fairseq.models.bart import BARTModel
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
            bart = BARTModel.from_pretrained(
         
     | 
| 77 | 
         
            +
                'checkpoints/',
         
     | 
| 78 | 
         
            +
                checkpoint_file='checkpoint_best.pt',
         
     | 
| 79 | 
         
            +
                data_name_or_path='RTE-bin'
         
     | 
| 80 | 
         
            +
            )
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
            label_fn = lambda label: bart.task.label_dictionary.string(
         
     | 
| 83 | 
         
            +
                [label + bart.task.label_dictionary.nspecial]
         
     | 
| 84 | 
         
            +
            )   
         
     | 
| 85 | 
         
            +
            ncorrect, nsamples = 0, 0
         
     | 
| 86 | 
         
            +
            bart.cuda()
         
     | 
| 87 | 
         
            +
            bart.eval()
         
     | 
| 88 | 
         
            +
            with open('glue_data/RTE/dev.tsv') as fin:
         
     | 
| 89 | 
         
            +
                fin.readline()
         
     | 
| 90 | 
         
            +
                for index, line in enumerate(fin):
         
     | 
| 91 | 
         
            +
                    tokens = line.strip().split('\t')
         
     | 
| 92 | 
         
            +
                    sent1, sent2, target = tokens[1], tokens[2], tokens[3]
         
     | 
| 93 | 
         
            +
                    tokens = bart.encode(sent1, sent2)
         
     | 
| 94 | 
         
            +
                    prediction = bart.predict('sentence_classification_head', tokens).argmax().item()
         
     | 
| 95 | 
         
            +
                    prediction_label = label_fn(prediction)
         
     | 
| 96 | 
         
            +
                    ncorrect += int(prediction_label == target)
         
     | 
| 97 | 
         
            +
                    nsamples += 1
         
     | 
| 98 | 
         
            +
            print('| Accuracy: ', float(ncorrect)/float(nsamples))
         
     | 
| 99 | 
         
            +
            ```
         
     | 
    	
        fairseq/examples/bart/README.md
    ADDED
    
    | 
         @@ -0,0 +1,228 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            [https://arxiv.org/abs/1910.13461](https://arxiv.org/abs/1910.13461)
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            ## Introduction
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            BART is sequence-to-sequence model trained with denoising as pretraining objective. We show that this pretraining objective is more generic and show that we can match [RoBERTa](../roberta) results on SQuAD and GLUE and gain state-of-the-art results on summarization (XSum, CNN dataset), long form generative question answering (ELI5) and dialog response genration (ConvAI2). See the associated paper for more details.
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            ## Pre-trained models
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            Model | Description | # params | Download
         
     | 
| 12 | 
         
            +
            ---|---|---|---
         
     | 
| 13 | 
         
            +
            `bart.base` | BART model with 6 encoder and decoder layers | 140M | [bart.base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz)
         
     | 
| 14 | 
         
            +
            `bart.large` | BART model with 12 encoder and decoder layers | 400M | [bart.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz)
         
     | 
| 15 | 
         
            +
            `bart.large.mnli` | `bart.large` finetuned on `MNLI` | 400M | [bart.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz)
         
     | 
| 16 | 
         
            +
            `bart.large.cnn` | `bart.large` finetuned on `CNN-DM` | 400M | [bart.large.cnn.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz)
         
     | 
| 17 | 
         
            +
            `bart.large.xsum` | `bart.large` finetuned on `Xsum` | 400M | [bart.large.xsum.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.xsum.tar.gz)
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            ## Results
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            **[GLUE (Wang et al., 2019)](https://gluebenchmark.com/)**
         
     | 
| 22 | 
         
            +
            _(dev set, single model, single-task finetuning)_
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
         
     | 
| 25 | 
         
            +
            ---|---|---|---|---|---|---|---|---
         
     | 
| 26 | 
         
            +
            `roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4
         
     | 
| 27 | 
         
            +
            `bart.large` | 89.9 | 94.9 | 92.5 | 87.0 | 96.6 | 90.4 | 62.8 | 91.2
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            **[SQuAD (Rajpurkar et al., 2018)](https://rajpurkar.github.io/SQuAD-explorer/)**
         
     | 
| 30 | 
         
            +
            _(dev set, no additional data used)_
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            Model | SQuAD 1.1 EM/F1 | SQuAD 2.0 EM/F1
         
     | 
| 33 | 
         
            +
            ---|---|---
         
     | 
| 34 | 
         
            +
            `roberta.large` | 88.9/94.6 | 86.5/89.4
         
     | 
| 35 | 
         
            +
            `bart.large` | 88.8/94.6 | 86.1/89.2
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            **[CNN/Daily Mail](http://nlpprogress.com/english/summarization.html)**
         
     | 
| 38 | 
         
            +
            _(test set, no additional data used)_
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            Model | R1 | R2 | RL
         
     | 
| 41 | 
         
            +
            ---|---|---|---
         
     | 
| 42 | 
         
            +
            `BERTSUMEXTABS` | 42.13 | 19.60 | 39.18
         
     | 
| 43 | 
         
            +
            `bart.large` | 44.16 | 21.28 | 40.90
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            ## Example usage
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            ##### Load BART from torch.hub (PyTorch >= 1.1):
         
     | 
| 48 | 
         
            +
            ```python
         
     | 
| 49 | 
         
            +
            import torch
         
     | 
| 50 | 
         
            +
            bart = torch.hub.load('pytorch/fairseq', 'bart.large')
         
     | 
| 51 | 
         
            +
            bart.eval()  # disable dropout (or leave in train mode to finetune)
         
     | 
| 52 | 
         
            +
            ```
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            ##### Load BART (for PyTorch 1.0 or custom models):
         
     | 
| 55 | 
         
            +
            ```python
         
     | 
| 56 | 
         
            +
            # Download bart.large model
         
     | 
| 57 | 
         
            +
            wget https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz
         
     | 
| 58 | 
         
            +
            tar -xzvf bart.large.tar.gz
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
            # Load the model in fairseq
         
     | 
| 61 | 
         
            +
            from fairseq.models.bart import BARTModel
         
     | 
| 62 | 
         
            +
            bart = BARTModel.from_pretrained('/path/to/bart.large', checkpoint_file='model.pt')
         
     | 
| 63 | 
         
            +
            bart.eval()  # disable dropout (or leave in train mode to finetune)
         
     | 
| 64 | 
         
            +
            ```
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
            ##### Apply Byte-Pair Encoding (BPE) to input text:
         
     | 
| 67 | 
         
            +
            ```python
         
     | 
| 68 | 
         
            +
            tokens = bart.encode('Hello world!')
         
     | 
| 69 | 
         
            +
            assert tokens.tolist() == [0, 31414, 232, 328, 2]
         
     | 
| 70 | 
         
            +
            bart.decode(tokens)  # 'Hello world!'
         
     | 
| 71 | 
         
            +
            ```
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
            ##### Extract features from BART:
         
     | 
| 74 | 
         
            +
            ```python
         
     | 
| 75 | 
         
            +
            # Extract the last layer's features
         
     | 
| 76 | 
         
            +
            last_layer_features = bart.extract_features(tokens)
         
     | 
| 77 | 
         
            +
            assert last_layer_features.size() == torch.Size([1, 5, 1024])
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
            # Extract all layer's features from decoder (layer 0 is the embedding layer)
         
     | 
| 80 | 
         
            +
            all_layers = bart.extract_features(tokens, return_all_hiddens=True)
         
     | 
| 81 | 
         
            +
            assert len(all_layers) == 13
         
     | 
| 82 | 
         
            +
            assert torch.all(all_layers[-1] == last_layer_features)
         
     | 
| 83 | 
         
            +
            ```
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
            ##### Use BART for sentence-pair classification tasks:
         
     | 
| 86 | 
         
            +
            ```python
         
     | 
| 87 | 
         
            +
            # Download BART already finetuned for MNLI
         
     | 
| 88 | 
         
            +
            bart = torch.hub.load('pytorch/fairseq', 'bart.large.mnli')
         
     | 
| 89 | 
         
            +
            bart.eval()  # disable dropout for evaluation
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
            # Encode a pair of sentences and make a prediction
         
     | 
| 92 | 
         
            +
            tokens = bart.encode('BART is a seq2seq model.', 'BART is not sequence to sequence.')
         
     | 
| 93 | 
         
            +
            bart.predict('mnli', tokens).argmax()  # 0: contradiction
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
            # Encode another pair of sentences
         
     | 
| 96 | 
         
            +
            tokens = bart.encode('BART is denoising autoencoder.', 'BART is version of autoencoder.')
         
     | 
| 97 | 
         
            +
            bart.predict('mnli', tokens).argmax()  # 2: entailment
         
     | 
| 98 | 
         
            +
            ```
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
            ##### Register a new (randomly initialized) classification head:
         
     | 
| 101 | 
         
            +
            ```python
         
     | 
| 102 | 
         
            +
            bart.register_classification_head('new_task', num_classes=3)
         
     | 
| 103 | 
         
            +
            logprobs = bart.predict('new_task', tokens)
         
     | 
| 104 | 
         
            +
            ```
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
            ##### Batched prediction:
         
     | 
| 107 | 
         
            +
            ```python
         
     | 
| 108 | 
         
            +
            import torch
         
     | 
| 109 | 
         
            +
            from fairseq.data.data_utils import collate_tokens
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
            bart = torch.hub.load('pytorch/fairseq', 'bart.large.mnli')
         
     | 
| 112 | 
         
            +
            bart.eval()
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
            batch_of_pairs = [
         
     | 
| 115 | 
         
            +
                ['BART is a seq2seq model.', 'BART is not sequence to sequence.'],
         
     | 
| 116 | 
         
            +
                ['BART is denoising autoencoder.', 'BART is version of autoencoder.'],
         
     | 
| 117 | 
         
            +
            ]
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
            batch = collate_tokens(
         
     | 
| 120 | 
         
            +
                [bart.encode(pair[0], pair[1]) for pair in batch_of_pairs], pad_idx=1
         
     | 
| 121 | 
         
            +
            )
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
            logprobs = bart.predict('mnli', batch)
         
     | 
| 124 | 
         
            +
            print(logprobs.argmax(dim=1))
         
     | 
| 125 | 
         
            +
            # tensor([0, 2])
         
     | 
| 126 | 
         
            +
            ```
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
            ##### Using the GPU:
         
     | 
| 129 | 
         
            +
            ```python
         
     | 
| 130 | 
         
            +
            bart.cuda()
         
     | 
| 131 | 
         
            +
            bart.predict('new_task', tokens)
         
     | 
| 132 | 
         
            +
            ```
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
            #### Filling masks:
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
            BART can be used to fill multiple `<mask>` tokens in the input.
         
     | 
| 137 | 
         
            +
            ```python
         
     | 
| 138 | 
         
            +
            bart = torch.hub.load('pytorch/fairseq', 'bart.base')
         
     | 
| 139 | 
         
            +
            bart.eval()
         
     | 
| 140 | 
         
            +
            bart.fill_mask(['The cat <mask> on the <mask>.'], topk=3, beam=10)
         
     | 
| 141 | 
         
            +
            # [[('The cat was on the ground.', tensor(-0.6183)), ('The cat was on the floor.', tensor(-0.6798)), ('The cat sleeps on the couch.', tensor(-0.6830))]]
         
     | 
| 142 | 
         
            +
            ```
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
            Note that by default we enforce the output length to match the input length.
         
     | 
| 145 | 
         
            +
            This can be disabled by setting ``match_source_len=False``:
         
     | 
| 146 | 
         
            +
            ```
         
     | 
| 147 | 
         
            +
            bart.fill_mask(['The cat <mask> on the <mask>.'], topk=3, beam=10, match_source_len=False)
         
     | 
| 148 | 
         
            +
            # [[('The cat was on the ground.', tensor(-0.6185)), ('The cat was asleep on the couch.', tensor(-0.6276)), ('The cat was on the floor.', tensor(-0.6800))]]
         
     | 
| 149 | 
         
            +
            ```
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
            Example code to fill masks for a batch of sentences using GPU
         
     | 
| 152 | 
         
            +
            ```
         
     | 
| 153 | 
         
            +
            bart.cuda()
         
     | 
| 154 | 
         
            +
            bart.fill_mask(['The cat <mask> on the <mask>.', 'The dog <mask> on the <mask>.'], topk=3, beam=10)
         
     | 
| 155 | 
         
            +
            # [[('The cat was on the ground.', tensor(-0.6183)), ('The cat was on the floor.', tensor(-0.6798)), ('The cat sleeps on the couch.', tensor(-0.6830))], [('The dog was on the ground.', tensor(-0.6190)), ('The dog lay on the ground.', tensor(-0.6711)),
         
     | 
| 156 | 
         
            +
            ('The dog was asleep on the couch', tensor(-0.6796))]]
         
     | 
| 157 | 
         
            +
            ```
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
            #### Evaluating the `bart.large.mnli` model:
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
            Example python code snippet to evaluate accuracy on the MNLI `dev_matched` set.
         
     | 
| 162 | 
         
            +
            ```python
         
     | 
| 163 | 
         
            +
            label_map = {0: 'contradiction', 1: 'neutral', 2: 'entailment'}
         
     | 
| 164 | 
         
            +
            ncorrect, nsamples = 0, 0
         
     | 
| 165 | 
         
            +
            bart.cuda()
         
     | 
| 166 | 
         
            +
            bart.eval()
         
     | 
| 167 | 
         
            +
            with open('glue_data/MNLI/dev_matched.tsv') as fin:
         
     | 
| 168 | 
         
            +
                fin.readline()
         
     | 
| 169 | 
         
            +
                for index, line in enumerate(fin):
         
     | 
| 170 | 
         
            +
                    tokens = line.strip().split('\t')
         
     | 
| 171 | 
         
            +
                    sent1, sent2, target = tokens[8], tokens[9], tokens[-1]
         
     | 
| 172 | 
         
            +
                    tokens = bart.encode(sent1, sent2)
         
     | 
| 173 | 
         
            +
                    prediction = bart.predict('mnli', tokens).argmax().item()
         
     | 
| 174 | 
         
            +
                    prediction_label = label_map[prediction]
         
     | 
| 175 | 
         
            +
                    ncorrect += int(prediction_label == target)
         
     | 
| 176 | 
         
            +
                    nsamples += 1
         
     | 
| 177 | 
         
            +
                    print('| Accuracy: ', float(ncorrect)/float(nsamples))
         
     | 
| 178 | 
         
            +
            # Expected output: 0.9010
         
     | 
| 179 | 
         
            +
            ```
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
            #### Evaluating the `bart.large.cnn` model:
         
     | 
| 182 | 
         
            +
            - Follow instructions [here](https://github.com/abisee/cnn-dailymail) to download and process into data-files such that `test.source` and `test.target` has one line for each non-tokenized sample.
         
     | 
| 183 | 
         
            +
            - For simpler preprocessing, you can also `wget https://cdn-datasets.huggingface.co/summarization/cnn_dm_v2.tgz`, although there is no guarantee of identical scores
         
     | 
| 184 | 
         
            +
            - `huggingface/transformers` has a simpler interface that supports [single-gpu](https://github.com/huggingface/transformers/blob/master/examples/legacy/seq2seq/run_eval.py) and [multi-gpu](https://github.com/huggingface/transformers/blob/master/examples/legacy/seq2seq/run_distributed_eval.py) beam search.
         
     | 
| 185 | 
         
            +
                In `huggingface/transformers`, the BART models' paths are `facebook/bart-large-cnn` and `facebook/bart-large-xsum`.
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
            In `fairseq`, summaries can be generated using:
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
            ```bash
         
     | 
| 190 | 
         
            +
            cp data-bin/cnn_dm/dict.source.txt  checkpoints/
         
     | 
| 191 | 
         
            +
            python examples/bart/summarize.py \
         
     | 
| 192 | 
         
            +
              --model-dir pytorch/fairseq \
         
     | 
| 193 | 
         
            +
              --model-file bart.large.cnn \
         
     | 
| 194 | 
         
            +
              --src cnn_dm/test.source \
         
     | 
| 195 | 
         
            +
              --out cnn_dm/test.hypo
         
     | 
| 196 | 
         
            +
            ```
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
            For calculating rouge, install `files2rouge` from [here](https://github.com/pltrdy/files2rouge).
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
            ```bash
         
     | 
| 201 | 
         
            +
            export CLASSPATH=/path/to/stanford-corenlp-full-2016-10-31/stanford-corenlp-3.7.0.jar
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
            # Tokenize hypothesis and target files.
         
     | 
| 204 | 
         
            +
            cat test.hypo | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > test.hypo.tokenized
         
     | 
| 205 | 
         
            +
            cat test.target | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > test.hypo.target
         
     | 
| 206 | 
         
            +
            files2rouge test.hypo.tokenized test.hypo.target
         
     | 
| 207 | 
         
            +
            # Expected output: (ROUGE-2 Average_F: 0.21238)
         
     | 
| 208 | 
         
            +
            ```
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
            ## Finetuning
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
            - [Finetuning on GLUE](README.glue.md)
         
     | 
| 214 | 
         
            +
            - [Finetuning on CNN-DM](README.summarization.md)
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
            ## Citation
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
            ```bibtex
         
     | 
| 219 | 
         
            +
            @article{lewis2019bart,
         
     | 
| 220 | 
         
            +
                title = {BART: Denoising Sequence-to-Sequence Pre-training for Natural
         
     | 
| 221 | 
         
            +
            Language Generation, Translation, and Comprehension},
         
     | 
| 222 | 
         
            +
                author = {Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and
         
     | 
| 223 | 
         
            +
                          Abdelrahman Mohamed and Omer Levy and Veselin Stoyanov
         
     | 
| 224 | 
         
            +
                          and Luke Zettlemoyer },
         
     | 
| 225 | 
         
            +
                journal={arXiv preprint arXiv:1910.13461},
         
     | 
| 226 | 
         
            +
                year = {2019},
         
     | 
| 227 | 
         
            +
            }
         
     | 
| 228 | 
         
            +
            ```
         
     | 
    	
        fairseq/examples/bart/README.summarization.md
    ADDED
    
    | 
         @@ -0,0 +1,102 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Fine-tuning BART on CNN-Dailymail summarization task
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            ### 1) Download the CNN and Daily Mail data and preprocess it into data files with non-tokenized cased samples.
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            Follow the instructions [here](https://github.com/abisee/cnn-dailymail) to download the original CNN and Daily Mail datasets. To preprocess the data, refer to the pointers in [this issue](https://github.com/pytorch/fairseq/issues/1391) or check out the code [here](https://github.com/artmatsak/cnn-dailymail).
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            Follow the instructions [here](https://github.com/EdinburghNLP/XSum) to download the original Extreme Summarization datasets, or check out the code [here](https://github.com/EdinburghNLP/XSum/tree/master/XSum-Dataset), Please keep the raw dataset and make sure no tokenization nor BPE on the dataset.
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            ### 2) BPE preprocess:
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            ```bash
         
     | 
| 12 | 
         
            +
            wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
         
     | 
| 13 | 
         
            +
            wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
         
     | 
| 14 | 
         
            +
            wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt'
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            TASK=cnn_dm
         
     | 
| 17 | 
         
            +
            for SPLIT in train val
         
     | 
| 18 | 
         
            +
            do
         
     | 
| 19 | 
         
            +
              for LANG in source target
         
     | 
| 20 | 
         
            +
              do
         
     | 
| 21 | 
         
            +
                python -m examples.roberta.multiprocessing_bpe_encoder \
         
     | 
| 22 | 
         
            +
                --encoder-json encoder.json \
         
     | 
| 23 | 
         
            +
                --vocab-bpe vocab.bpe \
         
     | 
| 24 | 
         
            +
                --inputs "$TASK/$SPLIT.$LANG" \
         
     | 
| 25 | 
         
            +
                --outputs "$TASK/$SPLIT.bpe.$LANG" \
         
     | 
| 26 | 
         
            +
                --workers 60 \
         
     | 
| 27 | 
         
            +
                --keep-empty;
         
     | 
| 28 | 
         
            +
              done
         
     | 
| 29 | 
         
            +
            done
         
     | 
| 30 | 
         
            +
            ```
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            ### 3) Binarize dataset:
         
     | 
| 33 | 
         
            +
            ```bash
         
     | 
| 34 | 
         
            +
            fairseq-preprocess \
         
     | 
| 35 | 
         
            +
              --source-lang "source" \
         
     | 
| 36 | 
         
            +
              --target-lang "target" \
         
     | 
| 37 | 
         
            +
              --trainpref "${TASK}/train.bpe" \
         
     | 
| 38 | 
         
            +
              --validpref "${TASK}/val.bpe" \
         
     | 
| 39 | 
         
            +
              --destdir "${TASK}-bin/" \
         
     | 
| 40 | 
         
            +
              --workers 60 \
         
     | 
| 41 | 
         
            +
              --srcdict dict.txt \
         
     | 
| 42 | 
         
            +
              --tgtdict dict.txt;
         
     | 
| 43 | 
         
            +
            ```
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            ### 4) Fine-tuning on CNN-DM summarization task:
         
     | 
| 46 | 
         
            +
            Example fine-tuning CNN-DM
         
     | 
| 47 | 
         
            +
            ```bash
         
     | 
| 48 | 
         
            +
            TOTAL_NUM_UPDATES=20000  
         
     | 
| 49 | 
         
            +
            WARMUP_UPDATES=500      
         
     | 
| 50 | 
         
            +
            LR=3e-05
         
     | 
| 51 | 
         
            +
            MAX_TOKENS=2048
         
     | 
| 52 | 
         
            +
            UPDATE_FREQ=4
         
     | 
| 53 | 
         
            +
            BART_PATH=/path/to/bart/model.pt
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
            CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 fairseq-train cnn_dm-bin \
         
     | 
| 56 | 
         
            +
                --restore-file $BART_PATH \
         
     | 
| 57 | 
         
            +
                --max-tokens $MAX_TOKENS \
         
     | 
| 58 | 
         
            +
                --task translation \
         
     | 
| 59 | 
         
            +
                --source-lang source --target-lang target \
         
     | 
| 60 | 
         
            +
                --truncate-source \
         
     | 
| 61 | 
         
            +
                --layernorm-embedding \
         
     | 
| 62 | 
         
            +
                --share-all-embeddings \
         
     | 
| 63 | 
         
            +
                --share-decoder-input-output-embed \
         
     | 
| 64 | 
         
            +
                --reset-optimizer --reset-dataloader --reset-meters \
         
     | 
| 65 | 
         
            +
                --required-batch-size-multiple 1 \
         
     | 
| 66 | 
         
            +
                --arch bart_large \
         
     | 
| 67 | 
         
            +
                --criterion label_smoothed_cross_entropy \
         
     | 
| 68 | 
         
            +
                --label-smoothing 0.1 \
         
     | 
| 69 | 
         
            +
                --dropout 0.1 --attention-dropout 0.1 \
         
     | 
| 70 | 
         
            +
                --weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \
         
     | 
| 71 | 
         
            +
                --clip-norm 0.1 \
         
     | 
| 72 | 
         
            +
                --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
         
     | 
| 73 | 
         
            +
                --fp16 --update-freq $UPDATE_FREQ \
         
     | 
| 74 | 
         
            +
                --skip-invalid-size-inputs-valid-test \
         
     | 
| 75 | 
         
            +
                --find-unused-parameters;
         
     | 
| 76 | 
         
            +
            ```
         
     | 
| 77 | 
         
            +
            Above is expected to run on `1` node with `8 32gb-V100`.
         
     | 
| 78 | 
         
            +
            Expected training time is about `5 hours`. Training time can be reduced with distributed training on `4` nodes and `--update-freq 1`.
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
            Use TOTAL_NUM_UPDATES=15000 UPDATE_FREQ=2 for Xsum task
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
            ### Inference for CNN-DM test data using above trained checkpoint.
         
     | 
| 83 | 
         
            +
            After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using `eval_cnn.py`, for example
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
            ```bash
         
     | 
| 86 | 
         
            +
            cp data-bin/cnn_dm/dict.source.txt  checkpoints/
         
     | 
| 87 | 
         
            +
            python examples/bart/summarize.py \
         
     | 
| 88 | 
         
            +
              --model-dir checkpoints \
         
     | 
| 89 | 
         
            +
              --model-file checkpoint_best.pt \
         
     | 
| 90 | 
         
            +
              --src cnn_dm/test.source \
         
     | 
| 91 | 
         
            +
              --out cnn_dm/test.hypo
         
     | 
| 92 | 
         
            +
            ```
         
     | 
| 93 | 
         
            +
            For XSUM, which uses beam=6, lenpen=1.0, max_len_b=60, min_len=10:
         
     | 
| 94 | 
         
            +
            ```bash
         
     | 
| 95 | 
         
            +
            cp data-bin/cnn_dm/dict.source.txt  checkpoints/
         
     | 
| 96 | 
         
            +
            python examples/bart/summarize.py \
         
     | 
| 97 | 
         
            +
              --model-dir checkpoints \
         
     | 
| 98 | 
         
            +
              --model-file checkpoint_best.pt \
         
     | 
| 99 | 
         
            +
              --src cnn_dm/test.source \
         
     | 
| 100 | 
         
            +
              --out cnn_dm/test.hypo \
         
     | 
| 101 | 
         
            +
              --xsum-kwargs
         
     | 
| 102 | 
         
            +
            ```
         
     | 
    	
        fairseq/examples/bart/summarize.py
    ADDED
    
    | 
         @@ -0,0 +1,100 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
            from fairseq.models.bart import BARTModel
         
     | 
| 8 | 
         
            +
            import argparse
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            XSUM_KWARGS = dict(beam=6, lenpen=1.0, max_len_b=60, min_len=10, no_repeat_ngram_size=3)
         
     | 
| 11 | 
         
            +
            CNN_KWARGS = dict(beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            @torch.no_grad()
         
     | 
| 15 | 
         
            +
            def generate(bart, infile, outfile="bart_hypo.txt", bsz=32, n_obs=None, **eval_kwargs):
         
     | 
| 16 | 
         
            +
                count = 1
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                # if n_obs is not None: bsz = min(bsz, n_obs)
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                with open(infile) as source, open(outfile, "w") as fout:
         
     | 
| 21 | 
         
            +
                    sline = source.readline().strip()
         
     | 
| 22 | 
         
            +
                    slines = [sline]
         
     | 
| 23 | 
         
            +
                    for sline in source:
         
     | 
| 24 | 
         
            +
                        if n_obs is not None and count > n_obs:
         
     | 
| 25 | 
         
            +
                            break
         
     | 
| 26 | 
         
            +
                        if count % bsz == 0:
         
     | 
| 27 | 
         
            +
                            hypotheses_batch = bart.sample(slines, **eval_kwargs)
         
     | 
| 28 | 
         
            +
                            for hypothesis in hypotheses_batch:
         
     | 
| 29 | 
         
            +
                                fout.write(hypothesis + "\n")
         
     | 
| 30 | 
         
            +
                                fout.flush()
         
     | 
| 31 | 
         
            +
                            slines = []
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                        slines.append(sline.strip())
         
     | 
| 34 | 
         
            +
                        count += 1
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                    if slines != []:
         
     | 
| 37 | 
         
            +
                        hypotheses_batch = bart.sample(slines, **eval_kwargs)
         
     | 
| 38 | 
         
            +
                        for hypothesis in hypotheses_batch:
         
     | 
| 39 | 
         
            +
                            fout.write(hypothesis + "\n")
         
     | 
| 40 | 
         
            +
                            fout.flush()
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            def main():
         
     | 
| 44 | 
         
            +
                """
         
     | 
| 45 | 
         
            +
                Usage::
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                     python examples/bart/summarize.py \
         
     | 
| 48 | 
         
            +
                        --model-dir $HOME/bart.large.cnn \
         
     | 
| 49 | 
         
            +
                        --model-file model.pt \
         
     | 
| 50 | 
         
            +
                        --src $HOME/data-bin/cnn_dm/test.source
         
     | 
| 51 | 
         
            +
                """
         
     | 
| 52 | 
         
            +
                parser = argparse.ArgumentParser()
         
     | 
| 53 | 
         
            +
                parser.add_argument(
         
     | 
| 54 | 
         
            +
                    "--model-dir",
         
     | 
| 55 | 
         
            +
                    required=True,
         
     | 
| 56 | 
         
            +
                    type=str,
         
     | 
| 57 | 
         
            +
                    default="bart.large.cnn/",
         
     | 
| 58 | 
         
            +
                    help="path containing model file and src_dict.txt",
         
     | 
| 59 | 
         
            +
                )
         
     | 
| 60 | 
         
            +
                parser.add_argument(
         
     | 
| 61 | 
         
            +
                    "--model-file",
         
     | 
| 62 | 
         
            +
                    default="checkpoint_best.pt",
         
     | 
| 63 | 
         
            +
                    help="where in model_dir are weights saved",
         
     | 
| 64 | 
         
            +
                )
         
     | 
| 65 | 
         
            +
                parser.add_argument(
         
     | 
| 66 | 
         
            +
                    "--src", default="test.source", help="text to summarize", type=str
         
     | 
| 67 | 
         
            +
                )
         
     | 
| 68 | 
         
            +
                parser.add_argument(
         
     | 
| 69 | 
         
            +
                    "--out", default="test.hypo", help="where to save summaries", type=str
         
     | 
| 70 | 
         
            +
                )
         
     | 
| 71 | 
         
            +
                parser.add_argument("--bsz", default=32, help="where to save summaries", type=int)
         
     | 
| 72 | 
         
            +
                parser.add_argument(
         
     | 
| 73 | 
         
            +
                    "--n", default=None, help="how many examples to summarize", type=int
         
     | 
| 74 | 
         
            +
                )
         
     | 
| 75 | 
         
            +
                parser.add_argument(
         
     | 
| 76 | 
         
            +
                    "--xsum-kwargs",
         
     | 
| 77 | 
         
            +
                    action="store_true",
         
     | 
| 78 | 
         
            +
                    default=False,
         
     | 
| 79 | 
         
            +
                    help="if true use XSUM_KWARGS else CNN_KWARGS",
         
     | 
| 80 | 
         
            +
                )
         
     | 
| 81 | 
         
            +
                args = parser.parse_args()
         
     | 
| 82 | 
         
            +
                eval_kwargs = XSUM_KWARGS if args.xsum_kwargs else CNN_KWARGS
         
     | 
| 83 | 
         
            +
                if args.model_dir == "pytorch/fairseq":
         
     | 
| 84 | 
         
            +
                    bart = torch.hub.load("pytorch/fairseq", args.model_file)
         
     | 
| 85 | 
         
            +
                else:
         
     | 
| 86 | 
         
            +
                    bart = BARTModel.from_pretrained(
         
     | 
| 87 | 
         
            +
                        args.model_dir,
         
     | 
| 88 | 
         
            +
                        checkpoint_file=args.model_file,
         
     | 
| 89 | 
         
            +
                        data_name_or_path=args.model_dir,
         
     | 
| 90 | 
         
            +
                    )
         
     | 
| 91 | 
         
            +
                bart = bart.eval()
         
     | 
| 92 | 
         
            +
                if torch.cuda.is_available():
         
     | 
| 93 | 
         
            +
                    bart = bart.cuda().half()
         
     | 
| 94 | 
         
            +
                generate(
         
     | 
| 95 | 
         
            +
                    bart, args.src, bsz=args.bsz, n_obs=args.n, outfile=args.out, **eval_kwargs
         
     | 
| 96 | 
         
            +
                )
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 100 | 
         
            +
                main()
         
     | 
    	
        fairseq/examples/byte_level_bpe/README.md
    ADDED
    
    | 
         @@ -0,0 +1,88 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Neural Machine Translation with Byte-Level Subwords
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            https://arxiv.org/abs/1909.03341
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            We provide an implementation of byte-level byte-pair encoding (BBPE), taking IWSLT 2017 Fr-En translation as
         
     | 
| 6 | 
         
            +
            example.
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            ## Data
         
     | 
| 9 | 
         
            +
            Get data and generate fairseq binary dataset:
         
     | 
| 10 | 
         
            +
            ```bash
         
     | 
| 11 | 
         
            +
            bash ./get_data.sh
         
     | 
| 12 | 
         
            +
            ```
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            ## Model Training
         
     | 
| 15 | 
         
            +
            Train Transformer model with Bi-GRU embedding contextualization (implemented in `gru_transformer.py`):
         
     | 
| 16 | 
         
            +
            ```bash
         
     | 
| 17 | 
         
            +
            # VOCAB=bytes
         
     | 
| 18 | 
         
            +
            # VOCAB=chars
         
     | 
| 19 | 
         
            +
            VOCAB=bbpe2048
         
     | 
| 20 | 
         
            +
            # VOCAB=bpe2048
         
     | 
| 21 | 
         
            +
            # VOCAB=bbpe4096
         
     | 
| 22 | 
         
            +
            # VOCAB=bpe4096
         
     | 
| 23 | 
         
            +
            # VOCAB=bpe16384
         
     | 
| 24 | 
         
            +
            ```
         
     | 
| 25 | 
         
            +
            ```bash
         
     | 
| 26 | 
         
            +
            fairseq-train "data/bin_${VOCAB}" --task translation --user-dir examples/byte_level_bpe/gru_transformer \
         
     | 
| 27 | 
         
            +
                --arch gru_transformer --encoder-layers 2 --decoder-layers 2 --dropout 0.3 --share-all-embeddings \
         
     | 
| 28 | 
         
            +
                --optimizer adam --adam-betas '(0.9, 0.98)' \
         
     | 
| 29 | 
         
            +
                --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
         
     | 
| 30 | 
         
            +
                --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
         
     | 
| 31 | 
         
            +
                --log-format 'simple' --log-interval 100 --save-dir "checkpoints/${VOCAB}" \
         
     | 
| 32 | 
         
            +
                --batch-size 100 --max-update 100000 --update-freq 2
         
     | 
| 33 | 
         
            +
            ```
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            ## Generation
         
     | 
| 36 | 
         
            +
            `fairseq-generate` requires bytes (BBPE) decoder to convert byte-level representation back to characters:
         
     | 
| 37 | 
         
            +
            ```bash
         
     | 
| 38 | 
         
            +
            # BPE=--bpe bytes
         
     | 
| 39 | 
         
            +
            # BPE=--bpe characters
         
     | 
| 40 | 
         
            +
            BPE=--bpe byte_bpe --sentencepiece-model-path data/spm_bbpe2048.model
         
     | 
| 41 | 
         
            +
            # BPE=--bpe sentencepiece --sentencepiece-model data/spm_bpe2048.model
         
     | 
| 42 | 
         
            +
            # BPE=--bpe byte_bpe --sentencepiece-model-path data/spm_bbpe4096.model
         
     | 
| 43 | 
         
            +
            # BPE=--bpe sentencepiece --sentencepiece-model data/spm_bpe4096.model
         
     | 
| 44 | 
         
            +
            # BPE=--bpe sentencepiece --sentencepiece-model data/spm_bpe16384.model
         
     | 
| 45 | 
         
            +
            ```
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            ```bash
         
     | 
| 48 | 
         
            +
            fairseq-generate "data/bin_${VOCAB}" --task translation --user-dir examples/byte_level_bpe/gru_transformer \
         
     | 
| 49 | 
         
            +
                --source-lang fr --gen-subset test --sacrebleu --path "checkpoints/${VOCAB}/checkpoint_last.pt" \
         
     | 
| 50 | 
         
            +
                --tokenizer moses --moses-target-lang en ${BPE}
         
     | 
| 51 | 
         
            +
            ```
         
     | 
| 52 | 
         
            +
            When using `fairseq-interactive`, bytes (BBPE) encoder/decoder is required to tokenize input data and detokenize model predictions:
         
     | 
| 53 | 
         
            +
            ```bash
         
     | 
| 54 | 
         
            +
            fairseq-interactive "data/bin_${VOCAB}" --task translation --user-dir examples/byte_level_bpe/gru_transformer \
         
     | 
| 55 | 
         
            +
                --path "checkpoints/${VOCAB}/checkpoint_last.pt" --input data/test.fr --tokenizer moses --moses-source-lang fr \
         
     | 
| 56 | 
         
            +
                --moses-target-lang en ${BPE} --buffer-size 1000 --max-tokens 10000
         
     | 
| 57 | 
         
            +
            ```
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            ## Results
         
     | 
| 60 | 
         
            +
            | Vocabulary    | Model  | BLEU |
         
     | 
| 61 | 
         
            +
            |:-------------:|:-------------:|:-------------:|
         
     | 
| 62 | 
         
            +
            | Joint BPE 16k ([Kudo, 2018](https://arxiv.org/abs/1804.10959)) | 512d LSTM 2+2 | 33.81 |
         
     | 
| 63 | 
         
            +
            | Joint BPE 16k | Transformer base 2+2 (w/ GRU) | 36.64 (36.72) |
         
     | 
| 64 | 
         
            +
            | Joint BPE 4k | Transformer base 2+2 (w/ GRU) | 35.49 (36.10) |
         
     | 
| 65 | 
         
            +
            | Joint BBPE 4k | Transformer base 2+2 (w/ GRU) | 35.61 (35.82) |
         
     | 
| 66 | 
         
            +
            | Joint BPE 2k | Transformer base 2+2 (w/ GRU) | 34.87 (36.13) |
         
     | 
| 67 | 
         
            +
            | Joint BBPE 2k | Transformer base 2+2 (w/ GRU) | 34.98 (35.43) |
         
     | 
| 68 | 
         
            +
            | Characters | Transformer base 2+2 (w/ GRU) | 31.78 (33.30) |
         
     | 
| 69 | 
         
            +
            | Bytes | Transformer base 2+2 (w/ GRU) | 31.57 (33.62) |
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
            ## Citation
         
     | 
| 73 | 
         
            +
            ```
         
     | 
| 74 | 
         
            +
            @misc{wang2019neural,
         
     | 
| 75 | 
         
            +
                title={Neural Machine Translation with Byte-Level Subwords},
         
     | 
| 76 | 
         
            +
                author={Changhan Wang and Kyunghyun Cho and Jiatao Gu},
         
     | 
| 77 | 
         
            +
                year={2019},
         
     | 
| 78 | 
         
            +
                eprint={1909.03341},
         
     | 
| 79 | 
         
            +
                archivePrefix={arXiv},
         
     | 
| 80 | 
         
            +
                primaryClass={cs.CL}
         
     | 
| 81 | 
         
            +
            }
         
     | 
| 82 | 
         
            +
            ```
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
            ## Contact
         
     | 
| 86 | 
         
            +
            Changhan Wang ([[email protected]](mailto:[email protected])),
         
     | 
| 87 | 
         
            +
            Kyunghyun Cho ([[email protected]](mailto:[email protected])),
         
     | 
| 88 | 
         
            +
            Jiatao Gu ([[email protected]](mailto:[email protected]))
         
     | 
    	
        fairseq/examples/byte_level_bpe/get_bitext.py
    ADDED
    
    | 
         @@ -0,0 +1,254 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import argparse
         
     | 
| 8 | 
         
            +
            import os
         
     | 
| 9 | 
         
            +
            import os.path as op
         
     | 
| 10 | 
         
            +
            from collections import namedtuple
         
     | 
| 11 | 
         
            +
            from multiprocessing import cpu_count
         
     | 
| 12 | 
         
            +
            from typing import List, Optional
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            import sentencepiece as sp
         
     | 
| 15 | 
         
            +
            from fairseq.data.encoders.byte_bpe import ByteBPE
         
     | 
| 16 | 
         
            +
            from fairseq.data.encoders.byte_utils import byte_encode
         
     | 
| 17 | 
         
            +
            from fairseq.data.encoders.bytes import Bytes
         
     | 
| 18 | 
         
            +
            from fairseq.data.encoders.characters import Characters
         
     | 
| 19 | 
         
            +
            from fairseq.data.encoders.moses_tokenizer import MosesTokenizer
         
     | 
| 20 | 
         
            +
            from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            SPLITS = ["train", "valid", "test"]
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            def _convert_xml(in_path: str, out_path: str):
         
     | 
| 27 | 
         
            +
                with open(in_path) as f, open(out_path, "w") as f_o:
         
     | 
| 28 | 
         
            +
                    for s in f:
         
     | 
| 29 | 
         
            +
                        ss = s.strip()
         
     | 
| 30 | 
         
            +
                        if not ss.startswith("<seg"):
         
     | 
| 31 | 
         
            +
                            continue
         
     | 
| 32 | 
         
            +
                        ss = ss.replace("</seg>", "").split('">')
         
     | 
| 33 | 
         
            +
                        assert len(ss) == 2
         
     | 
| 34 | 
         
            +
                        f_o.write(ss[1].strip() + "\n")
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            def _convert_train(in_path: str, out_path: str):
         
     | 
| 38 | 
         
            +
                with open(in_path) as f, open(out_path, "w") as f_o:
         
     | 
| 39 | 
         
            +
                    for s in f:
         
     | 
| 40 | 
         
            +
                        ss = s.strip()
         
     | 
| 41 | 
         
            +
                        if ss.startswith("<"):
         
     | 
| 42 | 
         
            +
                            continue
         
     | 
| 43 | 
         
            +
                        f_o.write(ss.strip() + "\n")
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            def _get_bytes(in_path: str, out_path: str):
         
     | 
| 47 | 
         
            +
                with open(in_path) as f, open(out_path, "w") as f_o:
         
     | 
| 48 | 
         
            +
                    for s in f:
         
     | 
| 49 | 
         
            +
                        f_o.write(Bytes.encode(s.strip()) + "\n")
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            def _get_chars(in_path: str, out_path: str):
         
     | 
| 53 | 
         
            +
                with open(in_path) as f, open(out_path, "w") as f_o:
         
     | 
| 54 | 
         
            +
                    for s in f:
         
     | 
| 55 | 
         
            +
                        f_o.write(Characters.encode(s.strip()) + "\n")
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            def pretokenize(in_path: str, out_path: str, src: str, tgt: str):
         
     | 
| 59 | 
         
            +
                Args = namedtuple(
         
     | 
| 60 | 
         
            +
                    "Args",
         
     | 
| 61 | 
         
            +
                    [
         
     | 
| 62 | 
         
            +
                        "moses_source_lang",
         
     | 
| 63 | 
         
            +
                        "moses_target_lang",
         
     | 
| 64 | 
         
            +
                        "moses_no_dash_splits",
         
     | 
| 65 | 
         
            +
                        "moses_no_escape",
         
     | 
| 66 | 
         
            +
                    ],
         
     | 
| 67 | 
         
            +
                )
         
     | 
| 68 | 
         
            +
                args = Args(
         
     | 
| 69 | 
         
            +
                    moses_source_lang=src,
         
     | 
| 70 | 
         
            +
                    moses_target_lang=tgt,
         
     | 
| 71 | 
         
            +
                    moses_no_dash_splits=False,
         
     | 
| 72 | 
         
            +
                    moses_no_escape=False,
         
     | 
| 73 | 
         
            +
                )
         
     | 
| 74 | 
         
            +
                pretokenizer = MosesTokenizer(args)
         
     | 
| 75 | 
         
            +
                with open(in_path) as f, open(out_path, "w") as f_o:
         
     | 
| 76 | 
         
            +
                    for s in f:
         
     | 
| 77 | 
         
            +
                        f_o.write(pretokenizer.encode(s.strip()) + "\n")
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
            def _convert_to_bchar(in_path_prefix: str, src: str, tgt: str, out_path: str):
         
     | 
| 81 | 
         
            +
                with open(out_path, "w") as f_o:
         
     | 
| 82 | 
         
            +
                    for lang in [src, tgt]:
         
     | 
| 83 | 
         
            +
                        with open(f"{in_path_prefix}.{lang}") as f:
         
     | 
| 84 | 
         
            +
                            for s in f:
         
     | 
| 85 | 
         
            +
                                f_o.write(byte_encode(s.strip()) + "\n")
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
            def _get_bpe(in_path: str, model_prefix: str, vocab_size: int):
         
     | 
| 89 | 
         
            +
                arguments = [
         
     | 
| 90 | 
         
            +
                    f"--input={in_path}",
         
     | 
| 91 | 
         
            +
                    f"--model_prefix={model_prefix}",
         
     | 
| 92 | 
         
            +
                    f"--model_type=bpe",
         
     | 
| 93 | 
         
            +
                    f"--vocab_size={vocab_size}",
         
     | 
| 94 | 
         
            +
                    "--character_coverage=1.0",
         
     | 
| 95 | 
         
            +
                    "--normalization_rule_name=identity",
         
     | 
| 96 | 
         
            +
                    f"--num_threads={cpu_count()}",
         
     | 
| 97 | 
         
            +
                ]
         
     | 
| 98 | 
         
            +
                sp.SentencePieceTrainer.Train(" ".join(arguments))
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
            def _apply_bbpe(model_path: str, in_path: str, out_path: str):
         
     | 
| 102 | 
         
            +
                Args = namedtuple("Args", ["sentencepiece_model_path"])
         
     | 
| 103 | 
         
            +
                args = Args(sentencepiece_model_path=model_path)
         
     | 
| 104 | 
         
            +
                tokenizer = ByteBPE(args)
         
     | 
| 105 | 
         
            +
                with open(in_path) as f, open(out_path, "w") as f_o:
         
     | 
| 106 | 
         
            +
                    for s in f:
         
     | 
| 107 | 
         
            +
                        f_o.write(tokenizer.encode(s.strip()) + "\n")
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
            def _apply_bpe(model_path: str, in_path: str, out_path: str):
         
     | 
| 111 | 
         
            +
                Args = namedtuple("Args", ["sentencepiece_model"])
         
     | 
| 112 | 
         
            +
                args = Args(sentencepiece_model=model_path)
         
     | 
| 113 | 
         
            +
                tokenizer = SentencepieceBPE(args)
         
     | 
| 114 | 
         
            +
                with open(in_path) as f, open(out_path, "w") as f_o:
         
     | 
| 115 | 
         
            +
                    for s in f:
         
     | 
| 116 | 
         
            +
                        f_o.write(tokenizer.encode(s.strip()) + "\n")
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
            def _concat_files(in_paths: List[str], out_path: str):
         
     | 
| 120 | 
         
            +
                with open(out_path, "w") as f_o:
         
     | 
| 121 | 
         
            +
                    for p in in_paths:
         
     | 
| 122 | 
         
            +
                        with open(p) as f:
         
     | 
| 123 | 
         
            +
                            for r in f:
         
     | 
| 124 | 
         
            +
                                f_o.write(r)
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
            def preprocess_iwslt17(
         
     | 
| 128 | 
         
            +
                root: str,
         
     | 
| 129 | 
         
            +
                src: str,
         
     | 
| 130 | 
         
            +
                tgt: str,
         
     | 
| 131 | 
         
            +
                bpe_size: Optional[int],
         
     | 
| 132 | 
         
            +
                need_chars: bool,
         
     | 
| 133 | 
         
            +
                bbpe_size: Optional[int],
         
     | 
| 134 | 
         
            +
                need_bytes: bool,
         
     | 
| 135 | 
         
            +
            ):
         
     | 
| 136 | 
         
            +
                # extract bitext
         
     | 
| 137 | 
         
            +
                in_root = op.join(root, f"{src}-{tgt}")
         
     | 
| 138 | 
         
            +
                for lang in [src, tgt]:
         
     | 
| 139 | 
         
            +
                    _convert_train(
         
     | 
| 140 | 
         
            +
                        op.join(in_root, f"train.tags.{src}-{tgt}.{lang}"),
         
     | 
| 141 | 
         
            +
                        op.join(root, f"train.{lang}"),
         
     | 
| 142 | 
         
            +
                    )
         
     | 
| 143 | 
         
            +
                    _convert_xml(
         
     | 
| 144 | 
         
            +
                        op.join(in_root, f"IWSLT17.TED.dev2010.{src}-{tgt}.{lang}.xml"),
         
     | 
| 145 | 
         
            +
                        op.join(root, f"valid.{lang}"),
         
     | 
| 146 | 
         
            +
                    )
         
     | 
| 147 | 
         
            +
                    _convert_xml(
         
     | 
| 148 | 
         
            +
                        op.join(in_root, f"IWSLT17.TED.tst2015.{src}-{tgt}.{lang}.xml"),
         
     | 
| 149 | 
         
            +
                        op.join(root, f"test.{lang}"),
         
     | 
| 150 | 
         
            +
                    )
         
     | 
| 151 | 
         
            +
                # pre-tokenize
         
     | 
| 152 | 
         
            +
                for lang in [src, tgt]:
         
     | 
| 153 | 
         
            +
                    for split in SPLITS:
         
     | 
| 154 | 
         
            +
                        pretokenize(
         
     | 
| 155 | 
         
            +
                            op.join(root, f"{split}.{lang}"),
         
     | 
| 156 | 
         
            +
                            op.join(root, f"{split}.moses.{lang}"),
         
     | 
| 157 | 
         
            +
                            src,
         
     | 
| 158 | 
         
            +
                            tgt,
         
     | 
| 159 | 
         
            +
                        )
         
     | 
| 160 | 
         
            +
                # tokenize with BPE vocabulary
         
     | 
| 161 | 
         
            +
                if bpe_size is not None:
         
     | 
| 162 | 
         
            +
                    # learn vocabulary
         
     | 
| 163 | 
         
            +
                    concated_train_path = op.join(root, "train.all")
         
     | 
| 164 | 
         
            +
                    _concat_files(
         
     | 
| 165 | 
         
            +
                        [op.join(root, "train.moses.fr"), op.join(root, "train.moses.en")],
         
     | 
| 166 | 
         
            +
                        concated_train_path,
         
     | 
| 167 | 
         
            +
                    )
         
     | 
| 168 | 
         
            +
                    bpe_model_prefix = op.join(root, f"spm_bpe{bpe_size}")
         
     | 
| 169 | 
         
            +
                    _get_bpe(concated_train_path, bpe_model_prefix, bpe_size)
         
     | 
| 170 | 
         
            +
                    os.remove(concated_train_path)
         
     | 
| 171 | 
         
            +
                    # apply
         
     | 
| 172 | 
         
            +
                    for lang in [src, tgt]:
         
     | 
| 173 | 
         
            +
                        for split in SPLITS:
         
     | 
| 174 | 
         
            +
                            _apply_bpe(
         
     | 
| 175 | 
         
            +
                                bpe_model_prefix + ".model",
         
     | 
| 176 | 
         
            +
                                op.join(root, f"{split}.moses.{lang}"),
         
     | 
| 177 | 
         
            +
                                op.join(root, f"{split}.moses.bpe{bpe_size}.{lang}"),
         
     | 
| 178 | 
         
            +
                            )
         
     | 
| 179 | 
         
            +
                # tokenize with bytes vocabulary
         
     | 
| 180 | 
         
            +
                if need_bytes:
         
     | 
| 181 | 
         
            +
                    for lang in [src, tgt]:
         
     | 
| 182 | 
         
            +
                        for split in SPLITS:
         
     | 
| 183 | 
         
            +
                            _get_bytes(
         
     | 
| 184 | 
         
            +
                                op.join(root, f"{split}.moses.{lang}"),
         
     | 
| 185 | 
         
            +
                                op.join(root, f"{split}.moses.bytes.{lang}"),
         
     | 
| 186 | 
         
            +
                            )
         
     | 
| 187 | 
         
            +
                # tokenize with characters vocabulary
         
     | 
| 188 | 
         
            +
                if need_chars:
         
     | 
| 189 | 
         
            +
                    for lang in [src, tgt]:
         
     | 
| 190 | 
         
            +
                        for split in SPLITS:
         
     | 
| 191 | 
         
            +
                            _get_chars(
         
     | 
| 192 | 
         
            +
                                op.join(root, f"{split}.moses.{lang}"),
         
     | 
| 193 | 
         
            +
                                op.join(root, f"{split}.moses.chars.{lang}"),
         
     | 
| 194 | 
         
            +
                            )
         
     | 
| 195 | 
         
            +
                # tokenize with byte-level BPE vocabulary
         
     | 
| 196 | 
         
            +
                if bbpe_size is not None:
         
     | 
| 197 | 
         
            +
                    # learn vocabulary
         
     | 
| 198 | 
         
            +
                    bchar_path = op.join(root, "train.bchar")
         
     | 
| 199 | 
         
            +
                    _convert_to_bchar(op.join(root, "train.moses"), src, tgt, bchar_path)
         
     | 
| 200 | 
         
            +
                    bbpe_model_prefix = op.join(root, f"spm_bbpe{bbpe_size}")
         
     | 
| 201 | 
         
            +
                    _get_bpe(bchar_path, bbpe_model_prefix, bbpe_size)
         
     | 
| 202 | 
         
            +
                    os.remove(bchar_path)
         
     | 
| 203 | 
         
            +
                    # apply
         
     | 
| 204 | 
         
            +
                    for lang in [src, tgt]:
         
     | 
| 205 | 
         
            +
                        for split in SPLITS:
         
     | 
| 206 | 
         
            +
                            _apply_bbpe(
         
     | 
| 207 | 
         
            +
                                bbpe_model_prefix + ".model",
         
     | 
| 208 | 
         
            +
                                op.join(root, f"{split}.moses.{lang}"),
         
     | 
| 209 | 
         
            +
                                op.join(root, f"{split}.moses.bbpe{bbpe_size}.{lang}"),
         
     | 
| 210 | 
         
            +
                            )
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
            def main():
         
     | 
| 214 | 
         
            +
                parser = argparse.ArgumentParser()
         
     | 
| 215 | 
         
            +
                parser.add_argument("--root", type=str, default="data")
         
     | 
| 216 | 
         
            +
                parser.add_argument(
         
     | 
| 217 | 
         
            +
                    "--bpe-vocab",
         
     | 
| 218 | 
         
            +
                    default=None,
         
     | 
| 219 | 
         
            +
                    type=int,
         
     | 
| 220 | 
         
            +
                    help="Generate tokenized bitext with BPE of size K."
         
     | 
| 221 | 
         
            +
                    "Default to None (disabled).",
         
     | 
| 222 | 
         
            +
                )
         
     | 
| 223 | 
         
            +
                parser.add_argument(
         
     | 
| 224 | 
         
            +
                    "--bbpe-vocab",
         
     | 
| 225 | 
         
            +
                    default=None,
         
     | 
| 226 | 
         
            +
                    type=int,
         
     | 
| 227 | 
         
            +
                    help="Generate tokenized bitext with BBPE of size K."
         
     | 
| 228 | 
         
            +
                    "Default to None (disabled).",
         
     | 
| 229 | 
         
            +
                )
         
     | 
| 230 | 
         
            +
                parser.add_argument(
         
     | 
| 231 | 
         
            +
                    "--byte-vocab",
         
     | 
| 232 | 
         
            +
                    action="store_true",
         
     | 
| 233 | 
         
            +
                    help="Generate tokenized bitext with bytes vocabulary",
         
     | 
| 234 | 
         
            +
                )
         
     | 
| 235 | 
         
            +
                parser.add_argument(
         
     | 
| 236 | 
         
            +
                    "--char-vocab",
         
     | 
| 237 | 
         
            +
                    action="store_true",
         
     | 
| 238 | 
         
            +
                    help="Generate tokenized bitext with chars vocabulary",
         
     | 
| 239 | 
         
            +
                )
         
     | 
| 240 | 
         
            +
                args = parser.parse_args()
         
     | 
| 241 | 
         
            +
             
     | 
| 242 | 
         
            +
                preprocess_iwslt17(
         
     | 
| 243 | 
         
            +
                    args.root,
         
     | 
| 244 | 
         
            +
                    "fr",
         
     | 
| 245 | 
         
            +
                    "en",
         
     | 
| 246 | 
         
            +
                    args.bpe_vocab,
         
     | 
| 247 | 
         
            +
                    args.char_vocab,
         
     | 
| 248 | 
         
            +
                    args.bbpe_vocab,
         
     | 
| 249 | 
         
            +
                    args.byte_vocab,
         
     | 
| 250 | 
         
            +
                )
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 254 | 
         
            +
                main()
         
     | 
    	
        fairseq/examples/byte_level_bpe/get_data.sh
    ADDED
    
    | 
         @@ -0,0 +1,47 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            #!/bin/bash
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 4 | 
         
            +
            #
         
     | 
| 5 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 6 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            PY_BIN_ROOT=
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            # PyPI dependency
         
     | 
| 11 | 
         
            +
            ${PY_BIN_ROOT}pip install sentencepiece sacremoses
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            # Get data
         
     | 
| 14 | 
         
            +
            if [ ! -d "data" ]; then
         
     | 
| 15 | 
         
            +
              mkdir data
         
     | 
| 16 | 
         
            +
            fi
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            if [ ! -f "data/fr-en.tgz" ]; then
         
     | 
| 19 | 
         
            +
              wget https://wit3.fbk.eu/archive/2017-01-trnted/texts/fr/en/fr-en.tgz -P data
         
     | 
| 20 | 
         
            +
              tar xvf data/fr-en.tgz -C data
         
     | 
| 21 | 
         
            +
            fi
         
     | 
| 22 | 
         
            +
            ${PY_BIN_ROOT}python get_bitext.py --bpe-vocab 16384 --byte-vocab --char-vocab
         
     | 
| 23 | 
         
            +
            for VOCAB_SIZE in 2048 4096; do
         
     | 
| 24 | 
         
            +
              ${PY_BIN_ROOT}python get_bitext.py --bpe-vocab ${VOCAB_SIZE} --bbpe-vocab ${VOCAB_SIZE}
         
     | 
| 25 | 
         
            +
            done
         
     | 
| 26 | 
         
            +
            rm -r data/fr-en data/fr-en.tgz
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            # Generate binary dataset
         
     | 
| 29 | 
         
            +
            ${PY_BIN_ROOT}/fairseq-preprocess --source-lang fr --target-lang en --destdir data/bin_bpe16384 --joined-dictionary \
         
     | 
| 30 | 
         
            +
              --workers "$(nproc)" --trainpref data/train.moses.bpe16384 --validpref data/valid.moses.bpe16384 \
         
     | 
| 31 | 
         
            +
              --testpref data/test.moses.bpe16384
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            ${PY_BIN_ROOT}/fairseq-preprocess --source-lang fr --target-lang en --destdir data/bin_bytes --joined-dictionary \
         
     | 
| 34 | 
         
            +
              --workers "$(nproc)" --trainpref data/train.moses.bytes --validpref data/valid.moses.bytes \
         
     | 
| 35 | 
         
            +
              --testpref data/test.moses.bytes
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            ${PY_BIN_ROOT}/fairseq-preprocess --source-lang fr --target-lang en --destdir data/bin_chars --joined-dictionary \
         
     | 
| 38 | 
         
            +
              --workers "$(nproc)" --trainpref data/train.moses.chars --validpref data/valid.moses.chars \
         
     | 
| 39 | 
         
            +
              --testpref data/test.moses.chars
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            for VOCAB_SIZE in 2048 4096; do
         
     | 
| 42 | 
         
            +
              for TYPE in bbpe bpe; do
         
     | 
| 43 | 
         
            +
                ${PY_BIN_ROOT}/fairseq-preprocess --source-lang fr --target-lang en --destdir "data/bin_${TYPE}${VOCAB_SIZE}" \
         
     | 
| 44 | 
         
            +
                  --joined-dictionary --workers "$(nproc)" --trainpref "data/train.moses.${TYPE}${VOCAB_SIZE}" \
         
     | 
| 45 | 
         
            +
                  --validpref "data/valid.moses.${TYPE}${VOCAB_SIZE}" --testpref "data/test.moses.${TYPE}${VOCAB_SIZE}"
         
     | 
| 46 | 
         
            +
              done
         
     | 
| 47 | 
         
            +
            done
         
     | 
    	
        fairseq/examples/byte_level_bpe/gru_transformer.py
    ADDED
    
    | 
         @@ -0,0 +1,107 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 7 | 
         
            +
            #
         
     | 
| 8 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 9 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            import torch.nn as nn
         
     | 
| 12 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 13 | 
         
            +
            from fairseq.models import register_model, register_model_architecture
         
     | 
| 14 | 
         
            +
            from fairseq.models.transformer import TransformerEncoder, TransformerModel
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            @register_model("gru_transformer")
         
     | 
| 18 | 
         
            +
            class GRUTransformerModel(TransformerModel):
         
     | 
| 19 | 
         
            +
                @classmethod
         
     | 
| 20 | 
         
            +
                def build_encoder(cls, args, src_dict, embed_tokens):
         
     | 
| 21 | 
         
            +
                    return GRUTransformerEncoder(args, src_dict, embed_tokens)
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            class GRUTransformerEncoder(TransformerEncoder):
         
     | 
| 25 | 
         
            +
                def __init__(self, args, dictionary, embed_tokens):
         
     | 
| 26 | 
         
            +
                    super().__init__(args, dictionary, embed_tokens)
         
     | 
| 27 | 
         
            +
                    self.emb_ctx = nn.GRU(
         
     | 
| 28 | 
         
            +
                        input_size=embed_tokens.embedding_dim,
         
     | 
| 29 | 
         
            +
                        hidden_size=embed_tokens.embedding_dim // 2,
         
     | 
| 30 | 
         
            +
                        num_layers=1,
         
     | 
| 31 | 
         
            +
                        bidirectional=True,
         
     | 
| 32 | 
         
            +
                    )
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                def forward_embedding(self, src_tokens):
         
     | 
| 35 | 
         
            +
                    # embed tokens and positions
         
     | 
| 36 | 
         
            +
                    x = embed = self.embed_scale * self.embed_tokens(src_tokens)
         
     | 
| 37 | 
         
            +
                    if self.embed_positions is not None:
         
     | 
| 38 | 
         
            +
                        x = embed + self.embed_positions(src_tokens)
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                    # contextualize embeddings
         
     | 
| 41 | 
         
            +
                    x = x.transpose(0, 1)
         
     | 
| 42 | 
         
            +
                    x = self.dropout_module(x)
         
     | 
| 43 | 
         
            +
                    x, _ = self.emb_ctx.forward(x)
         
     | 
| 44 | 
         
            +
                    x = x.transpose(0, 1)
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                    if self.layernorm_embedding is not None:
         
     | 
| 47 | 
         
            +
                        x = self.layernorm_embedding(x)
         
     | 
| 48 | 
         
            +
                    x = self.dropout_module(x)
         
     | 
| 49 | 
         
            +
                    return x, embed
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            @register_model_architecture("gru_transformer", "gru_transformer")
         
     | 
| 53 | 
         
            +
            def gru_transformer_base_architecture(args):
         
     | 
| 54 | 
         
            +
                args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
         
     | 
| 55 | 
         
            +
                args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
         
     | 
| 56 | 
         
            +
                args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
         
     | 
| 57 | 
         
            +
                args.encoder_layers = getattr(args, "encoder_layers", 6)
         
     | 
| 58 | 
         
            +
                args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
         
     | 
| 59 | 
         
            +
                args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
         
     | 
| 60 | 
         
            +
                args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
         
     | 
| 61 | 
         
            +
                args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
         
     | 
| 62 | 
         
            +
                args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
         
     | 
| 63 | 
         
            +
                args.decoder_ffn_embed_dim = getattr(
         
     | 
| 64 | 
         
            +
                    args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
         
     | 
| 65 | 
         
            +
                )
         
     | 
| 66 | 
         
            +
                args.decoder_layers = getattr(args, "decoder_layers", 6)
         
     | 
| 67 | 
         
            +
                args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
         
     | 
| 68 | 
         
            +
                args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
         
     | 
| 69 | 
         
            +
                args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
         
     | 
| 70 | 
         
            +
                args.attention_dropout = getattr(args, "attention_dropout", 0.0)
         
     | 
| 71 | 
         
            +
                args.activation_dropout = getattr(args, "activation_dropout", 0.0)
         
     | 
| 72 | 
         
            +
                args.activation_fn = getattr(args, "activation_fn", "relu")
         
     | 
| 73 | 
         
            +
                args.dropout = getattr(args, "dropout", 0.1)
         
     | 
| 74 | 
         
            +
                args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
         
     | 
| 75 | 
         
            +
                args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
         
     | 
| 76 | 
         
            +
                args.share_decoder_input_output_embed = getattr(
         
     | 
| 77 | 
         
            +
                    args, "share_decoder_input_output_embed", False
         
     | 
| 78 | 
         
            +
                )
         
     | 
| 79 | 
         
            +
                args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
         
     | 
| 80 | 
         
            +
                args.no_token_positional_embeddings = getattr(
         
     | 
| 81 | 
         
            +
                    args, "no_token_positional_embeddings", False
         
     | 
| 82 | 
         
            +
                )
         
     | 
| 83 | 
         
            +
                args.adaptive_input = getattr(args, "adaptive_input", False)
         
     | 
| 84 | 
         
            +
                args.no_cross_attention = getattr(args, "no_cross_attention", False)
         
     | 
| 85 | 
         
            +
                args.cross_self_attention = getattr(args, "cross_self_attention", False)
         
     | 
| 86 | 
         
            +
                args.layer_wise_attention = getattr(args, "layer_wise_attention", False)
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                args.decoder_output_dim = getattr(
         
     | 
| 89 | 
         
            +
                    args, "decoder_output_dim", args.decoder_embed_dim
         
     | 
| 90 | 
         
            +
                )
         
     | 
| 91 | 
         
            +
                args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
         
     | 
| 94 | 
         
            +
                args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
            @register_model_architecture("gru_transformer", "gru_transformer_big")
         
     | 
| 98 | 
         
            +
            def gru_transformer_big(args):
         
     | 
| 99 | 
         
            +
                args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
         
     | 
| 100 | 
         
            +
                args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
         
     | 
| 101 | 
         
            +
                args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
         
     | 
| 102 | 
         
            +
                args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
         
     | 
| 103 | 
         
            +
                args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
         
     | 
| 104 | 
         
            +
                args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
         
     | 
| 105 | 
         
            +
                args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
         
     | 
| 106 | 
         
            +
                args.dropout = getattr(args, "dropout", 0.3)
         
     | 
| 107 | 
         
            +
                gru_transformer_base_architecture(args)
         
     | 
    	
        fairseq/examples/camembert/README.md
    ADDED
    
    | 
         @@ -0,0 +1,75 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # CamemBERT: a Tasty French Language Model
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            ## Introduction
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            [CamemBERT](https://arxiv.org/abs/1911.03894) is a pretrained language model trained on 138GB of French text based on RoBERTa.
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            Also available in [github.com/huggingface/transformers](https://github.com/huggingface/transformers/).
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            ## Pre-trained models
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            | Model                          | #params | Download                                                                                                                 | Arch. | Training data                     |
         
     | 
| 12 | 
         
            +
            |--------------------------------|---------|--------------------------------------------------------------------------------------------------------------------------|-------|-----------------------------------|
         
     | 
| 13 | 
         
            +
            | `camembert` / `camembert-base` | 110M    | [camembert-base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz)                             | Base  | OSCAR (138 GB of text)            |
         
     | 
| 14 | 
         
            +
            | `camembert-large`              | 335M    | [camembert-large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-large.tar.gz)                           | Large | CCNet (135 GB of text)            |
         
     | 
| 15 | 
         
            +
            | `camembert-base-ccnet`         | 110M    | [camembert-base-ccnet.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-base-ccnet.tar.gz)                 | Base  | CCNet (135 GB of text)            |
         
     | 
| 16 | 
         
            +
            | `camembert-base-wikipedia-4gb` | 110M    | [camembert-base-wikipedia-4gb.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-base-wikipedia-4gb.tar.gz) | Base  | Wikipedia (4 GB of text)          |
         
     | 
| 17 | 
         
            +
            | `camembert-base-oscar-4gb`     | 110M    | [camembert-base-oscar-4gb.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-base-oscar-4gb.tar.gz)         | Base  | Subsample of OSCAR (4 GB of text) |
         
     | 
| 18 | 
         
            +
            | `camembert-base-ccnet-4gb`     | 110M    | [camembert-base-ccnet-4gb.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-base-ccnet-4gb.tar.gz)         | Base  | Subsample of CCNet (4 GB of text) |
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            ## Example usage
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            ### fairseq
         
     | 
| 23 | 
         
            +
            ##### Load CamemBERT from torch.hub (PyTorch >= 1.1):
         
     | 
| 24 | 
         
            +
            ```python
         
     | 
| 25 | 
         
            +
            import torch
         
     | 
| 26 | 
         
            +
            camembert = torch.hub.load('pytorch/fairseq', 'camembert')
         
     | 
| 27 | 
         
            +
            camembert.eval()  # disable dropout (or leave in train mode to finetune)
         
     | 
| 28 | 
         
            +
            ```
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            ##### Load CamemBERT (for PyTorch 1.0 or custom models):
         
     | 
| 31 | 
         
            +
            ```python
         
     | 
| 32 | 
         
            +
            # Download camembert model
         
     | 
| 33 | 
         
            +
            wget https://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz
         
     | 
| 34 | 
         
            +
            tar -xzvf camembert.tar.gz
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            # Load the model in fairseq
         
     | 
| 37 | 
         
            +
            from fairseq.models.roberta import CamembertModel
         
     | 
| 38 | 
         
            +
            camembert = CamembertModel.from_pretrained('/path/to/camembert')
         
     | 
| 39 | 
         
            +
            camembert.eval()  # disable dropout (or leave in train mode to finetune)
         
     | 
| 40 | 
         
            +
            ```
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            ##### Filling masks:
         
     | 
| 43 | 
         
            +
            ```python
         
     | 
| 44 | 
         
            +
            masked_line = 'Le camembert est <mask> :)'
         
     | 
| 45 | 
         
            +
            camembert.fill_mask(masked_line, topk=3)
         
     | 
| 46 | 
         
            +
            # [('Le camembert est délicieux :)', 0.4909118115901947, ' délicieux'),
         
     | 
| 47 | 
         
            +
            #  ('Le camembert est excellent :)', 0.10556942224502563, ' excellent'),
         
     | 
| 48 | 
         
            +
            #  ('Le camembert est succulent :)', 0.03453322499990463, ' succulent')]
         
     | 
| 49 | 
         
            +
            ```
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
            ##### Extract features from Camembert:
         
     | 
| 52 | 
         
            +
            ```python
         
     | 
| 53 | 
         
            +
            # Extract the last layer's features
         
     | 
| 54 | 
         
            +
            line = "J'aime le camembert !"
         
     | 
| 55 | 
         
            +
            tokens = camembert.encode(line)
         
     | 
| 56 | 
         
            +
            last_layer_features = camembert.extract_features(tokens)
         
     | 
| 57 | 
         
            +
            assert last_layer_features.size() == torch.Size([1, 10, 768])
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            # Extract all layer's features (layer 0 is the embedding layer)
         
     | 
| 60 | 
         
            +
            all_layers = camembert.extract_features(tokens, return_all_hiddens=True)
         
     | 
| 61 | 
         
            +
            assert len(all_layers) == 13
         
     | 
| 62 | 
         
            +
            assert torch.all(all_layers[-1] == last_layer_features)
         
     | 
| 63 | 
         
            +
            ```
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            ## Citation
         
     | 
| 66 | 
         
            +
            If you use our work, please cite:
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            ```bibtex
         
     | 
| 69 | 
         
            +
            @inproceedings{martin2020camembert,
         
     | 
| 70 | 
         
            +
              title={CamemBERT: a Tasty French Language Model},
         
     | 
| 71 | 
         
            +
              author={Martin, Louis and Muller, Benjamin and Su{\'a}rez, Pedro Javier Ortiz and Dupont, Yoann and Romary, Laurent and de la Clergerie, {\'E}ric Villemonte and Seddah, Djam{\'e} and Sagot, Beno{\^\i}t},
         
     | 
| 72 | 
         
            +
              booktitle={Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics},
         
     | 
| 73 | 
         
            +
              year={2020}
         
     | 
| 74 | 
         
            +
            }
         
     | 
| 75 | 
         
            +
            ```
         
     | 
    	
        fairseq/examples/constrained_decoding/README.md
    ADDED
    
    | 
         @@ -0,0 +1,123 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # (Vectorized) Lexically constrained decoding with dynamic beam allocation
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            This page provides instructions for how to use lexically constrained decoding in Fairseq.
         
     | 
| 4 | 
         
            +
            Fairseq implements the code described in the following papers:
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            * [Fast Lexically Constrained Decoding With Dynamic Beam Allocation](https://www.aclweb.org/anthology/N18-1119/) (Post & Vilar, 2018)
         
     | 
| 7 | 
         
            +
            * [Improved Lexically Constrained Decoding for Translation and Monolingual Rewriting](https://www.aclweb.org/anthology/N19-1090/) (Hu et al., 2019)
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            ## Quick start
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            Constrained search is enabled by adding the command-line argument `--constraints` to `fairseq-interactive`.
         
     | 
| 12 | 
         
            +
            Constraints are appended to each line of input, separated by tabs. Each constraint (one or more tokens)
         
     | 
| 13 | 
         
            +
            is a separate field.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            The following command, using [Fairseq's WMT19 German--English model](https://github.com/pytorch/fairseq/blob/main/examples/wmt19/README.md),
         
     | 
| 16 | 
         
            +
            translates the sentence *Die maschinelle Übersetzung ist schwer zu kontrollieren.* with the constraints
         
     | 
| 17 | 
         
            +
            "hard" and "to influence".
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                echo -e "Die maschinelle Übersetzung ist schwer zu kontrollieren.\thard\ttoinfluence" \
         
     | 
| 20 | 
         
            +
                | normalize.py | tok.py \
         
     | 
| 21 | 
         
            +
                | fairseq-interactive /path/to/model \
         
     | 
| 22 | 
         
            +
                  --path /path/to/model/model1.pt \
         
     | 
| 23 | 
         
            +
                  --bpe fastbpe \
         
     | 
| 24 | 
         
            +
                  --bpe-codes /path/to/model/bpecodes \
         
     | 
| 25 | 
         
            +
                  --constraints \
         
     | 
| 26 | 
         
            +
                  -s de -t en \
         
     | 
| 27 | 
         
            +
                  --beam 10
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            (tok.py and normalize.py can be found in the same directory as this README; they are just shortcuts around Fairseq's WMT19 preprocessing).
         
     | 
| 30 | 
         
            +
            This will generate the following output:
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                [snip]
         
     | 
| 33 | 
         
            +
                S-0     Die masch@@ in@@ elle Über@@ setzung ist schwer zu kontrollieren .
         
     | 
| 34 | 
         
            +
                W-0     1.844   seconds
         
     | 
| 35 | 
         
            +
                C-0     hard
         
     | 
| 36 | 
         
            +
                C-0     influence
         
     | 
| 37 | 
         
            +
                H-0     -1.5333266258239746     Mach@@ ine trans@@ lation is hard to influence .
         
     | 
| 38 | 
         
            +
                D-0     -1.5333266258239746     Machine translation is hard to influence .
         
     | 
| 39 | 
         
            +
                P-0     -0.5434 -0.1423 -0.1930 -0.1415 -0.2346 -1.8031 -0.1701 -11.7727 -0.1815 -0.1511
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            By default, constraints are generated in the order supplied, with any number (zero or more) of tokens generated
         
     | 
| 42 | 
         
            +
            between constraints. If you wish for the decoder to order the constraints, then use `--constraints unordered`.
         
     | 
| 43 | 
         
            +
            Note that you may want to use a larger beam.
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            ## Implementation details
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            The heart of the implementation is in `fairseq/search.py`, which adds a `LexicallyConstrainedBeamSearch` instance.
         
     | 
| 48 | 
         
            +
            This instance of beam search tracks the progress of each hypothesis in the beam through the set of constraints
         
     | 
| 49 | 
         
            +
            provided for each input sentence. It does this using one of two classes, both found in `fairseq/token_generation_contstraints.py`:
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
            * OrderedConstraintState: assumes the `C` input constraints will be generated in the provided order
         
     | 
| 52 | 
         
            +
            * UnorderedConstraintState: tries to apply `C` (phrasal) constraints in all `C!` orders
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            ## Differences from Sockeye
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
            There are a number of [differences from Sockeye's implementation](https://awslabs.github.io/sockeye/inference.html#lexical-constraints).
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            * Generating constraints in the order supplied (the default option here) is not available in Sockeye.
         
     | 
| 59 | 
         
            +
            * Due to an improved beam allocation method, there is no need to prune the beam.
         
     | 
| 60 | 
         
            +
            * Again due to better allocation, beam sizes as low as 10 or even 5 are often sufficient.
         
     | 
| 61 | 
         
            +
            * [The vector extensions described in Hu et al.](https://github.com/edwardjhu/sockeye/tree/trie_constraints) (NAACL 2019) were never merged
         
     | 
| 62 | 
         
            +
              into the main Sockeye branch.
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            ## Citation
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
            The paper first describing lexical constraints for seq2seq decoding is:
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            ```bibtex
         
     | 
| 69 | 
         
            +
            @inproceedings{hokamp-liu-2017-lexically,
         
     | 
| 70 | 
         
            +
              title = "Lexically Constrained Decoding for Sequence Generation Using Grid Beam Search",
         
     | 
| 71 | 
         
            +
              author = "Hokamp, Chris  and
         
     | 
| 72 | 
         
            +
                Liu, Qun",
         
     | 
| 73 | 
         
            +
              booktitle = "Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
         
     | 
| 74 | 
         
            +
              month = jul,
         
     | 
| 75 | 
         
            +
              year = "2017",
         
     | 
| 76 | 
         
            +
              address = "Vancouver, Canada",
         
     | 
| 77 | 
         
            +
              publisher = "Association for Computational Linguistics",
         
     | 
| 78 | 
         
            +
              url = "https://www.aclweb.org/anthology/P17-1141",
         
     | 
| 79 | 
         
            +
              doi = "10.18653/v1/P17-1141",
         
     | 
| 80 | 
         
            +
              pages = "1535--1546",
         
     | 
| 81 | 
         
            +
            }
         
     | 
| 82 | 
         
            +
            ```
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
            The fairseq implementation uses the extensions described in
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
            ```bibtex
         
     | 
| 87 | 
         
            +
            @inproceedings{post-vilar-2018-fast,
         
     | 
| 88 | 
         
            +
                title = "Fast Lexically Constrained Decoding with Dynamic Beam Allocation for Neural Machine Translation",
         
     | 
| 89 | 
         
            +
                author = "Post, Matt  and
         
     | 
| 90 | 
         
            +
                  Vilar, David",
         
     | 
| 91 | 
         
            +
                booktitle = "Proceedings of the 2018 Conference of the North {A}merican Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long Papers)",
         
     | 
| 92 | 
         
            +
                month = jun,
         
     | 
| 93 | 
         
            +
                year = "2018",
         
     | 
| 94 | 
         
            +
                address = "New Orleans, Louisiana",
         
     | 
| 95 | 
         
            +
                publisher = "Association for Computational Linguistics",
         
     | 
| 96 | 
         
            +
                url = "https://www.aclweb.org/anthology/N18-1119",
         
     | 
| 97 | 
         
            +
                doi = "10.18653/v1/N18-1119",
         
     | 
| 98 | 
         
            +
                pages = "1314--1324",
         
     | 
| 99 | 
         
            +
            }
         
     | 
| 100 | 
         
            +
            ```
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
            and
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
            ```bibtex
         
     | 
| 105 | 
         
            +
            @inproceedings{hu-etal-2019-improved,
         
     | 
| 106 | 
         
            +
              title = "Improved Lexically Constrained Decoding for Translation and Monolingual Rewriting",
         
     | 
| 107 | 
         
            +
              author = "Hu, J. Edward  and
         
     | 
| 108 | 
         
            +
                Khayrallah, Huda  and
         
     | 
| 109 | 
         
            +
                Culkin, Ryan  and
         
     | 
| 110 | 
         
            +
                Xia, Patrick  and
         
     | 
| 111 | 
         
            +
                Chen, Tongfei  and
         
     | 
| 112 | 
         
            +
                Post, Matt  and
         
     | 
| 113 | 
         
            +
                Van Durme, Benjamin",
         
     | 
| 114 | 
         
            +
              booktitle = "Proceedings of the 2019 Conference of the North {A}merican Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers)",
         
     | 
| 115 | 
         
            +
              month = jun,
         
     | 
| 116 | 
         
            +
              year = "2019",
         
     | 
| 117 | 
         
            +
              address = "Minneapolis, Minnesota",
         
     | 
| 118 | 
         
            +
              publisher = "Association for Computational Linguistics",
         
     | 
| 119 | 
         
            +
              url = "https://www.aclweb.org/anthology/N19-1090",
         
     | 
| 120 | 
         
            +
              doi = "10.18653/v1/N19-1090",
         
     | 
| 121 | 
         
            +
              pages = "839--850",
         
     | 
| 122 | 
         
            +
            }
         
     | 
| 123 | 
         
            +
            ```
         
     | 
    	
        fairseq/examples/constrained_decoding/normalize.py
    ADDED
    
    | 
         @@ -0,0 +1,27 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            #!/usr/bin/env python3
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 4 | 
         
            +
            #
         
     | 
| 5 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 6 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import sys
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            from sacremoses.normalize import MosesPunctNormalizer
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            def main(args):
         
     | 
| 14 | 
         
            +
                normalizer = MosesPunctNormalizer(lang=args.lang, penn=args.penn)
         
     | 
| 15 | 
         
            +
                for line in sys.stdin:
         
     | 
| 16 | 
         
            +
                    print(normalizer.normalize(line.rstrip()), flush=True)
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 20 | 
         
            +
                import argparse
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                parser = argparse.ArgumentParser()
         
     | 
| 23 | 
         
            +
                parser.add_argument("--lang", "-l", default="en")
         
     | 
| 24 | 
         
            +
                parser.add_argument("--penn", "-p", action="store_true")
         
     | 
| 25 | 
         
            +
                args = parser.parse_args()
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                main(args)
         
     | 
    	
        fairseq/examples/constrained_decoding/tok.py
    ADDED
    
    | 
         @@ -0,0 +1,34 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            #!/usr/bin/env python3
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 4 | 
         
            +
            #
         
     | 
| 5 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 6 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import sys
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            import sacremoses
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            def main(args):
         
     | 
| 14 | 
         
            +
                """Tokenizes, preserving tabs"""
         
     | 
| 15 | 
         
            +
                mt = sacremoses.MosesTokenizer(lang=args.lang)
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                def tok(s):
         
     | 
| 18 | 
         
            +
                    return mt.tokenize(s, return_str=True)
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                for line in sys.stdin:
         
     | 
| 21 | 
         
            +
                    parts = list(map(tok, line.split("\t")))
         
     | 
| 22 | 
         
            +
                    print(*parts, sep="\t", flush=True)
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 26 | 
         
            +
                import argparse
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                parser = argparse.ArgumentParser()
         
     | 
| 29 | 
         
            +
                parser.add_argument("--lang", "-l", default="en")
         
     | 
| 30 | 
         
            +
                parser.add_argument("--penn", "-p", action="store_true")
         
     | 
| 31 | 
         
            +
                parser.add_argument("--fields", "-f", help="fields to tokenize")
         
     | 
| 32 | 
         
            +
                args = parser.parse_args()
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                main(args)
         
     | 
    	
        fairseq/examples/conv_seq2seq/README.md
    ADDED
    
    | 
         @@ -0,0 +1,25 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Convolutional Sequence to Sequence Learning (Gehring et al., 2017)
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            ## Pre-trained models
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            Description | Dataset | Model | Test set(s)
         
     | 
| 6 | 
         
            +
            ---|---|---|---
         
     | 
| 7 | 
         
            +
            Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2) <br> newstest2012/2013: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.ntst1213.tar.bz2)
         
     | 
| 8 | 
         
            +
            Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-German](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-de.newstest2014.tar.bz2)
         
     | 
| 9 | 
         
            +
            Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT17 English-German](http://statmt.org/wmt17/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.v2.en-de.newstest2014.tar.bz2)
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            ## Example usage
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            See the [translation README](../translation/README.md) for instructions on reproducing results for WMT'14 En-De and
         
     | 
| 14 | 
         
            +
            WMT'14 En-Fr using the `fconv_wmt_en_de` and `fconv_wmt_en_fr` model architectures.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            ## Citation
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            ```bibtex
         
     | 
| 19 | 
         
            +
            @inproceedings{gehring2017convs2s,
         
     | 
| 20 | 
         
            +
              title = {Convolutional Sequence to Sequence Learning},
         
     | 
| 21 | 
         
            +
              author = {Gehring, Jonas, and Auli, Michael and Grangier, David and Yarats, Denis and Dauphin, Yann N},
         
     | 
| 22 | 
         
            +
              booktitle = {Proc. of ICML},
         
     | 
| 23 | 
         
            +
              year = 2017,
         
     | 
| 24 | 
         
            +
            }
         
     | 
| 25 | 
         
            +
            ```
         
     | 
    	
        fairseq/examples/criss/README.md
    ADDED
    
    | 
         @@ -0,0 +1,61 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Cross-lingual Retrieval for Iterative Self-Supervised Training
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            https://arxiv.org/pdf/2006.09526.pdf
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            ## Introduction
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            CRISS is a multilingual sequence-to-sequnce pretraining method where mining and training processes are applied iteratively, improving cross-lingual alignment and translation ability at the same time.
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            ## Requirements:
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            * faiss: https://github.com/facebookresearch/faiss
         
     | 
| 12 | 
         
            +
            * mosesdecoder: https://github.com/moses-smt/mosesdecoder
         
     | 
| 13 | 
         
            +
            * flores: https://github.com/facebookresearch/flores
         
     | 
| 14 | 
         
            +
            * LASER: https://github.com/facebookresearch/LASER
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            ## Unsupervised Machine Translation
         
     | 
| 17 | 
         
            +
            ##### 1. Download and decompress CRISS checkpoints
         
     | 
| 18 | 
         
            +
            ```
         
     | 
| 19 | 
         
            +
            cd examples/criss
         
     | 
| 20 | 
         
            +
            wget https://dl.fbaipublicfiles.com/criss/criss_3rd_checkpoints.tar.gz
         
     | 
| 21 | 
         
            +
            tar -xf criss_checkpoints.tar.gz
         
     | 
| 22 | 
         
            +
            ```
         
     | 
| 23 | 
         
            +
            ##### 2. Download and preprocess Flores test dataset
         
     | 
| 24 | 
         
            +
            Make sure to run all scripts from examples/criss directory
         
     | 
| 25 | 
         
            +
            ```
         
     | 
| 26 | 
         
            +
            bash download_and_preprocess_flores_test.sh
         
     | 
| 27 | 
         
            +
            ```
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            ##### 3. Run Evaluation on Sinhala-English
         
     | 
| 30 | 
         
            +
            ```
         
     | 
| 31 | 
         
            +
            bash unsupervised_mt/eval.sh
         
     | 
| 32 | 
         
            +
            ```
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            ## Sentence Retrieval
         
     | 
| 35 | 
         
            +
            ##### 1. Download and preprocess Tatoeba dataset
         
     | 
| 36 | 
         
            +
            ```
         
     | 
| 37 | 
         
            +
            bash download_and_preprocess_tatoeba.sh
         
     | 
| 38 | 
         
            +
            ```
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            ##### 2. Run Sentence Retrieval on Tatoeba Kazakh-English
         
     | 
| 41 | 
         
            +
            ```
         
     | 
| 42 | 
         
            +
            bash sentence_retrieval/sentence_retrieval_tatoeba.sh
         
     | 
| 43 | 
         
            +
            ```
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            ## Mining
         
     | 
| 46 | 
         
            +
            ##### 1. Install faiss
         
     | 
| 47 | 
         
            +
            Follow instructions on https://github.com/facebookresearch/faiss/blob/master/INSTALL.md
         
     | 
| 48 | 
         
            +
            ##### 2. Mine pseudo-parallel data between Kazakh and English
         
     | 
| 49 | 
         
            +
            ```
         
     | 
| 50 | 
         
            +
            bash mining/mine_example.sh
         
     | 
| 51 | 
         
            +
            ```
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
            ## Citation
         
     | 
| 54 | 
         
            +
            ```bibtex
         
     | 
| 55 | 
         
            +
            @article{tran2020cross,
         
     | 
| 56 | 
         
            +
              title={Cross-lingual retrieval for iterative self-supervised training},
         
     | 
| 57 | 
         
            +
              author={Tran, Chau and Tang, Yuqing and Li, Xian and Gu, Jiatao},
         
     | 
| 58 | 
         
            +
              journal={arXiv preprint arXiv:2006.09526},
         
     | 
| 59 | 
         
            +
              year={2020}
         
     | 
| 60 | 
         
            +
            }
         
     | 
| 61 | 
         
            +
            ```
         
     | 
    	
        fairseq/examples/criss/download_and_preprocess_flores_test.sh
    ADDED
    
    | 
         @@ -0,0 +1,64 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            #!/bin/bash
         
     | 
| 2 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 3 | 
         
            +
            # All rights reserved.
         
     | 
| 4 | 
         
            +
            #
         
     | 
| 5 | 
         
            +
            # This source code is licensed under the license found in the
         
     | 
| 6 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            SPM_ENCODE=flores/scripts/spm_encode.py
         
     | 
| 9 | 
         
            +
            DATA=data_tmp
         
     | 
| 10 | 
         
            +
            SPM_MODEL=criss_checkpoints/sentence.bpe.model
         
     | 
| 11 | 
         
            +
            DICT=criss_checkpoints/dict.txt
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            download_data() {
         
     | 
| 14 | 
         
            +
              CORPORA=$1
         
     | 
| 15 | 
         
            +
              URL=$2
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
              if [ -f $CORPORA ]; then
         
     | 
| 18 | 
         
            +
                echo "$CORPORA already exists, skipping download"
         
     | 
| 19 | 
         
            +
              else
         
     | 
| 20 | 
         
            +
                echo "Downloading $URL"
         
     | 
| 21 | 
         
            +
                wget $URL -O $CORPORA --no-check-certificate || rm -f $CORPORA
         
     | 
| 22 | 
         
            +
                if [ -f $CORPORA ]; then
         
     | 
| 23 | 
         
            +
                  echo "$URL successfully downloaded."
         
     | 
| 24 | 
         
            +
                else
         
     | 
| 25 | 
         
            +
                  echo "$URL not successfully downloaded."
         
     | 
| 26 | 
         
            +
                  rm -f $CORPORA
         
     | 
| 27 | 
         
            +
                fi
         
     | 
| 28 | 
         
            +
              fi
         
     | 
| 29 | 
         
            +
            }
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            if [[ -f flores ]]; then
         
     | 
| 32 | 
         
            +
              echo "flores already cloned"
         
     | 
| 33 | 
         
            +
            else
         
     | 
| 34 | 
         
            +
              git clone https://github.com/facebookresearch/flores
         
     | 
| 35 | 
         
            +
            fi
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            mkdir -p $DATA
         
     | 
| 38 | 
         
            +
            download_data $DATA/wikipedia_en_ne_si_test_sets.tgz "https://github.com/facebookresearch/flores/raw/master/data/wikipedia_en_ne_si_test_sets.tgz"
         
     | 
| 39 | 
         
            +
            pushd $DATA
         
     | 
| 40 | 
         
            +
            pwd
         
     | 
| 41 | 
         
            +
            tar -vxf wikipedia_en_ne_si_test_sets.tgz
         
     | 
| 42 | 
         
            +
            popd
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            for lang in ne_NP si_LK; do
         
     | 
| 46 | 
         
            +
              datadir=$DATA/${lang}-en_XX-flores
         
     | 
| 47 | 
         
            +
              rm -rf $datadir
         
     | 
| 48 | 
         
            +
              mkdir -p $datadir
         
     | 
| 49 | 
         
            +
              TEST_PREFIX=$DATA/wikipedia_en_ne_si_test_sets/wikipedia.test
         
     | 
| 50 | 
         
            +
              python $SPM_ENCODE \
         
     | 
| 51 | 
         
            +
                --model ${SPM_MODEL} \
         
     | 
| 52 | 
         
            +
                --output_format=piece \
         
     | 
| 53 | 
         
            +
                --inputs ${TEST_PREFIX}.${lang:0:2}-en.${lang:0:2} ${TEST_PREFIX}.${lang:0:2}-en.en \
         
     | 
| 54 | 
         
            +
                --outputs $datadir/test.bpe.${lang}-en_XX.${lang} $datadir/test.bpe.${lang}-en_XX.en_XX
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
              # binarize data
         
     | 
| 57 | 
         
            +
              fairseq-preprocess \
         
     | 
| 58 | 
         
            +
                --source-lang ${lang} --target-lang en_XX \
         
     | 
| 59 | 
         
            +
                --testpref $datadir/test.bpe.${lang}-en_XX \
         
     | 
| 60 | 
         
            +
                --destdir $datadir \
         
     | 
| 61 | 
         
            +
                --srcdict ${DICT} \
         
     | 
| 62 | 
         
            +
                --joined-dictionary \
         
     | 
| 63 | 
         
            +
                --workers 4
         
     | 
| 64 | 
         
            +
            done
         
     | 
    	
        fairseq/examples/criss/download_and_preprocess_tatoeba.sh
    ADDED
    
    | 
         @@ -0,0 +1,46 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            #!/bin/bash
         
     | 
| 2 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 3 | 
         
            +
            # All rights reserved.
         
     | 
| 4 | 
         
            +
            #
         
     | 
| 5 | 
         
            +
            # This source code is licensed under the license found in the
         
     | 
| 6 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            SPM_ENCODE=flores/scripts/spm_encode.py
         
     | 
| 9 | 
         
            +
            DATA=data_tmp
         
     | 
| 10 | 
         
            +
            SPM_MODEL=criss_checkpoints/sentence.bpe.model
         
     | 
| 11 | 
         
            +
            DICT=criss_checkpoints/dict.txt
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            if [[ -f flores ]]; then
         
     | 
| 14 | 
         
            +
              echo "flores already cloned"
         
     | 
| 15 | 
         
            +
            else
         
     | 
| 16 | 
         
            +
              git clone https://github.com/facebookresearch/flores
         
     | 
| 17 | 
         
            +
            fi
         
     | 
| 18 | 
         
            +
            if [[ -f LASER ]]; then
         
     | 
| 19 | 
         
            +
              echo "LASER already cloned"
         
     | 
| 20 | 
         
            +
            else
         
     | 
| 21 | 
         
            +
              git clone https://github.com/facebookresearch/LASER
         
     | 
| 22 | 
         
            +
            fi
         
     | 
| 23 | 
         
            +
            mkdir -p data_tmp
         
     | 
| 24 | 
         
            +
            declare -A lang_tatoeba_map=( ["ar_AR"]="ara" ["de_DE"]="deu"  ["es_XX"]="spa" ["et_EE"]="est" ["fi_FI"]="fin" ["fr_XX"]="fra" ["hi_IN"]="hin" ["it_IT"]="ita" ["ja_XX"]="jpn" ["ko_KR"]="kor" ["kk_KZ"]="kaz" ["nl_XX"]="nld" ["ru_RU"]="rus" ["tr_TR"]="tur" ["vi_VN"]="vie" ["zh_CN"]="cmn")
         
     | 
| 25 | 
         
            +
            for lang in ar_AR de_DE es_XX et_EE fi_FI fr_XX hi_IN it_IT ja_XX kk_KZ ko_KR nl_XX ru_RU tr_TR vi_VN zh_CN; do
         
     | 
| 26 | 
         
            +
              lang_tatoeba=${lang_tatoeba_map[$lang]}
         
     | 
| 27 | 
         
            +
              echo $lang_tatoeba
         
     | 
| 28 | 
         
            +
              datadir=$DATA/${lang}-en_XX-tatoeba
         
     | 
| 29 | 
         
            +
              rm -rf $datadir
         
     | 
| 30 | 
         
            +
              mkdir -p $datadir
         
     | 
| 31 | 
         
            +
              TEST_PREFIX=LASER/data/tatoeba/v1/tatoeba
         
     | 
| 32 | 
         
            +
              python $SPM_ENCODE \
         
     | 
| 33 | 
         
            +
                --model ${SPM_MODEL} \
         
     | 
| 34 | 
         
            +
                --output_format=piece \
         
     | 
| 35 | 
         
            +
                --inputs ${TEST_PREFIX}.${lang_tatoeba}-eng.${lang_tatoeba} ${TEST_PREFIX}.${lang_tatoeba}-eng.eng \
         
     | 
| 36 | 
         
            +
                --outputs $datadir/test.bpe.${lang}-en_XX.${lang} $datadir/test.bpe.${lang}-en_XX.en_XX
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
              # binarize data
         
     | 
| 39 | 
         
            +
              fairseq-preprocess \
         
     | 
| 40 | 
         
            +
                --source-lang ${lang} --target-lang en_XX \
         
     | 
| 41 | 
         
            +
                --testpref $datadir/test.bpe.${lang}-en_XX \
         
     | 
| 42 | 
         
            +
                --destdir $datadir \
         
     | 
| 43 | 
         
            +
                --srcdict ${DICT} \
         
     | 
| 44 | 
         
            +
                --joined-dictionary \
         
     | 
| 45 | 
         
            +
                --workers 4
         
     | 
| 46 | 
         
            +
            done
         
     | 
    	
        fairseq/examples/criss/mining/mine.py
    ADDED
    
    | 
         @@ -0,0 +1,240 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            #!/usr/bin/env python3 -u
         
     | 
| 2 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 3 | 
         
            +
            #
         
     | 
| 4 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 5 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 6 | 
         
            +
            import argparse
         
     | 
| 7 | 
         
            +
            import glob
         
     | 
| 8 | 
         
            +
            from subprocess import check_call
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            try:
         
     | 
| 11 | 
         
            +
                import faiss
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                has_faiss = True
         
     | 
| 14 | 
         
            +
            except ImportError:
         
     | 
| 15 | 
         
            +
                has_faiss = False
         
     | 
| 16 | 
         
            +
            import numpy as np
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            GB = 1024 * 1024 * 1024
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            def call(cmd):
         
     | 
| 23 | 
         
            +
                print(cmd)
         
     | 
| 24 | 
         
            +
                check_call(cmd, shell=True)
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            def get_batches(directory, lang, prefix="all_avg_pool"):
         
     | 
| 28 | 
         
            +
                print(f"Finding in {directory}/{prefix}.{lang}*")
         
     | 
| 29 | 
         
            +
                files = glob.glob(f"{directory}/{prefix}.{lang}*")
         
     | 
| 30 | 
         
            +
                emb_files = []
         
     | 
| 31 | 
         
            +
                txt_files = []
         
     | 
| 32 | 
         
            +
                for emb_fi in files:
         
     | 
| 33 | 
         
            +
                    emb_files.append(emb_fi)
         
     | 
| 34 | 
         
            +
                    txt_fi = emb_fi.replace(prefix, "sentences")
         
     | 
| 35 | 
         
            +
                    txt_files.append(txt_fi)
         
     | 
| 36 | 
         
            +
                return emb_files, txt_files
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            def load_batch(emb_file, dim):
         
     | 
| 40 | 
         
            +
                embeddings = np.fromfile(emb_file, dtype=np.float32)
         
     | 
| 41 | 
         
            +
                num_rows = int(embeddings.shape[0] / dim)
         
     | 
| 42 | 
         
            +
                embeddings = embeddings.reshape((num_rows, dim))
         
     | 
| 43 | 
         
            +
                faiss.normalize_L2(embeddings)
         
     | 
| 44 | 
         
            +
                return embeddings
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction="x2y"):
         
     | 
| 48 | 
         
            +
                if not has_faiss:
         
     | 
| 49 | 
         
            +
                    raise ImportError("Please install Faiss")
         
     | 
| 50 | 
         
            +
                sims = []
         
     | 
| 51 | 
         
            +
                inds = []
         
     | 
| 52 | 
         
            +
                xfrom = 0
         
     | 
| 53 | 
         
            +
                xto = 0
         
     | 
| 54 | 
         
            +
                for x_batch_f in x_batches_f:
         
     | 
| 55 | 
         
            +
                    yfrom = 0
         
     | 
| 56 | 
         
            +
                    yto = 0
         
     | 
| 57 | 
         
            +
                    x_batch = load_batch(x_batch_f, dim)
         
     | 
| 58 | 
         
            +
                    xto = xfrom + x_batch.shape[0]
         
     | 
| 59 | 
         
            +
                    bsims, binds = [], []
         
     | 
| 60 | 
         
            +
                    for y_batch_f in y_batches_f:
         
     | 
| 61 | 
         
            +
                        y_batch = load_batch(y_batch_f, dim)
         
     | 
| 62 | 
         
            +
                        neighbor_size = min(k, y_batch.shape[0])
         
     | 
| 63 | 
         
            +
                        yto = yfrom + y_batch.shape[0]
         
     | 
| 64 | 
         
            +
                        print("{}-{}  ->  {}-{}".format(xfrom, xto, yfrom, yto))
         
     | 
| 65 | 
         
            +
                        idx = faiss.IndexFlatIP(dim)
         
     | 
| 66 | 
         
            +
                        idx = faiss.index_cpu_to_all_gpus(idx)
         
     | 
| 67 | 
         
            +
                        idx.add(y_batch)
         
     | 
| 68 | 
         
            +
                        bsim, bind = idx.search(x_batch, neighbor_size)
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                        bsims.append(bsim)
         
     | 
| 71 | 
         
            +
                        binds.append(bind + yfrom)
         
     | 
| 72 | 
         
            +
                        yfrom += y_batch.shape[0]
         
     | 
| 73 | 
         
            +
                        del idx
         
     | 
| 74 | 
         
            +
                        del y_batch
         
     | 
| 75 | 
         
            +
                    bsims = np.concatenate(bsims, axis=1)
         
     | 
| 76 | 
         
            +
                    binds = np.concatenate(binds, axis=1)
         
     | 
| 77 | 
         
            +
                    aux = np.argsort(-bsims, axis=1)
         
     | 
| 78 | 
         
            +
                    sim_batch = np.zeros((x_batch.shape[0], k), dtype=np.float32)
         
     | 
| 79 | 
         
            +
                    ind_batch = np.zeros((x_batch.shape[0], k), dtype=np.int64)
         
     | 
| 80 | 
         
            +
                    for i in range(x_batch.shape[0]):
         
     | 
| 81 | 
         
            +
                        for j in range(k):
         
     | 
| 82 | 
         
            +
                            sim_batch[i, j] = bsims[i, aux[i, j]]
         
     | 
| 83 | 
         
            +
                            ind_batch[i, j] = binds[i, aux[i, j]]
         
     | 
| 84 | 
         
            +
                    sims.append(sim_batch)
         
     | 
| 85 | 
         
            +
                    inds.append(ind_batch)
         
     | 
| 86 | 
         
            +
                    xfrom += x_batch.shape[0]
         
     | 
| 87 | 
         
            +
                    del x_batch
         
     | 
| 88 | 
         
            +
                sim = np.concatenate(sims, axis=0)
         
     | 
| 89 | 
         
            +
                ind = np.concatenate(inds, axis=0)
         
     | 
| 90 | 
         
            +
                return sim, ind
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
            def score(sim, fwd_mean, bwd_mean, margin):
         
     | 
| 94 | 
         
            +
                return margin(sim, (fwd_mean + bwd_mean) / 2)
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
            def score_candidates(
         
     | 
| 98 | 
         
            +
                sim_mat, candidate_inds, fwd_mean, bwd_mean, margin, verbose=False
         
     | 
| 99 | 
         
            +
            ):
         
     | 
| 100 | 
         
            +
                print(" - scoring {:d} candidates".format(sim_mat.shape[0]))
         
     | 
| 101 | 
         
            +
                scores = np.zeros(candidate_inds.shape)
         
     | 
| 102 | 
         
            +
                for i in range(scores.shape[0]):
         
     | 
| 103 | 
         
            +
                    for j in range(scores.shape[1]):
         
     | 
| 104 | 
         
            +
                        k = int(candidate_inds[i, j])
         
     | 
| 105 | 
         
            +
                        scores[i, j] = score(sim_mat[i, j], fwd_mean[i], bwd_mean[k], margin)
         
     | 
| 106 | 
         
            +
                return scores
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
            def load_text(files):
         
     | 
| 110 | 
         
            +
                all_sentences = []
         
     | 
| 111 | 
         
            +
                for fi in files:
         
     | 
| 112 | 
         
            +
                    with open(fi) as sentence_fi:
         
     | 
| 113 | 
         
            +
                        for line in sentence_fi:
         
     | 
| 114 | 
         
            +
                            all_sentences.append(line.strip())
         
     | 
| 115 | 
         
            +
                print(f"Read {len(all_sentences)} sentences")
         
     | 
| 116 | 
         
            +
                return all_sentences
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 120 | 
         
            +
                parser = argparse.ArgumentParser(description="Mine bitext")
         
     | 
| 121 | 
         
            +
                parser.add_argument("--src-lang", help="Source language")
         
     | 
| 122 | 
         
            +
                parser.add_argument("--tgt-lang", help="Target language")
         
     | 
| 123 | 
         
            +
                parser.add_argument(
         
     | 
| 124 | 
         
            +
                    "--dict-path", help="Path to dictionary file", default="dict.txt"
         
     | 
| 125 | 
         
            +
                )
         
     | 
| 126 | 
         
            +
                parser.add_argument(
         
     | 
| 127 | 
         
            +
                    "--spm-path", help="Path to SPM model file", default="sentence.bpe.model"
         
     | 
| 128 | 
         
            +
                )
         
     | 
| 129 | 
         
            +
                parser.add_argument("--dim", type=int, default=1024, help="Embedding dimension")
         
     | 
| 130 | 
         
            +
                parser.add_argument("--mem", type=int, default=5, help="Memory in GB")
         
     | 
| 131 | 
         
            +
                parser.add_argument("--src-dir", help="Source directory")
         
     | 
| 132 | 
         
            +
                parser.add_argument("--tgt-dir", help="Target directory")
         
     | 
| 133 | 
         
            +
                parser.add_argument("--output", help="Output path")
         
     | 
| 134 | 
         
            +
                parser.add_argument(
         
     | 
| 135 | 
         
            +
                    "--neighborhood", type=int, default=4, help="Embedding dimension"
         
     | 
| 136 | 
         
            +
                )
         
     | 
| 137 | 
         
            +
                parser.add_argument(
         
     | 
| 138 | 
         
            +
                    "--threshold", type=float, default=1.06, help="Threshold on mined bitext"
         
     | 
| 139 | 
         
            +
                )
         
     | 
| 140 | 
         
            +
                parser.add_argument(
         
     | 
| 141 | 
         
            +
                    "--valid-size",
         
     | 
| 142 | 
         
            +
                    type=int,
         
     | 
| 143 | 
         
            +
                    default=2000,
         
     | 
| 144 | 
         
            +
                    help="Number of sentences used for validation set",
         
     | 
| 145 | 
         
            +
                )
         
     | 
| 146 | 
         
            +
                parser.add_argument(
         
     | 
| 147 | 
         
            +
                    "--min-count",
         
     | 
| 148 | 
         
            +
                    type=int,
         
     | 
| 149 | 
         
            +
                    default=50000,
         
     | 
| 150 | 
         
            +
                    help="Min num sentences used for each language",
         
     | 
| 151 | 
         
            +
                )
         
     | 
| 152 | 
         
            +
                args = parser.parse_args()
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                x_batches_f, x_sents_f = get_batches(args.src_dir, args.src_lang)
         
     | 
| 155 | 
         
            +
                y_batches_f, y_sents_f = get_batches(args.tgt_dir, args.tgt_lang)
         
     | 
| 156 | 
         
            +
                margin = lambda a, b: a / b
         
     | 
| 157 | 
         
            +
                y2x_sim, y2x_ind = knnGPU_sharded(
         
     | 
| 158 | 
         
            +
                    y_batches_f, x_batches_f, args.dim, args.neighborhood, direction="y2x"
         
     | 
| 159 | 
         
            +
                )
         
     | 
| 160 | 
         
            +
                x2y_sim, x2y_ind = knnGPU_sharded(
         
     | 
| 161 | 
         
            +
                    x_batches_f, y_batches_f, args.dim, args.neighborhood, direction="x2y"
         
     | 
| 162 | 
         
            +
                )
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                x2y_mean = x2y_sim.mean(axis=1)
         
     | 
| 165 | 
         
            +
                y2x_mean = y2x_sim.mean(axis=1)
         
     | 
| 166 | 
         
            +
                fwd_scores = score_candidates(x2y_sim, x2y_ind, x2y_mean, y2x_mean, margin)
         
     | 
| 167 | 
         
            +
                bwd_scores = score_candidates(y2x_sim, y2x_ind, y2x_mean, x2y_mean, margin)
         
     | 
| 168 | 
         
            +
                fwd_best = x2y_ind[np.arange(x2y_sim.shape[0]), fwd_scores.argmax(axis=1)]
         
     | 
| 169 | 
         
            +
                bwd_best = y2x_ind[np.arange(y2x_sim.shape[0]), bwd_scores.argmax(axis=1)]
         
     | 
| 170 | 
         
            +
                indices = np.stack(
         
     | 
| 171 | 
         
            +
                    (
         
     | 
| 172 | 
         
            +
                        np.concatenate((np.arange(x2y_ind.shape[0]), bwd_best)),
         
     | 
| 173 | 
         
            +
                        np.concatenate((fwd_best, np.arange(y2x_ind.shape[0]))),
         
     | 
| 174 | 
         
            +
                    ),
         
     | 
| 175 | 
         
            +
                    axis=1,
         
     | 
| 176 | 
         
            +
                )
         
     | 
| 177 | 
         
            +
                scores = np.concatenate((fwd_scores.max(axis=1), bwd_scores.max(axis=1)))
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                x_sentences = load_text(x_sents_f)
         
     | 
| 180 | 
         
            +
                y_sentences = load_text(y_sents_f)
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                threshold = args.threshold
         
     | 
| 183 | 
         
            +
                min_count = args.min_count
         
     | 
| 184 | 
         
            +
                seen_src, seen_trg = set(), set()
         
     | 
| 185 | 
         
            +
                directory = args.output
         
     | 
| 186 | 
         
            +
                call(f"mkdir -p {directory}")
         
     | 
| 187 | 
         
            +
                src_out = open(
         
     | 
| 188 | 
         
            +
                    f"{directory}/all.{args.src_lang}",
         
     | 
| 189 | 
         
            +
                    mode="w",
         
     | 
| 190 | 
         
            +
                    encoding="utf-8",
         
     | 
| 191 | 
         
            +
                    errors="surrogateescape",
         
     | 
| 192 | 
         
            +
                )
         
     | 
| 193 | 
         
            +
                tgt_out = open(
         
     | 
| 194 | 
         
            +
                    f"{directory}/all.{args.tgt_lang}",
         
     | 
| 195 | 
         
            +
                    mode="w",
         
     | 
| 196 | 
         
            +
                    encoding="utf-8",
         
     | 
| 197 | 
         
            +
                    errors="surrogateescape",
         
     | 
| 198 | 
         
            +
                )
         
     | 
| 199 | 
         
            +
                scores_out = open(
         
     | 
| 200 | 
         
            +
                    f"{directory}/all.scores", mode="w", encoding="utf-8", errors="surrogateescape"
         
     | 
| 201 | 
         
            +
                )
         
     | 
| 202 | 
         
            +
                count = 0
         
     | 
| 203 | 
         
            +
                for i in np.argsort(-scores):
         
     | 
| 204 | 
         
            +
                    src_ind, trg_ind = indices[i]
         
     | 
| 205 | 
         
            +
                    if src_ind not in seen_src and trg_ind not in seen_trg:
         
     | 
| 206 | 
         
            +
                        seen_src.add(src_ind)
         
     | 
| 207 | 
         
            +
                        seen_trg.add(trg_ind)
         
     | 
| 208 | 
         
            +
                        if scores[i] > threshold or count < min_count:
         
     | 
| 209 | 
         
            +
                            if x_sentences[src_ind]:
         
     | 
| 210 | 
         
            +
                                print(scores[i], file=scores_out)
         
     | 
| 211 | 
         
            +
                                print(x_sentences[src_ind], file=src_out)
         
     | 
| 212 | 
         
            +
                                print(y_sentences[trg_ind], file=tgt_out)
         
     | 
| 213 | 
         
            +
                                count += 1
         
     | 
| 214 | 
         
            +
                            else:
         
     | 
| 215 | 
         
            +
                                print(f"Ignoring sentence: {x_sentences[src_ind]}")
         
     | 
| 216 | 
         
            +
                src_out.close()
         
     | 
| 217 | 
         
            +
                tgt_out.close()
         
     | 
| 218 | 
         
            +
                scores_out.close()
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
                print(f"Found {count} pairs for threshold={threshold}")
         
     | 
| 221 | 
         
            +
                with open(f"{directory}/all.{args.src_lang}") as all_s, open(
         
     | 
| 222 | 
         
            +
                    f"{directory}/all.{args.tgt_lang}"
         
     | 
| 223 | 
         
            +
                ) as all_t, open(f"{directory}/valid.{args.src_lang}", "w") as valid_s, open(
         
     | 
| 224 | 
         
            +
                    f"{directory}/valid.{args.tgt_lang}", "w"
         
     | 
| 225 | 
         
            +
                ) as valid_t, open(
         
     | 
| 226 | 
         
            +
                    f"{directory}/train.{args.src_lang}", "w"
         
     | 
| 227 | 
         
            +
                ) as train_s, open(
         
     | 
| 228 | 
         
            +
                    f"{directory}/train.{args.tgt_lang}", "w"
         
     | 
| 229 | 
         
            +
                ) as train_t:
         
     | 
| 230 | 
         
            +
                    count = 0
         
     | 
| 231 | 
         
            +
                    for s_line, t_line in zip(all_s, all_t):
         
     | 
| 232 | 
         
            +
                        s_line = s_line.split("\t")[1]
         
     | 
| 233 | 
         
            +
                        t_line = t_line.split("\t")[1]
         
     | 
| 234 | 
         
            +
                        if count >= args.valid_size:
         
     | 
| 235 | 
         
            +
                            train_s.write(s_line)
         
     | 
| 236 | 
         
            +
                            train_t.write(t_line)
         
     | 
| 237 | 
         
            +
                        else:
         
     | 
| 238 | 
         
            +
                            valid_s.write(s_line)
         
     | 
| 239 | 
         
            +
                            valid_t.write(t_line)
         
     | 
| 240 | 
         
            +
                            count += 1
         
     | 
    	
        fairseq/examples/criss/mining/mine_example.sh
    ADDED
    
    | 
         @@ -0,0 +1,103 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            #!/bin/bash
         
     | 
| 2 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 3 | 
         
            +
            # All rights reserved.
         
     | 
| 4 | 
         
            +
            #
         
     | 
| 5 | 
         
            +
            # This source code is licensed under the license found in the
         
     | 
| 6 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 7 | 
         
            +
            #
         
     | 
| 8 | 
         
            +
            source_lang=kk_KZ
         
     | 
| 9 | 
         
            +
            target_lang=en_XX
         
     | 
| 10 | 
         
            +
            MODEL=criss_checkpoints/criss.3rd.pt
         
     | 
| 11 | 
         
            +
            SPM=criss_checkpoints/sentence.bpe.model
         
     | 
| 12 | 
         
            +
            SPLIT=test
         
     | 
| 13 | 
         
            +
            LANG_DICT=criss_checkpoints/lang_dict.txt
         
     | 
| 14 | 
         
            +
            SPM_ENCODE=flores/scripts/spm_encode.py
         
     | 
| 15 | 
         
            +
            SAVE_ENCODER=save_encoder.py
         
     | 
| 16 | 
         
            +
            ENCODER_SAVE_ROOT=sentence_embeddings/$MODEL
         
     | 
| 17 | 
         
            +
            DICT=criss_checkpoints/dict.txt
         
     | 
| 18 | 
         
            +
            THRESHOLD=1.02
         
     | 
| 19 | 
         
            +
            MIN_COUNT=500
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            DATA_DIR=data_tmp
         
     | 
| 22 | 
         
            +
            SAVE_DIR=mining/${source_lang}_${target_lang}_mined
         
     | 
| 23 | 
         
            +
            ENCODER_SAVE_DIR=${ENCODER_SAVE_ROOT}/${source_lang}-${target_lang}
         
     | 
| 24 | 
         
            +
            INPUT_DIR=$DATA_DIR/${source_lang}-${target_lang}-tatoeba
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            mkdir -p $ENCODER_SAVE_DIR/${target_lang}
         
     | 
| 27 | 
         
            +
            mkdir -p $ENCODER_SAVE_DIR/${source_lang}
         
     | 
| 28 | 
         
            +
            mkdir -p $SAVE_DIR
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            ## Save encoder outputs
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            # Save encoder outputs for source sentences
         
     | 
| 33 | 
         
            +
            python $SAVE_ENCODER \
         
     | 
| 34 | 
         
            +
              ${INPUT_DIR} \
         
     | 
| 35 | 
         
            +
              --path ${MODEL} \
         
     | 
| 36 | 
         
            +
              --task translation_multi_simple_epoch \
         
     | 
| 37 | 
         
            +
              --lang-pairs ${source_lang}-${target_lang} \
         
     | 
| 38 | 
         
            +
              --lang-dict ${LANG_DICT} \
         
     | 
| 39 | 
         
            +
              --gen-subset ${SPLIT} \
         
     | 
| 40 | 
         
            +
              --bpe 'sentencepiece' \
         
     | 
| 41 | 
         
            +
              -s ${source_lang} -t ${target_lang} \
         
     | 
| 42 | 
         
            +
              --sentencepiece-model ${SPM} \
         
     | 
| 43 | 
         
            +
              --remove-bpe 'sentencepiece' \
         
     | 
| 44 | 
         
            +
              --beam 1 \
         
     | 
| 45 | 
         
            +
              --lang-tok-style mbart \
         
     | 
| 46 | 
         
            +
              --encoder-save-dir ${ENCODER_SAVE_DIR}/${source_lang}
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            ## Save encoder outputs for target sentences
         
     | 
| 49 | 
         
            +
            python $SAVE_ENCODER \
         
     | 
| 50 | 
         
            +
              ${INPUT_DIR} \
         
     | 
| 51 | 
         
            +
              --path ${MODEL} \
         
     | 
| 52 | 
         
            +
              --lang-pairs ${source_lang}-${target_lang} \
         
     | 
| 53 | 
         
            +
              --lang-dict ${LANG_DICT} \
         
     | 
| 54 | 
         
            +
              --task translation_multi_simple_epoch \
         
     | 
| 55 | 
         
            +
              --gen-subset ${SPLIT} \
         
     | 
| 56 | 
         
            +
              --bpe 'sentencepiece' \
         
     | 
| 57 | 
         
            +
              -t ${source_lang} -s ${target_lang} \
         
     | 
| 58 | 
         
            +
              --sentencepiece-model ${SPM} \
         
     | 
| 59 | 
         
            +
              --remove-bpe 'sentencepiece' \
         
     | 
| 60 | 
         
            +
              --beam 1 \
         
     | 
| 61 | 
         
            +
              --lang-tok-style mbart \
         
     | 
| 62 | 
         
            +
              --encoder-save-dir ${ENCODER_SAVE_DIR}/${target_lang}
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            ## Mining
         
     | 
| 65 | 
         
            +
            python mining/mine.py \
         
     | 
| 66 | 
         
            +
              --src-lang ${source_lang} \
         
     | 
| 67 | 
         
            +
              --tgt-lang ${target_lang} \
         
     | 
| 68 | 
         
            +
              --dim 1024 \
         
     | 
| 69 | 
         
            +
              --mem 10 \
         
     | 
| 70 | 
         
            +
              --neighborhood 4 \
         
     | 
| 71 | 
         
            +
              --src-dir ${ENCODER_SAVE_DIR}/${source_lang} \
         
     | 
| 72 | 
         
            +
              --tgt-dir ${ENCODER_SAVE_DIR}/${target_lang} \
         
     | 
| 73 | 
         
            +
              --output $SAVE_DIR \
         
     | 
| 74 | 
         
            +
              --threshold ${THRESHOLD} \
         
     | 
| 75 | 
         
            +
              --min-count ${MIN_COUNT} \
         
     | 
| 76 | 
         
            +
              --valid-size 100 \
         
     | 
| 77 | 
         
            +
              --dict-path ${DICT} \
         
     | 
| 78 | 
         
            +
              --spm-path ${SPM} \
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
            ## Process and binarize mined data
         
     | 
| 82 | 
         
            +
            python $SPM_ENCODE \
         
     | 
| 83 | 
         
            +
              --model ${SPM} \
         
     | 
| 84 | 
         
            +
              --output_format=piece \
         
     | 
| 85 | 
         
            +
              --inputs mining/${source_lang}_${target_lang}_mined/train.${source_lang} mining/${source_lang}_${target_lang}_mined/train.${target_lang} \
         
     | 
| 86 | 
         
            +
              --outputs mining/${source_lang}_${target_lang}_mined/train.bpe.${source_lang} mining/${source_lang}_${target_lang}_mined/train.bpe.${target_lang}
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
            python $SPM_ENCODE \
         
     | 
| 89 | 
         
            +
              --model ${SPM} \
         
     | 
| 90 | 
         
            +
              --output_format=piece \
         
     | 
| 91 | 
         
            +
              --inputs mining/${source_lang}_${target_lang}_mined/valid.${source_lang} mining/${source_lang}_${target_lang}_mined/valid.${target_lang} \
         
     | 
| 92 | 
         
            +
              --outputs mining/${source_lang}_${target_lang}_mined/valid.bpe.${source_lang} mining/${source_lang}_${target_lang}_mined/valid.bpe.${target_lang}
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
            fairseq-preprocess \
         
     | 
| 96 | 
         
            +
              --source-lang ${source_lang} \
         
     | 
| 97 | 
         
            +
              --target-lang ${target_lang} \
         
     | 
| 98 | 
         
            +
              --trainpref mining/${source_lang}_${target_lang}_mined/train.bpe \
         
     | 
| 99 | 
         
            +
              --validpref mining/${source_lang}_${target_lang}_mined/valid.bpe \
         
     | 
| 100 | 
         
            +
              --destdir mining/${source_lang}_${target_lang}_mined \
         
     | 
| 101 | 
         
            +
              --srcdict ${DICT} \
         
     | 
| 102 | 
         
            +
              --joined-dictionary \
         
     | 
| 103 | 
         
            +
              --workers 8
         
     | 
    	
        fairseq/examples/criss/save_encoder.py
    ADDED
    
    | 
         @@ -0,0 +1,214 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            #!/usr/bin/env python3 -u
         
     | 
| 2 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 3 | 
         
            +
            #
         
     | 
| 4 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 5 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 6 | 
         
            +
            """
         
     | 
| 7 | 
         
            +
            Translate pre-processed data with a trained model.
         
     | 
| 8 | 
         
            +
            """
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            import numpy as np
         
     | 
| 11 | 
         
            +
            import torch
         
     | 
| 12 | 
         
            +
            from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
         
     | 
| 13 | 
         
            +
            from fairseq.sequence_generator import EnsembleModel
         
     | 
| 14 | 
         
            +
            from fairseq.utils import safe_hasattr
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            def get_avg_pool(
         
     | 
| 18 | 
         
            +
                models, sample, prefix_tokens, src_dict, remove_bpe, has_langtok=False
         
     | 
| 19 | 
         
            +
            ):
         
     | 
| 20 | 
         
            +
                model = EnsembleModel(models)
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                # model.forward normally channels prev_output_tokens into the decoder
         
     | 
| 23 | 
         
            +
                # separately, but SequenceGenerator directly calls model.encoder
         
     | 
| 24 | 
         
            +
                encoder_input = {
         
     | 
| 25 | 
         
            +
                    k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
         
     | 
| 26 | 
         
            +
                }
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                # compute the encoder output for each beam
         
     | 
| 29 | 
         
            +
                encoder_outs = model.forward_encoder(encoder_input)
         
     | 
| 30 | 
         
            +
                np_encoder_outs = encoder_outs[0].encoder_out.cpu().numpy().astype(np.float32)
         
     | 
| 31 | 
         
            +
                encoder_mask = 1 - encoder_outs[0].encoder_padding_mask.cpu().numpy().astype(
         
     | 
| 32 | 
         
            +
                    np.float32
         
     | 
| 33 | 
         
            +
                )
         
     | 
| 34 | 
         
            +
                encoder_mask = np.expand_dims(encoder_mask.T, axis=2)
         
     | 
| 35 | 
         
            +
                if has_langtok:
         
     | 
| 36 | 
         
            +
                    encoder_mask = encoder_mask[1:, :, :]
         
     | 
| 37 | 
         
            +
                    np_encoder_outs = np_encoder_outs[1, :, :]
         
     | 
| 38 | 
         
            +
                masked_encoder_outs = encoder_mask * np_encoder_outs
         
     | 
| 39 | 
         
            +
                avg_pool = (masked_encoder_outs / encoder_mask.sum(axis=0)).sum(axis=0)
         
     | 
| 40 | 
         
            +
                return avg_pool
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            def main(args):
         
     | 
| 44 | 
         
            +
                assert args.path is not None, "--path required for generation!"
         
     | 
| 45 | 
         
            +
                assert (
         
     | 
| 46 | 
         
            +
                    not args.sampling or args.nbest == args.beam
         
     | 
| 47 | 
         
            +
                ), "--sampling requires --nbest to be equal to --beam"
         
     | 
| 48 | 
         
            +
                assert (
         
     | 
| 49 | 
         
            +
                    args.replace_unk is None or args.raw_text
         
     | 
| 50 | 
         
            +
                ), "--replace-unk requires a raw text dataset (--raw-text)"
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                args.beam = 1
         
     | 
| 53 | 
         
            +
                utils.import_user_module(args)
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                if args.max_tokens is None:
         
     | 
| 56 | 
         
            +
                    args.max_tokens = 12000
         
     | 
| 57 | 
         
            +
                print(args)
         
     | 
| 58 | 
         
            +
                use_cuda = torch.cuda.is_available() and not args.cpu
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                # Load dataset splits
         
     | 
| 61 | 
         
            +
                task = tasks.setup_task(args)
         
     | 
| 62 | 
         
            +
                task.load_dataset(args.gen_subset)
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                # Set dictionaries
         
     | 
| 65 | 
         
            +
                try:
         
     | 
| 66 | 
         
            +
                    src_dict = getattr(task, "source_dictionary", None)
         
     | 
| 67 | 
         
            +
                except NotImplementedError:
         
     | 
| 68 | 
         
            +
                    src_dict = None
         
     | 
| 69 | 
         
            +
                tgt_dict = task.target_dictionary
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                # Load ensemble
         
     | 
| 72 | 
         
            +
                print("| loading model(s) from {}".format(args.path))
         
     | 
| 73 | 
         
            +
                models, _model_args = checkpoint_utils.load_model_ensemble(
         
     | 
| 74 | 
         
            +
                    args.path.split(":"),
         
     | 
| 75 | 
         
            +
                    arg_overrides=eval(args.model_overrides),
         
     | 
| 76 | 
         
            +
                    task=task,
         
     | 
| 77 | 
         
            +
                )
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                # Optimize ensemble for generation
         
     | 
| 80 | 
         
            +
                for model in models:
         
     | 
| 81 | 
         
            +
                    model.make_generation_fast_(
         
     | 
| 82 | 
         
            +
                        beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
         
     | 
| 83 | 
         
            +
                        need_attn=args.print_alignment,
         
     | 
| 84 | 
         
            +
                    )
         
     | 
| 85 | 
         
            +
                    if args.fp16:
         
     | 
| 86 | 
         
            +
                        model.half()
         
     | 
| 87 | 
         
            +
                    if use_cuda:
         
     | 
| 88 | 
         
            +
                        model.cuda()
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                # Load alignment dictionary for unknown word replacement
         
     | 
| 91 | 
         
            +
                # (None if no unknown word replacement, empty if no path to align dictionary)
         
     | 
| 92 | 
         
            +
                align_dict = utils.load_align_dict(args.replace_unk)
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                # Load dataset (possibly sharded)
         
     | 
| 95 | 
         
            +
                itr = task.get_batch_iterator(
         
     | 
| 96 | 
         
            +
                    dataset=task.dataset(args.gen_subset),
         
     | 
| 97 | 
         
            +
                    max_tokens=args.max_tokens,
         
     | 
| 98 | 
         
            +
                    max_positions=utils.resolve_max_positions(
         
     | 
| 99 | 
         
            +
                        task.max_positions(),
         
     | 
| 100 | 
         
            +
                    ),
         
     | 
| 101 | 
         
            +
                    ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
         
     | 
| 102 | 
         
            +
                    required_batch_size_multiple=args.required_batch_size_multiple,
         
     | 
| 103 | 
         
            +
                    num_shards=args.num_shards,
         
     | 
| 104 | 
         
            +
                    shard_id=args.shard_id,
         
     | 
| 105 | 
         
            +
                    num_workers=args.num_workers,
         
     | 
| 106 | 
         
            +
                ).next_epoch_itr(shuffle=False)
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                num_sentences = 0
         
     | 
| 109 | 
         
            +
                source_sentences = []
         
     | 
| 110 | 
         
            +
                shard_id = 0
         
     | 
| 111 | 
         
            +
                all_avg_pool = None
         
     | 
| 112 | 
         
            +
                encoder_has_langtok = (
         
     | 
| 113 | 
         
            +
                    safe_hasattr(task.args, "encoder_langtok")
         
     | 
| 114 | 
         
            +
                    and task.args.encoder_langtok is not None
         
     | 
| 115 | 
         
            +
                    and safe_hasattr(task.args, "lang_tok_replacing_bos_eos")
         
     | 
| 116 | 
         
            +
                    and not task.args.lang_tok_replacing_bos_eos
         
     | 
| 117 | 
         
            +
                )
         
     | 
| 118 | 
         
            +
                with progress_bar.build_progress_bar(args, itr) as t:
         
     | 
| 119 | 
         
            +
                    for sample in t:
         
     | 
| 120 | 
         
            +
                        if sample is None:
         
     | 
| 121 | 
         
            +
                            print("Skipping None")
         
     | 
| 122 | 
         
            +
                            continue
         
     | 
| 123 | 
         
            +
                        sample = utils.move_to_cuda(sample) if use_cuda else sample
         
     | 
| 124 | 
         
            +
                        if "net_input" not in sample:
         
     | 
| 125 | 
         
            +
                            continue
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                        prefix_tokens = None
         
     | 
| 128 | 
         
            +
                        if args.prefix_size > 0:
         
     | 
| 129 | 
         
            +
                            prefix_tokens = sample["target"][:, : args.prefix_size]
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                        with torch.no_grad():
         
     | 
| 132 | 
         
            +
                            avg_pool = get_avg_pool(
         
     | 
| 133 | 
         
            +
                                models,
         
     | 
| 134 | 
         
            +
                                sample,
         
     | 
| 135 | 
         
            +
                                prefix_tokens,
         
     | 
| 136 | 
         
            +
                                src_dict,
         
     | 
| 137 | 
         
            +
                                args.post_process,
         
     | 
| 138 | 
         
            +
                                has_langtok=encoder_has_langtok,
         
     | 
| 139 | 
         
            +
                            )
         
     | 
| 140 | 
         
            +
                            if all_avg_pool is not None:
         
     | 
| 141 | 
         
            +
                                all_avg_pool = np.concatenate((all_avg_pool, avg_pool))
         
     | 
| 142 | 
         
            +
                            else:
         
     | 
| 143 | 
         
            +
                                all_avg_pool = avg_pool
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                        if not isinstance(sample["id"], list):
         
     | 
| 146 | 
         
            +
                            sample_ids = sample["id"].tolist()
         
     | 
| 147 | 
         
            +
                        else:
         
     | 
| 148 | 
         
            +
                            sample_ids = sample["id"]
         
     | 
| 149 | 
         
            +
                        for i, sample_id in enumerate(sample_ids):
         
     | 
| 150 | 
         
            +
                            # Remove padding
         
     | 
| 151 | 
         
            +
                            src_tokens = utils.strip_pad(
         
     | 
| 152 | 
         
            +
                                sample["net_input"]["src_tokens"][i, :], tgt_dict.pad()
         
     | 
| 153 | 
         
            +
                            )
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
                            # Either retrieve the original sentences or regenerate them from tokens.
         
     | 
| 156 | 
         
            +
                            if align_dict is not None:
         
     | 
| 157 | 
         
            +
                                src_str = task.dataset(args.gen_subset).src.get_original_text(
         
     | 
| 158 | 
         
            +
                                    sample_id
         
     | 
| 159 | 
         
            +
                                )
         
     | 
| 160 | 
         
            +
                            else:
         
     | 
| 161 | 
         
            +
                                if src_dict is not None:
         
     | 
| 162 | 
         
            +
                                    src_str = src_dict.string(src_tokens, args.post_process)
         
     | 
| 163 | 
         
            +
                                else:
         
     | 
| 164 | 
         
            +
                                    src_str = ""
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                            if not args.quiet:
         
     | 
| 167 | 
         
            +
                                if src_dict is not None:
         
     | 
| 168 | 
         
            +
                                    print("S-{}\t{}".format(sample_id, src_str))
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                            source_sentences.append(f"{sample_id}\t{src_str}")
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                        num_sentences += sample["nsentences"]
         
     | 
| 173 | 
         
            +
                        if all_avg_pool.shape[0] >= 1000000:
         
     | 
| 174 | 
         
            +
                            with open(
         
     | 
| 175 | 
         
            +
                                f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}",
         
     | 
| 176 | 
         
            +
                                "w",
         
     | 
| 177 | 
         
            +
                            ) as avg_pool_file:
         
     | 
| 178 | 
         
            +
                                all_avg_pool.tofile(avg_pool_file)
         
     | 
| 179 | 
         
            +
                            with open(
         
     | 
| 180 | 
         
            +
                                f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}",
         
     | 
| 181 | 
         
            +
                                "w",
         
     | 
| 182 | 
         
            +
                            ) as sentence_file:
         
     | 
| 183 | 
         
            +
                                sentence_file.writelines(f"{line}\n" for line in source_sentences)
         
     | 
| 184 | 
         
            +
                            all_avg_pool = None
         
     | 
| 185 | 
         
            +
                            source_sentences = []
         
     | 
| 186 | 
         
            +
                            shard_id += 1
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                if all_avg_pool is not None:
         
     | 
| 189 | 
         
            +
                    with open(
         
     | 
| 190 | 
         
            +
                        f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}", "w"
         
     | 
| 191 | 
         
            +
                    ) as avg_pool_file:
         
     | 
| 192 | 
         
            +
                        all_avg_pool.tofile(avg_pool_file)
         
     | 
| 193 | 
         
            +
                    with open(
         
     | 
| 194 | 
         
            +
                        f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}", "w"
         
     | 
| 195 | 
         
            +
                    ) as sentence_file:
         
     | 
| 196 | 
         
            +
                        sentence_file.writelines(f"{line}\n" for line in source_sentences)
         
     | 
| 197 | 
         
            +
                return None
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
            def cli_main():
         
     | 
| 201 | 
         
            +
                parser = options.get_generation_parser()
         
     | 
| 202 | 
         
            +
                parser.add_argument(
         
     | 
| 203 | 
         
            +
                    "--encoder-save-dir",
         
     | 
| 204 | 
         
            +
                    default="",
         
     | 
| 205 | 
         
            +
                    type=str,
         
     | 
| 206 | 
         
            +
                    metavar="N",
         
     | 
| 207 | 
         
            +
                    help="directory to save encoder outputs",
         
     | 
| 208 | 
         
            +
                )
         
     | 
| 209 | 
         
            +
                args = options.parse_args_and_arch(parser)
         
     | 
| 210 | 
         
            +
                main(args)
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 214 | 
         
            +
                cli_main()
         
     | 
    	
        fairseq/examples/criss/sentence_retrieval/encoder_analysis.py
    ADDED
    
    | 
         @@ -0,0 +1,92 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            #!/usr/bin/env python3 -u
         
     | 
| 2 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 3 | 
         
            +
            #
         
     | 
| 4 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 5 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 6 | 
         
            +
            import argparse
         
     | 
| 7 | 
         
            +
            import glob
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            import numpy as np
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            DIM = 1024
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            def compute_dist(source_embs, target_embs, k=5, return_sim_mat=False):
         
     | 
| 16 | 
         
            +
                target_ids = [tid for tid in target_embs]
         
     | 
| 17 | 
         
            +
                source_mat = np.stack(source_embs.values(), axis=0)
         
     | 
| 18 | 
         
            +
                normalized_source_mat = source_mat / np.linalg.norm(
         
     | 
| 19 | 
         
            +
                    source_mat, axis=1, keepdims=True
         
     | 
| 20 | 
         
            +
                )
         
     | 
| 21 | 
         
            +
                target_mat = np.stack(target_embs.values(), axis=0)
         
     | 
| 22 | 
         
            +
                normalized_target_mat = target_mat / np.linalg.norm(
         
     | 
| 23 | 
         
            +
                    target_mat, axis=1, keepdims=True
         
     | 
| 24 | 
         
            +
                )
         
     | 
| 25 | 
         
            +
                sim_mat = normalized_source_mat.dot(normalized_target_mat.T)
         
     | 
| 26 | 
         
            +
                if return_sim_mat:
         
     | 
| 27 | 
         
            +
                    return sim_mat
         
     | 
| 28 | 
         
            +
                neighbors_map = {}
         
     | 
| 29 | 
         
            +
                for i, sentence_id in enumerate(source_embs):
         
     | 
| 30 | 
         
            +
                    idx = np.argsort(sim_mat[i, :])[::-1][:k]
         
     | 
| 31 | 
         
            +
                    neighbors_map[sentence_id] = [target_ids[tid] for tid in idx]
         
     | 
| 32 | 
         
            +
                return neighbors_map
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            def load_embeddings(directory, LANGS):
         
     | 
| 36 | 
         
            +
                sentence_embeddings = {}
         
     | 
| 37 | 
         
            +
                sentence_texts = {}
         
     | 
| 38 | 
         
            +
                for lang in LANGS:
         
     | 
| 39 | 
         
            +
                    sentence_embeddings[lang] = {}
         
     | 
| 40 | 
         
            +
                    sentence_texts[lang] = {}
         
     | 
| 41 | 
         
            +
                    lang_dir = f"{directory}/{lang}"
         
     | 
| 42 | 
         
            +
                    embedding_files = glob.glob(f"{lang_dir}/all_avg_pool.{lang}.*")
         
     | 
| 43 | 
         
            +
                    for embed_file in embedding_files:
         
     | 
| 44 | 
         
            +
                        shard_id = embed_file.split(".")[-1]
         
     | 
| 45 | 
         
            +
                        embeddings = np.fromfile(embed_file, dtype=np.float32)
         
     | 
| 46 | 
         
            +
                        num_rows = embeddings.shape[0] // DIM
         
     | 
| 47 | 
         
            +
                        embeddings = embeddings.reshape((num_rows, DIM))
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                        with open(f"{lang_dir}/sentences.{lang}.{shard_id}") as sentence_file:
         
     | 
| 50 | 
         
            +
                            for idx, line in enumerate(sentence_file):
         
     | 
| 51 | 
         
            +
                                sentence_id, sentence = line.strip().split("\t")
         
     | 
| 52 | 
         
            +
                                sentence_texts[lang][sentence_id] = sentence
         
     | 
| 53 | 
         
            +
                                sentence_embeddings[lang][sentence_id] = embeddings[idx, :]
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                return sentence_embeddings, sentence_texts
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            def compute_accuracy(directory, LANGS):
         
     | 
| 59 | 
         
            +
                sentence_embeddings, sentence_texts = load_embeddings(directory, LANGS)
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                top_1_accuracy = {}
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                top1_str = " ".join(LANGS) + "\n"
         
     | 
| 64 | 
         
            +
                for source_lang in LANGS:
         
     | 
| 65 | 
         
            +
                    top_1_accuracy[source_lang] = {}
         
     | 
| 66 | 
         
            +
                    top1_str += f"{source_lang} "
         
     | 
| 67 | 
         
            +
                    for target_lang in LANGS:
         
     | 
| 68 | 
         
            +
                        top1 = 0
         
     | 
| 69 | 
         
            +
                        top5 = 0
         
     | 
| 70 | 
         
            +
                        neighbors_map = compute_dist(
         
     | 
| 71 | 
         
            +
                            sentence_embeddings[source_lang], sentence_embeddings[target_lang]
         
     | 
| 72 | 
         
            +
                        )
         
     | 
| 73 | 
         
            +
                        for sentence_id, neighbors in neighbors_map.items():
         
     | 
| 74 | 
         
            +
                            if sentence_id == neighbors[0]:
         
     | 
| 75 | 
         
            +
                                top1 += 1
         
     | 
| 76 | 
         
            +
                            if sentence_id in neighbors[:5]:
         
     | 
| 77 | 
         
            +
                                top5 += 1
         
     | 
| 78 | 
         
            +
                        n = len(sentence_embeddings[target_lang])
         
     | 
| 79 | 
         
            +
                        top1_str += f"{top1/n} "
         
     | 
| 80 | 
         
            +
                    top1_str += "\n"
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                print(top1_str)
         
     | 
| 83 | 
         
            +
                print(top1_str, file=open(f"{directory}/accuracy", "w"))
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 87 | 
         
            +
                parser = argparse.ArgumentParser(description="Analyze encoder outputs")
         
     | 
| 88 | 
         
            +
                parser.add_argument("directory", help="Source language corpus")
         
     | 
| 89 | 
         
            +
                parser.add_argument("--langs", help="List of langs")
         
     | 
| 90 | 
         
            +
                args = parser.parse_args()
         
     | 
| 91 | 
         
            +
                langs = args.langs.split(",")
         
     | 
| 92 | 
         
            +
                compute_accuracy(args.directory, langs)
         
     |