Antonio Cheong commited on
Commit
4f15858
·
1 Parent(s): 4f7a928
mm-cot/CODE_OF_CONDUCT.md DELETED
@@ -1,4 +0,0 @@
1
- ## Code of Conduct
2
- This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
3
- For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
4
- [email protected] with any additional questions or comments.
 
 
 
 
 
mm-cot/CONTRIBUTING.md DELETED
@@ -1,59 +0,0 @@
1
- # Contributing Guidelines
2
-
3
- Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
4
- documentation, we greatly value feedback and contributions from our community.
5
-
6
- Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
7
- information to effectively respond to your bug report or contribution.
8
-
9
-
10
- ## Reporting Bugs/Feature Requests
11
-
12
- We welcome you to use the GitHub issue tracker to report bugs or suggest features.
13
-
14
- When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
15
- reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
16
-
17
- * A reproducible test case or series of steps
18
- * The version of our code being used
19
- * Any modifications you've made relevant to the bug
20
- * Anything unusual about your environment or deployment
21
-
22
-
23
- ## Contributing via Pull Requests
24
- Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
25
-
26
- 1. You are working against the latest source on the *main* branch.
27
- 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
28
- 3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
29
-
30
- To send us a pull request, please:
31
-
32
- 1. Fork the repository.
33
- 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
34
- 3. Ensure local tests pass.
35
- 4. Commit to your fork using clear commit messages.
36
- 5. Send us a pull request, answering any default questions in the pull request interface.
37
- 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
38
-
39
- GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
40
- [creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
41
-
42
-
43
- ## Finding contributions to work on
44
- Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.
45
-
46
-
47
- ## Code of Conduct
48
- This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
49
- For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
50
- [email protected] with any additional questions or comments.
51
-
52
-
53
- ## Security issue notifications
54
- If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
55
-
56
-
57
- ## Licensing
58
-
59
- See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mm-cot/LICENSE DELETED
@@ -1,175 +0,0 @@
1
-
2
- Apache License
3
- Version 2.0, January 2004
4
- http://www.apache.org/licenses/
5
-
6
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
-
8
- 1. Definitions.
9
-
10
- "License" shall mean the terms and conditions for use, reproduction,
11
- and distribution as defined by Sections 1 through 9 of this document.
12
-
13
- "Licensor" shall mean the copyright owner or entity authorized by
14
- the copyright owner that is granting the License.
15
-
16
- "Legal Entity" shall mean the union of the acting entity and all
17
- other entities that control, are controlled by, or are under common
18
- control with that entity. For the purposes of this definition,
19
- "control" means (i) the power, direct or indirect, to cause the
20
- direction or management of such entity, whether by contract or
21
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
- outstanding shares, or (iii) beneficial ownership of such entity.
23
-
24
- "You" (or "Your") shall mean an individual or Legal Entity
25
- exercising permissions granted by this License.
26
-
27
- "Source" form shall mean the preferred form for making modifications,
28
- including but not limited to software source code, documentation
29
- source, and configuration files.
30
-
31
- "Object" form shall mean any form resulting from mechanical
32
- transformation or translation of a Source form, including but
33
- not limited to compiled object code, generated documentation,
34
- and conversions to other media types.
35
-
36
- "Work" shall mean the work of authorship, whether in Source or
37
- Object form, made available under the License, as indicated by a
38
- copyright notice that is included in or attached to the work
39
- (an example is provided in the Appendix below).
40
-
41
- "Derivative Works" shall mean any work, whether in Source or Object
42
- form, that is based on (or derived from) the Work and for which the
43
- editorial revisions, annotations, elaborations, or other modifications
44
- represent, as a whole, an original work of authorship. For the purposes
45
- of this License, Derivative Works shall not include works that remain
46
- separable from, or merely link (or bind by name) to the interfaces of,
47
- the Work and Derivative Works thereof.
48
-
49
- "Contribution" shall mean any work of authorship, including
50
- the original version of the Work and any modifications or additions
51
- to that Work or Derivative Works thereof, that is intentionally
52
- submitted to Licensor for inclusion in the Work by the copyright owner
53
- or by an individual or Legal Entity authorized to submit on behalf of
54
- the copyright owner. For the purposes of this definition, "submitted"
55
- means any form of electronic, verbal, or written communication sent
56
- to the Licensor or its representatives, including but not limited to
57
- communication on electronic mailing lists, source code control systems,
58
- and issue tracking systems that are managed by, or on behalf of, the
59
- Licensor for the purpose of discussing and improving the Work, but
60
- excluding communication that is conspicuously marked or otherwise
61
- designated in writing by the copyright owner as "Not a Contribution."
62
-
63
- "Contributor" shall mean Licensor and any individual or Legal Entity
64
- on behalf of whom a Contribution has been received by Licensor and
65
- subsequently incorporated within the Work.
66
-
67
- 2. Grant of Copyright License. Subject to the terms and conditions of
68
- this License, each Contributor hereby grants to You a perpetual,
69
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
- copyright license to reproduce, prepare Derivative Works of,
71
- publicly display, publicly perform, sublicense, and distribute the
72
- Work and such Derivative Works in Source or Object form.
73
-
74
- 3. Grant of Patent License. Subject to the terms and conditions of
75
- this License, each Contributor hereby grants to You a perpetual,
76
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
- (except as stated in this section) patent license to make, have made,
78
- use, offer to sell, sell, import, and otherwise transfer the Work,
79
- where such license applies only to those patent claims licensable
80
- by such Contributor that are necessarily infringed by their
81
- Contribution(s) alone or by combination of their Contribution(s)
82
- with the Work to which such Contribution(s) was submitted. If You
83
- institute patent litigation against any entity (including a
84
- cross-claim or counterclaim in a lawsuit) alleging that the Work
85
- or a Contribution incorporated within the Work constitutes direct
86
- or contributory patent infringement, then any patent licenses
87
- granted to You under this License for that Work shall terminate
88
- as of the date such litigation is filed.
89
-
90
- 4. Redistribution. You may reproduce and distribute copies of the
91
- Work or Derivative Works thereof in any medium, with or without
92
- modifications, and in Source or Object form, provided that You
93
- meet the following conditions:
94
-
95
- (a) You must give any other recipients of the Work or
96
- Derivative Works a copy of this License; and
97
-
98
- (b) You must cause any modified files to carry prominent notices
99
- stating that You changed the files; and
100
-
101
- (c) You must retain, in the Source form of any Derivative Works
102
- that You distribute, all copyright, patent, trademark, and
103
- attribution notices from the Source form of the Work,
104
- excluding those notices that do not pertain to any part of
105
- the Derivative Works; and
106
-
107
- (d) If the Work includes a "NOTICE" text file as part of its
108
- distribution, then any Derivative Works that You distribute must
109
- include a readable copy of the attribution notices contained
110
- within such NOTICE file, excluding those notices that do not
111
- pertain to any part of the Derivative Works, in at least one
112
- of the following places: within a NOTICE text file distributed
113
- as part of the Derivative Works; within the Source form or
114
- documentation, if provided along with the Derivative Works; or,
115
- within a display generated by the Derivative Works, if and
116
- wherever such third-party notices normally appear. The contents
117
- of the NOTICE file are for informational purposes only and
118
- do not modify the License. You may add Your own attribution
119
- notices within Derivative Works that You distribute, alongside
120
- or as an addendum to the NOTICE text from the Work, provided
121
- that such additional attribution notices cannot be construed
122
- as modifying the License.
123
-
124
- You may add Your own copyright statement to Your modifications and
125
- may provide additional or different license terms and conditions
126
- for use, reproduction, or distribution of Your modifications, or
127
- for any such Derivative Works as a whole, provided Your use,
128
- reproduction, and distribution of the Work otherwise complies with
129
- the conditions stated in this License.
130
-
131
- 5. Submission of Contributions. Unless You explicitly state otherwise,
132
- any Contribution intentionally submitted for inclusion in the Work
133
- by You to the Licensor shall be under the terms and conditions of
134
- this License, without any additional terms or conditions.
135
- Notwithstanding the above, nothing herein shall supersede or modify
136
- the terms of any separate license agreement you may have executed
137
- with Licensor regarding such Contributions.
138
-
139
- 6. Trademarks. This License does not grant permission to use the trade
140
- names, trademarks, service marks, or product names of the Licensor,
141
- except as required for reasonable and customary use in describing the
142
- origin of the Work and reproducing the content of the NOTICE file.
143
-
144
- 7. Disclaimer of Warranty. Unless required by applicable law or
145
- agreed to in writing, Licensor provides the Work (and each
146
- Contributor provides its Contributions) on an "AS IS" BASIS,
147
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
- implied, including, without limitation, any warranties or conditions
149
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
- PARTICULAR PURPOSE. You are solely responsible for determining the
151
- appropriateness of using or redistributing the Work and assume any
152
- risks associated with Your exercise of permissions under this License.
153
-
154
- 8. Limitation of Liability. In no event and under no legal theory,
155
- whether in tort (including negligence), contract, or otherwise,
156
- unless required by applicable law (such as deliberate and grossly
157
- negligent acts) or agreed to in writing, shall any Contributor be
158
- liable to You for damages, including any direct, indirect, special,
159
- incidental, or consequential damages of any character arising as a
160
- result of this License or out of the use or inability to use the
161
- Work (including but not limited to damages for loss of goodwill,
162
- work stoppage, computer failure or malfunction, or any and all
163
- other commercial damages or losses), even if such Contributor
164
- has been advised of the possibility of such damages.
165
-
166
- 9. Accepting Warranty or Additional Liability. While redistributing
167
- the Work or Derivative Works thereof, You may choose to offer,
168
- and charge a fee for, acceptance of support, warranty, indemnity,
169
- or other liability obligations and/or rights consistent with this
170
- License. However, in accepting such obligations, You may act only
171
- on Your own behalf and on Your sole responsibility, not on behalf
172
- of any other Contributor, and only if You agree to indemnify,
173
- defend, and hold each Contributor harmless for any liability
174
- incurred by, or claims asserted against, such Contributor by reason
175
- of your accepting any such warranty or additional liability.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mm-cot/NOTICE DELETED
@@ -1 +0,0 @@
1
- Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 
 
mm-cot/README.md DELETED
@@ -1,93 +0,0 @@
1
- # Multimodal Chain-of-Thought Reasoning in Language Models
2
-
3
- <h5 align="center"><i>"Imagine learning a textbook without figures or tables."</i></h5>
4
-
5
- Multimodal-CoT incorporates vision features in a decoupled training framework. The framework consists of two training stages: (i) rationale generation and (ii) answer inference. Both stages share the same model architecture but differ in the input and output.
6
-
7
- ![](vision_features/mm-cot.png)
8
-
9
-
10
- ## Requirements
11
-
12
- Install all required python dependencies:
13
-
14
- ```
15
- pip install -r requirements.txt
16
- ```
17
-
18
- ## Datasets
19
-
20
- Download the dataset from the following repository:
21
-
22
- ```
23
- https://github.com/lupantech/ScienceQA/tree/main/data
24
- ```
25
-
26
- Download the extracted vision features from [vision_features](https://drive.google.com/file/d/13B0hc_F_45-UlqPLKSgRz-ALtFQ8kIJr/view?usp=share_link) and unzip the files under `vision_features`
27
-
28
- ## Instructions
29
-
30
- ### Training
31
-
32
- ```
33
- # rationale generation
34
- CUDA_VISIBLE_DEVICES=0,1 python main.py \
35
- --model allenai/unifiedqa-t5-base \
36
- --user_msg rationale --img_type detr \
37
- --bs 8 --eval_bs 4 --eval_acc 10 --output_len 512 \
38
- --final_eval --prompt_format QCM-LE
39
-
40
- # answer inference
41
- CUDA_VISIBLE_DEVICES=0,1 python main.py \
42
- --model allenai/unifiedqa-t5-base \
43
- --user_msg answer --img_type detr \
44
- --bs 8 --eval_bs 4 --eval_acc 10 --output_len 64 \
45
- --final_eval --prompt_format QCMG-A \
46
- --eval_le experiments/rationale_allenai-unifiedqa-t5-base_detr_QCM-LE_lr5e-05_bs16_op512_ep20/predictions_ans_eval.json \
47
- --test_le experiments/rationale_allenai-unifiedqa-t5-base_detr_QCM-LE_lr5e-05_bs16_op512_ep20/predictions_ans_test.json
48
- ```
49
-
50
- ### Inference
51
-
52
- Our trained models are available at [models](https://drive.google.com/file/d/1FtTYOJPHnWnFfCxNC6M3gar4RAX5E21b/view?usp=share_link). To use our trained models, please put the them under the ```models``` folder.
53
-
54
- ```
55
- # rationale generation
56
- CUDA_VISIBLE_DEVICES=0,1 python main.py \
57
- --model allenai/unifiedqa-t5-base \
58
- --user_msg rationale --img_type detr \
59
- --bs 8 --eval_bs 4 --eval_acc 10 --output_len 512 \
60
- --final_eval --prompt_format QCM-LE \
61
- --evaluate_dir models/MM-CoT-UnifiedQA-base-Rationale
62
-
63
- # answer inference
64
- CUDA_VISIBLE_DEVICES=0,1 python main.py \
65
- --model allenai/unifiedqa-t5-base \
66
- --user_msg answer --img_type detr \
67
- --bs 8 --eval_bs 4 --eval_acc 10 --output_len 64 \
68
- --final_eval --prompt_format QCMG-A \
69
- --eval_le models/rationale/predictions_ans_eval.json \
70
- --test_le models/rationale/predictions_ans_test.json \
71
- --evaluate_dir models/MM-CoT-UnifiedQA-base-Answer
72
- ```
73
-
74
- ## Citing MM-CoT
75
-
76
- ```
77
- @article{zhang2023multicot,
78
- title={Multimodal Chain-of-Thought Reasoning in Language Models},
79
- author={Zhang, Zhuosheng and Zhang, Aston and Li, Mu and Zhao, Hai and Karypis, George and Smola, Alex},
80
- journal={arXiv preprint arXiv:2302.00923},
81
- year={2023}
82
- }
83
- ```
84
-
85
- ## License
86
-
87
- This project is licensed under the Apache-2.0 License.
88
-
89
- ## Acknowledgement
90
-
91
- Part of our codes are adapted from [ScienceQA](https://github.com/lupantech/ScienceQA) and [Transformers](https://github.com/huggingface/transformers).
92
-
93
- We thank Pan Lu for providing parameter size for ScienceQA baselines.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mm-cot/evaluations.py DELETED
@@ -1,100 +0,0 @@
1
- '''
2
- Adapted from https://github.com/lupantech/ScienceQA
3
- '''
4
-
5
- import re
6
- from rouge import Rouge
7
- from nltk.translate.bleu_score import sentence_bleu
8
- from sentence_transformers import util
9
-
10
- ########################
11
- ## BLEU
12
- ########################
13
- def tokenize(text):
14
- tokens = re.split(r'\s|\.', text)
15
- tokens = [t for t in tokens if len(t) > 0]
16
- return tokens
17
-
18
-
19
- def bleu_score(reference, hypothesis, gram):
20
- reference_tokens = tokenize(reference)
21
- hypothesis_tokens = tokenize(hypothesis)
22
-
23
- if gram == 1:
24
- bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1., )) # BELU-1
25
- elif gram == 2:
26
- bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1. / 2., 1. / 2.)) # BELU-2
27
- elif gram == 3:
28
- bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1. / 3., 1. / 3., 1. / 3.)) # BELU-3
29
- elif gram == 4:
30
- bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1. / 4., 1. / 4., 1. / 4., 1. / 4.)) # BELU-4
31
-
32
- return bleu
33
-
34
-
35
- def caculate_bleu(results, data, gram):
36
- bleus = []
37
- for qid, output in results.items():
38
- prediction = output
39
- target = data[qid]
40
- target = target.strip()
41
- if target == "":
42
- continue
43
- bleu = bleu_score(target, prediction, gram)
44
- bleus.append(bleu)
45
-
46
- avg_bleu = sum(bleus) / len(bleus)
47
-
48
- return avg_bleu
49
-
50
-
51
- ########################
52
- ## Rouge-L
53
- ########################
54
- def score_rouge(str1, str2):
55
- rouge = Rouge(metrics=["rouge-l"])
56
- scores = rouge.get_scores(str1, str2, avg=True)
57
- rouge_l = scores['rouge-l']['f']
58
- return rouge_l
59
-
60
-
61
- def caculate_rouge(results, data):
62
- rouges = []
63
- for qid, output in results.items():
64
- prediction = output
65
- target = data[qid]
66
- target = target.strip()
67
- if prediction == "":
68
- continue
69
- if target == "":
70
- continue
71
- rouge = score_rouge(target, prediction)
72
- rouges.append(rouge)
73
-
74
- avg_rouge = sum(rouges) / len(rouges)
75
- return avg_rouge
76
-
77
-
78
- ########################
79
- ## Sentence Similarity
80
- ########################
81
- def similariry_score(str1, str2, model):
82
- # compute embedding for both lists
83
- embedding_1 = model.encode(str1, convert_to_tensor=True)
84
- embedding_2 = model.encode(str2, convert_to_tensor=True)
85
- score = util.pytorch_cos_sim(embedding_1, embedding_2).item()
86
- return score
87
-
88
-
89
- def caculate_similariry(results, data, model):
90
- scores = []
91
- for qid, output in results.items():
92
- prediction = output
93
- target = data[qid]
94
- target = target.strip()
95
-
96
- score = similariry_score(target, prediction, model)
97
- scores.append(score)
98
-
99
- avg_score = sum(scores) / len(scores)
100
- return avg_score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mm-cot/main.py DELETED
@@ -1,383 +0,0 @@
1
- import os
2
- import numpy as np
3
- import torch
4
- import os
5
- import re
6
- import json
7
- import argparse
8
- import random
9
- from transformers import T5Tokenizer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, T5ForConditionalGeneration
10
- from model import T5ForConditionalGeneration, T5ForMultimodalGeneration
11
- from utils_data import img_shape, load_data_std, load_data_img, ScienceQADatasetStd, ScienceQADatasetImg
12
- from utils_prompt import *
13
- from utils_evaluate import get_scores
14
- from rich.table import Column, Table
15
- from rich import box
16
- from rich.console import Console
17
- console = Console(record=True)
18
- from torch import cuda
19
- import nltk
20
- import evaluate
21
-
22
-
23
- def parse_args():
24
- parser = argparse.ArgumentParser()
25
- parser.add_argument('--data_root', type=str, default='data')
26
- parser.add_argument('--output_dir', type=str, default='experiments')
27
- parser.add_argument('--model', type=str, default='allenai/unifiedqa-t5-base')
28
- parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
29
- parser.add_argument('--epoch', type=int, default=20)
30
- parser.add_argument('--lr', type=float, default=5e-5)
31
- parser.add_argument('--bs', type=int, default=16)
32
- parser.add_argument('--input_len', type=int, default=512)
33
- parser.add_argument('--output_len', type=int, default=64)
34
- parser.add_argument('--eval_bs', type=int, default=16)
35
- parser.add_argument('--eval_acc', type=int, default=None, help='evaluate accumulation step')
36
- parser.add_argument('--train_split', type=str, default='train', choices=['train', 'trainval', 'minitrain'])
37
- parser.add_argument('--val_split', type=str, default='val', choices=['test', 'val', 'minival'])
38
- parser.add_argument('--test_split', type=str, default='test', choices=['test', 'minitest'])
39
-
40
- parser.add_argument('--use_generate', action='store_true', help='only for baseline to improve inference speed')
41
- parser.add_argument('--final_eval', action='store_true', help='only evaluate the model at the final epoch')
42
- parser.add_argument('--user_msg', type=str, default="baseline", help='experiment type in the save_dir')
43
- parser.add_argument('--img_type', type=str, default=None, choices=['detr', 'clip', 'resnet'], help='type of image features')
44
- parser.add_argument('--eval_le', type=str, default=None, help='generated rationale for the dev set')
45
- parser.add_argument('--test_le', type=str, default=None, help='generated rationale for the test set')
46
- parser.add_argument('--evaluate_dir', type=str, default=None, help='the directory of model for evaluation')
47
- parser.add_argument('--caption_file', type=str, default='data/captions.json')
48
- parser.add_argument('--use_caption', action='store_true', help='use image captions or not')
49
- parser.add_argument('--prompt_format', type=str, default='QCM-A', help='prompt format template',
50
- choices=['QCM-A', 'QCM-LE', 'QCMG-A', 'QCM-LEA', 'QCM-ALE'])
51
- parser.add_argument('--seed', type=int, default=42, help='random seed')
52
-
53
- args = parser.parse_args()
54
- return args
55
-
56
- def T5Trainer(
57
- dataframe, args,
58
- ):
59
- torch.manual_seed(args.seed) # pytorch random seed
60
- np.random.seed(args.seed) # numpy random seed
61
- torch.backends.cudnn.deterministic = True
62
-
63
- if args.evaluate_dir is not None:
64
- args.model = args.evaluate_dir
65
-
66
- tokenizer = T5Tokenizer.from_pretrained(args.model)
67
-
68
- console.log(f"""[Model]: Loading {args.model}...\n""")
69
- console.log(f"[Data]: Reading data...\n")
70
- problems = dataframe['problems']
71
- qids = dataframe['qids']
72
- train_qids = qids['train']
73
- test_qids = qids['test']
74
- val_qids = qids['val']
75
-
76
- if args.evaluate_dir is not None:
77
- save_dir = args.evaluate_dir
78
- else:
79
- model_name = args.model.replace("/","-")
80
- gpu_count = torch.cuda.device_count()
81
- save_dir = f"{args.output_dir}/{args.user_msg}_{model_name}_{args.img_type}_{args.prompt_format}_lr{args.lr}_bs{args.bs * gpu_count}_op{args.output_len}_ep{args.epoch}"
82
- if not os.path.exists(save_dir):
83
- os.mkdir(save_dir)
84
-
85
- padding_idx = tokenizer._convert_token_to_id(tokenizer.pad_token)
86
- if args.img_type is not None:
87
- patch_size = img_shape[args.img_type]
88
- model = T5ForMultimodalGeneration.from_pretrained(args.model, patch_size=patch_size, padding_idx=padding_idx, save_dir=save_dir)
89
- name_maps = dataframe['name_maps']
90
- image_features = dataframe['image_features']
91
- train_set = ScienceQADatasetImg(
92
- problems,
93
- train_qids,
94
- name_maps,
95
- tokenizer,
96
- args.input_len,
97
- args.output_len,
98
- args,
99
- image_features,
100
- )
101
- eval_set = ScienceQADatasetImg(
102
- problems,
103
- val_qids,
104
- name_maps,
105
- tokenizer,
106
- args.input_len,
107
- args.output_len,
108
- args,
109
- image_features,
110
- args.eval_le,
111
- )
112
- test_set = ScienceQADatasetImg(
113
- problems,
114
- test_qids,
115
- name_maps,
116
- tokenizer,
117
- args.input_len,
118
- args.output_len,
119
- args,
120
- image_features,
121
- args.test_le,
122
- )
123
- else:
124
- model = T5ForConditionalGeneration.from_pretrained(args.model)
125
- train_set = ScienceQADatasetStd(
126
- problems,
127
- train_qids,
128
- tokenizer,
129
- args.input_len,
130
- args.output_len,
131
- args,
132
- )
133
- eval_set = ScienceQADatasetStd(
134
- problems,
135
- val_qids,
136
- tokenizer,
137
- args.input_len,
138
- args.output_len,
139
- args,
140
- args.eval_le,
141
- )
142
-
143
- test_set = ScienceQADatasetStd(
144
- problems,
145
- test_qids,
146
- tokenizer,
147
- args.input_len,
148
- args.output_len,
149
- args,
150
- args.test_le,
151
- )
152
-
153
- datacollator = DataCollatorForSeq2Seq(tokenizer)
154
- print("model parameters: ", model.num_parameters())
155
- def extract_ans(ans):
156
- pattern = re.compile(r'The answer is \(([A-Z])\)')
157
- res = pattern.findall(ans)
158
-
159
- if len(res) == 1:
160
- answer = res[0] # 'A', 'B', ...
161
- else:
162
- answer = "FAILED"
163
- return answer
164
-
165
- # accuracy for answer inference
166
- def compute_metrics_acc(eval_preds):
167
- if args.use_generate:
168
- preds, targets = eval_preds
169
- if isinstance(preds, tuple):
170
- preds = preds[0]
171
- else:
172
- preds = eval_preds.predictions[0]
173
- targets = eval_preds.label_ids
174
- preds = preds.argmax(axis=2)
175
- preds = tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
176
- targets = tokenizer.batch_decode(targets, skip_special_tokens=True, clean_up_tokenization_spaces=True)
177
- correct = 0
178
- assert len(preds) == len(targets)
179
- for idx, pred in enumerate(preds):
180
- reference = targets[idx]
181
- reference = extract_ans(reference)
182
- extract_pred = extract_ans(pred)
183
- best_option = extract_pred
184
- if reference == best_option:
185
- correct +=1
186
- return {'accuracy': 1.0*correct/len(targets)}
187
-
188
- # rougel for rationale generation
189
- metric = evaluate.load("rouge")
190
- def postprocess_text(preds, labels):
191
- preds = [pred.strip() for pred in preds]
192
- labels = [label.strip() for label in labels]
193
- preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
194
- labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
195
- return preds, labels
196
-
197
- def compute_metrics_rougel(eval_preds):
198
- if args.use_generate:
199
- preds, targets = eval_preds
200
- if isinstance(preds, tuple):
201
- preds = preds[0]
202
- else:
203
- preds = eval_preds.predictions[0]
204
- targets = eval_preds.label_ids
205
- preds = preds.argmax(axis=2)
206
- preds = tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
207
- targets = tokenizer.batch_decode(targets, skip_special_tokens=True, clean_up_tokenization_spaces=True)
208
-
209
- decoded_preds, decoded_labels = postprocess_text(preds, targets)
210
-
211
- result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
212
- result = {k: round(v * 100, 4) for k, v in result.items()}
213
- prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
214
- result["gen_len"] = np.mean(prediction_lens)
215
- return result
216
-
217
- # only use the last model for evaluation to save time
218
- if args.final_eval:
219
- training_args = Seq2SeqTrainingArguments(
220
- save_dir,
221
- do_train=True if args.evaluate_dir is None else False,
222
- do_eval=False,
223
- evaluation_strategy="no",
224
- logging_strategy="steps",
225
- save_strategy="epoch",
226
- save_total_limit = 2,
227
- learning_rate= args.lr,
228
- eval_accumulation_steps=args.eval_acc,
229
- per_device_train_batch_size=args.bs,
230
- per_device_eval_batch_size=args.eval_bs,
231
- weight_decay=0.01,
232
- num_train_epochs=args.epoch,
233
- predict_with_generate=args.use_generate,
234
- report_to="none",
235
- )
236
- # evaluate at each epoch
237
- else:
238
- training_args = Seq2SeqTrainingArguments(
239
- save_dir,
240
- do_train=True if args.evaluate_dir is None else False,
241
- do_eval=True,
242
- evaluation_strategy="epoch",
243
- logging_strategy="steps",
244
- save_strategy="epoch",
245
- save_total_limit = 2,
246
- learning_rate= args.lr,
247
- eval_accumulation_steps=args.eval_acc,
248
- per_device_train_batch_size=args.bs,
249
- per_device_eval_batch_size=args.eval_bs,
250
- weight_decay=0.01,
251
- num_train_epochs=args.epoch,
252
- metric_for_best_model="accuracy" if args.prompt_format != "QCM-LE" else "rougeL",
253
- predict_with_generate=args.use_generate,
254
- load_best_model_at_end=True,
255
- report_to="none",
256
- )
257
-
258
- trainer = Seq2SeqTrainer(
259
- model=model,
260
- args=training_args,
261
- train_dataset=train_set,
262
- eval_dataset=eval_set,
263
- data_collator=datacollator,
264
- tokenizer=tokenizer,
265
- compute_metrics = compute_metrics_acc if args.prompt_format != "QCM-LE" else compute_metrics_rougel
266
- )
267
-
268
- if args.evaluate_dir is None:
269
- trainer.train()
270
- trainer.save_model(save_dir)
271
-
272
- metrics = trainer.evaluate(eval_dataset = test_set)
273
- trainer.log_metrics("test", metrics)
274
- trainer.save_metrics("test", metrics)
275
-
276
- predict_results = trainer.predict(test_dataset=test_set, max_length=args.output_len)
277
- if trainer.is_world_process_zero():
278
- if args.use_generate:
279
- preds, targets = predict_results.predictions, predict_results.label_ids
280
- else:
281
- preds = predict_results.predictions[0]
282
- targets = predict_results.label_ids
283
- preds = preds.argmax(axis=2)
284
-
285
- preds = tokenizer.batch_decode(
286
- preds, skip_special_tokens=True, clean_up_tokenization_spaces=True
287
- )
288
- targets = tokenizer.batch_decode(
289
- targets, skip_special_tokens=True, clean_up_tokenization_spaces=True
290
- )
291
-
292
- results_ans = {}
293
- results_rationale = {}
294
- results_reference = {}
295
-
296
- num_fail = 0
297
- for idx, qid in enumerate(test_qids):
298
- pred = preds[int(idx)]
299
- ref = targets[int(idx)]
300
- extract_pred = extract_ans(pred)
301
- if extract_pred != "FAILED":
302
- if extract_pred in args.options:
303
- extract_pred = args.options.index(extract_pred)
304
- else:
305
- extract_pred = random.choice(range(0,len(args.options)))
306
- else:
307
- num_fail += 1
308
- extract_pred = random.choice(range(len(args.options))) # random choose one option
309
- results_ans[str(qid)] = extract_pred
310
- results_rationale[str(qid)] = pred
311
- results_reference[str(qid)] = ref
312
-
313
- scores = get_scores(results_ans, results_rationale, results_reference, os.path.join(args.data_root, "scienceqa/problems.json"))
314
- preds = [pred.strip() for pred in preds]
315
- output_data = {
316
- "num_fail": num_fail,
317
- "scores": scores,
318
- "preds": preds,
319
- "labels": targets}
320
- output_prediction_file = os.path.join(save_dir,"predictions_ans_test.json")
321
- with open(output_prediction_file, "w") as writer:
322
- writer.write(json.dumps(output_data, indent=4))
323
-
324
- # generate the rationale for the eval set
325
- if args.prompt_format == "QCM-LE":
326
- torch.cuda.empty_cache()
327
- del predict_results, preds, targets
328
- predict_results = trainer.predict(test_dataset=eval_set, max_length=args.output_len)
329
- if trainer.is_world_process_zero():
330
- if args.use_generate:
331
- preds, targets = predict_results.predictions, predict_results.label_ids
332
- else:
333
- preds = predict_results.predictions[0]
334
- targets = predict_results.label_ids
335
- preds = preds.argmax(axis=2)
336
-
337
- preds = tokenizer.batch_decode(
338
- preds, skip_special_tokens=True, clean_up_tokenization_spaces=True
339
- )
340
- targets = tokenizer.batch_decode(
341
- targets, skip_special_tokens=True, clean_up_tokenization_spaces=True
342
- )
343
- preds = [pred.strip() for pred in preds]
344
- output_data = {"preds": preds,
345
- "labels": targets}
346
- output_prediction_file = os.path.join(save_dir,"predictions_ans_eval.json")
347
- with open(output_prediction_file, "w") as writer:
348
- writer.write(json.dumps(output_data, indent=4))
349
-
350
-
351
- if __name__ == '__main__':
352
-
353
- # training logger to log training progress
354
- training_logger = Table(
355
- Column("Epoch", justify="center"),
356
- Column("Steps", justify="center"),
357
- Column("Loss", justify="center"),
358
- title="Training Status",
359
- pad_edge=False,
360
- box=box.ASCII,
361
- )
362
-
363
- args = parse_args()
364
- print("args",args)
365
- print('====Input Arguments====')
366
- print(json.dumps(vars(args), indent=2, sort_keys=False))
367
-
368
- random.seed(args.seed)
369
-
370
- if not os.path.exists(args.output_dir):
371
- os.mkdir(args.output_dir)
372
-
373
- if args.img_type is not None:
374
- problems, qids, name_maps, image_features = load_data_img(args) # probelms, test question ids, shot example ids
375
- dataframe = {'problems':problems, 'qids':qids, 'name_maps': name_maps, 'image_features': image_features}
376
- else:
377
- problems, qids = load_data_std(args) # probelms, test question ids, shot example ids
378
- dataframe = {'problems':problems, 'qids':qids}
379
-
380
- T5Trainer(
381
- dataframe=dataframe,
382
- args = args
383
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mm-cot/mm-cot/CODE_OF_CONDUCT.md DELETED
@@ -1,4 +0,0 @@
1
- ## Code of Conduct
2
- This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
3
- For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
4
- [email protected] with any additional questions or comments.
 
 
 
 
 
mm-cot/mm-cot/CONTRIBUTING.md DELETED
@@ -1,59 +0,0 @@
1
- # Contributing Guidelines
2
-
3
- Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
4
- documentation, we greatly value feedback and contributions from our community.
5
-
6
- Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
7
- information to effectively respond to your bug report or contribution.
8
-
9
-
10
- ## Reporting Bugs/Feature Requests
11
-
12
- We welcome you to use the GitHub issue tracker to report bugs or suggest features.
13
-
14
- When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
15
- reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
16
-
17
- * A reproducible test case or series of steps
18
- * The version of our code being used
19
- * Any modifications you've made relevant to the bug
20
- * Anything unusual about your environment or deployment
21
-
22
-
23
- ## Contributing via Pull Requests
24
- Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
25
-
26
- 1. You are working against the latest source on the *main* branch.
27
- 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
28
- 3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
29
-
30
- To send us a pull request, please:
31
-
32
- 1. Fork the repository.
33
- 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
34
- 3. Ensure local tests pass.
35
- 4. Commit to your fork using clear commit messages.
36
- 5. Send us a pull request, answering any default questions in the pull request interface.
37
- 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
38
-
39
- GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
40
- [creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
41
-
42
-
43
- ## Finding contributions to work on
44
- Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.
45
-
46
-
47
- ## Code of Conduct
48
- This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
49
- For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
50
- [email protected] with any additional questions or comments.
51
-
52
-
53
- ## Security issue notifications
54
- If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
55
-
56
-
57
- ## Licensing
58
-
59
- See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mm-cot/mm-cot/LICENSE DELETED
@@ -1,175 +0,0 @@
1
-
2
- Apache License
3
- Version 2.0, January 2004
4
- http://www.apache.org/licenses/
5
-
6
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
-
8
- 1. Definitions.
9
-
10
- "License" shall mean the terms and conditions for use, reproduction,
11
- and distribution as defined by Sections 1 through 9 of this document.
12
-
13
- "Licensor" shall mean the copyright owner or entity authorized by
14
- the copyright owner that is granting the License.
15
-
16
- "Legal Entity" shall mean the union of the acting entity and all
17
- other entities that control, are controlled by, or are under common
18
- control with that entity. For the purposes of this definition,
19
- "control" means (i) the power, direct or indirect, to cause the
20
- direction or management of such entity, whether by contract or
21
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
- outstanding shares, or (iii) beneficial ownership of such entity.
23
-
24
- "You" (or "Your") shall mean an individual or Legal Entity
25
- exercising permissions granted by this License.
26
-
27
- "Source" form shall mean the preferred form for making modifications,
28
- including but not limited to software source code, documentation
29
- source, and configuration files.
30
-
31
- "Object" form shall mean any form resulting from mechanical
32
- transformation or translation of a Source form, including but
33
- not limited to compiled object code, generated documentation,
34
- and conversions to other media types.
35
-
36
- "Work" shall mean the work of authorship, whether in Source or
37
- Object form, made available under the License, as indicated by a
38
- copyright notice that is included in or attached to the work
39
- (an example is provided in the Appendix below).
40
-
41
- "Derivative Works" shall mean any work, whether in Source or Object
42
- form, that is based on (or derived from) the Work and for which the
43
- editorial revisions, annotations, elaborations, or other modifications
44
- represent, as a whole, an original work of authorship. For the purposes
45
- of this License, Derivative Works shall not include works that remain
46
- separable from, or merely link (or bind by name) to the interfaces of,
47
- the Work and Derivative Works thereof.
48
-
49
- "Contribution" shall mean any work of authorship, including
50
- the original version of the Work and any modifications or additions
51
- to that Work or Derivative Works thereof, that is intentionally
52
- submitted to Licensor for inclusion in the Work by the copyright owner
53
- or by an individual or Legal Entity authorized to submit on behalf of
54
- the copyright owner. For the purposes of this definition, "submitted"
55
- means any form of electronic, verbal, or written communication sent
56
- to the Licensor or its representatives, including but not limited to
57
- communication on electronic mailing lists, source code control systems,
58
- and issue tracking systems that are managed by, or on behalf of, the
59
- Licensor for the purpose of discussing and improving the Work, but
60
- excluding communication that is conspicuously marked or otherwise
61
- designated in writing by the copyright owner as "Not a Contribution."
62
-
63
- "Contributor" shall mean Licensor and any individual or Legal Entity
64
- on behalf of whom a Contribution has been received by Licensor and
65
- subsequently incorporated within the Work.
66
-
67
- 2. Grant of Copyright License. Subject to the terms and conditions of
68
- this License, each Contributor hereby grants to You a perpetual,
69
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
- copyright license to reproduce, prepare Derivative Works of,
71
- publicly display, publicly perform, sublicense, and distribute the
72
- Work and such Derivative Works in Source or Object form.
73
-
74
- 3. Grant of Patent License. Subject to the terms and conditions of
75
- this License, each Contributor hereby grants to You a perpetual,
76
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
- (except as stated in this section) patent license to make, have made,
78
- use, offer to sell, sell, import, and otherwise transfer the Work,
79
- where such license applies only to those patent claims licensable
80
- by such Contributor that are necessarily infringed by their
81
- Contribution(s) alone or by combination of their Contribution(s)
82
- with the Work to which such Contribution(s) was submitted. If You
83
- institute patent litigation against any entity (including a
84
- cross-claim or counterclaim in a lawsuit) alleging that the Work
85
- or a Contribution incorporated within the Work constitutes direct
86
- or contributory patent infringement, then any patent licenses
87
- granted to You under this License for that Work shall terminate
88
- as of the date such litigation is filed.
89
-
90
- 4. Redistribution. You may reproduce and distribute copies of the
91
- Work or Derivative Works thereof in any medium, with or without
92
- modifications, and in Source or Object form, provided that You
93
- meet the following conditions:
94
-
95
- (a) You must give any other recipients of the Work or
96
- Derivative Works a copy of this License; and
97
-
98
- (b) You must cause any modified files to carry prominent notices
99
- stating that You changed the files; and
100
-
101
- (c) You must retain, in the Source form of any Derivative Works
102
- that You distribute, all copyright, patent, trademark, and
103
- attribution notices from the Source form of the Work,
104
- excluding those notices that do not pertain to any part of
105
- the Derivative Works; and
106
-
107
- (d) If the Work includes a "NOTICE" text file as part of its
108
- distribution, then any Derivative Works that You distribute must
109
- include a readable copy of the attribution notices contained
110
- within such NOTICE file, excluding those notices that do not
111
- pertain to any part of the Derivative Works, in at least one
112
- of the following places: within a NOTICE text file distributed
113
- as part of the Derivative Works; within the Source form or
114
- documentation, if provided along with the Derivative Works; or,
115
- within a display generated by the Derivative Works, if and
116
- wherever such third-party notices normally appear. The contents
117
- of the NOTICE file are for informational purposes only and
118
- do not modify the License. You may add Your own attribution
119
- notices within Derivative Works that You distribute, alongside
120
- or as an addendum to the NOTICE text from the Work, provided
121
- that such additional attribution notices cannot be construed
122
- as modifying the License.
123
-
124
- You may add Your own copyright statement to Your modifications and
125
- may provide additional or different license terms and conditions
126
- for use, reproduction, or distribution of Your modifications, or
127
- for any such Derivative Works as a whole, provided Your use,
128
- reproduction, and distribution of the Work otherwise complies with
129
- the conditions stated in this License.
130
-
131
- 5. Submission of Contributions. Unless You explicitly state otherwise,
132
- any Contribution intentionally submitted for inclusion in the Work
133
- by You to the Licensor shall be under the terms and conditions of
134
- this License, without any additional terms or conditions.
135
- Notwithstanding the above, nothing herein shall supersede or modify
136
- the terms of any separate license agreement you may have executed
137
- with Licensor regarding such Contributions.
138
-
139
- 6. Trademarks. This License does not grant permission to use the trade
140
- names, trademarks, service marks, or product names of the Licensor,
141
- except as required for reasonable and customary use in describing the
142
- origin of the Work and reproducing the content of the NOTICE file.
143
-
144
- 7. Disclaimer of Warranty. Unless required by applicable law or
145
- agreed to in writing, Licensor provides the Work (and each
146
- Contributor provides its Contributions) on an "AS IS" BASIS,
147
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
- implied, including, without limitation, any warranties or conditions
149
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
- PARTICULAR PURPOSE. You are solely responsible for determining the
151
- appropriateness of using or redistributing the Work and assume any
152
- risks associated with Your exercise of permissions under this License.
153
-
154
- 8. Limitation of Liability. In no event and under no legal theory,
155
- whether in tort (including negligence), contract, or otherwise,
156
- unless required by applicable law (such as deliberate and grossly
157
- negligent acts) or agreed to in writing, shall any Contributor be
158
- liable to You for damages, including any direct, indirect, special,
159
- incidental, or consequential damages of any character arising as a
160
- result of this License or out of the use or inability to use the
161
- Work (including but not limited to damages for loss of goodwill,
162
- work stoppage, computer failure or malfunction, or any and all
163
- other commercial damages or losses), even if such Contributor
164
- has been advised of the possibility of such damages.
165
-
166
- 9. Accepting Warranty or Additional Liability. While redistributing
167
- the Work or Derivative Works thereof, You may choose to offer,
168
- and charge a fee for, acceptance of support, warranty, indemnity,
169
- or other liability obligations and/or rights consistent with this
170
- License. However, in accepting such obligations, You may act only
171
- on Your own behalf and on Your sole responsibility, not on behalf
172
- of any other Contributor, and only if You agree to indemnify,
173
- defend, and hold each Contributor harmless for any liability
174
- incurred by, or claims asserted against, such Contributor by reason
175
- of your accepting any such warranty or additional liability.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mm-cot/mm-cot/NOTICE DELETED
@@ -1 +0,0 @@
1
- Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 
 
mm-cot/mm-cot/README.md DELETED
@@ -1,93 +0,0 @@
1
- # Multimodal Chain-of-Thought Reasoning in Language Models
2
-
3
- <h5 align="center"><i>"Imagine learning a textbook without figures or tables."</i></h5>
4
-
5
- Multimodal-CoT incorporates vision features in a decoupled training framework. The framework consists of two training stages: (i) rationale generation and (ii) answer inference. Both stages share the same model architecture but differ in the input and output.
6
-
7
- ![](vision_features/mm-cot.png)
8
-
9
-
10
- ## Requirements
11
-
12
- Install all required python dependencies:
13
-
14
- ```
15
- pip install -r requirements.txt
16
- ```
17
-
18
- ## Datasets
19
-
20
- Download the dataset from the following repository:
21
-
22
- ```
23
- https://github.com/lupantech/ScienceQA/tree/main/data
24
- ```
25
-
26
- Download the extracted vision features from [vision_features](https://drive.google.com/file/d/13B0hc_F_45-UlqPLKSgRz-ALtFQ8kIJr/view?usp=share_link) and unzip the files under `vision_features`
27
-
28
- ## Instructions
29
-
30
- ### Training
31
-
32
- ```
33
- # rationale generation
34
- CUDA_VISIBLE_DEVICES=0,1 python main.py \
35
- --model allenai/unifiedqa-t5-base \
36
- --user_msg rationale --img_type detr \
37
- --bs 8 --eval_bs 4 --eval_acc 10 --output_len 512 \
38
- --final_eval --prompt_format QCM-LE
39
-
40
- # answer inference
41
- CUDA_VISIBLE_DEVICES=0,1 python main.py \
42
- --model allenai/unifiedqa-t5-base \
43
- --user_msg answer --img_type detr \
44
- --bs 8 --eval_bs 4 --eval_acc 10 --output_len 64 \
45
- --final_eval --prompt_format QCMG-A \
46
- --eval_le experiments/rationale_allenai-unifiedqa-t5-base_detr_QCM-LE_lr5e-05_bs16_op512_ep20/predictions_ans_eval.json \
47
- --test_le experiments/rationale_allenai-unifiedqa-t5-base_detr_QCM-LE_lr5e-05_bs16_op512_ep20/predictions_ans_test.json
48
- ```
49
-
50
- ### Inference
51
-
52
- Our trained models are available at [models](https://drive.google.com/file/d/1FtTYOJPHnWnFfCxNC6M3gar4RAX5E21b/view?usp=share_link). To use our trained models, please put the them under the ```models``` folder.
53
-
54
- ```
55
- # rationale generation
56
- CUDA_VISIBLE_DEVICES=0,1 python main.py \
57
- --model allenai/unifiedqa-t5-base \
58
- --user_msg rationale --img_type detr \
59
- --bs 8 --eval_bs 4 --eval_acc 10 --output_len 512 \
60
- --final_eval --prompt_format QCM-LE \
61
- --evaluate_dir models/MM-CoT-UnifiedQA-base-Rationale
62
-
63
- # answer inference
64
- CUDA_VISIBLE_DEVICES=0,1 python main.py \
65
- --model allenai/unifiedqa-t5-base \
66
- --user_msg answer --img_type detr \
67
- --bs 8 --eval_bs 4 --eval_acc 10 --output_len 64 \
68
- --final_eval --prompt_format QCMG-A \
69
- --eval_le models/rationale/predictions_ans_eval.json \
70
- --test_le models/rationale/predictions_ans_test.json \
71
- --evaluate_dir models/MM-CoT-UnifiedQA-base-Answer
72
- ```
73
-
74
- ## Citing MM-CoT
75
-
76
- ```
77
- @article{zhang2023multicot,
78
- title={Multimodal Chain-of-Thought Reasoning in Language Models},
79
- author={Zhang, Zhuosheng and Zhang, Aston and Li, Mu and Zhao, Hai and Karypis, George and Smola, Alex},
80
- journal={arXiv preprint arXiv:2302.00923},
81
- year={2023}
82
- }
83
- ```
84
-
85
- ## License
86
-
87
- This project is licensed under the Apache-2.0 License.
88
-
89
- ## Acknowledgement
90
-
91
- Part of our codes are adapted from [ScienceQA](https://github.com/lupantech/ScienceQA) and [Transformers](https://github.com/huggingface/transformers).
92
-
93
- We thank Pan Lu for providing parameter size for ScienceQA baselines.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mm-cot/mm-cot/evaluations.py DELETED
@@ -1,100 +0,0 @@
1
- '''
2
- Adapted from https://github.com/lupantech/ScienceQA
3
- '''
4
-
5
- import re
6
- from rouge import Rouge
7
- from nltk.translate.bleu_score import sentence_bleu
8
- from sentence_transformers import util
9
-
10
- ########################
11
- ## BLEU
12
- ########################
13
- def tokenize(text):
14
- tokens = re.split(r'\s|\.', text)
15
- tokens = [t for t in tokens if len(t) > 0]
16
- return tokens
17
-
18
-
19
- def bleu_score(reference, hypothesis, gram):
20
- reference_tokens = tokenize(reference)
21
- hypothesis_tokens = tokenize(hypothesis)
22
-
23
- if gram == 1:
24
- bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1., )) # BELU-1
25
- elif gram == 2:
26
- bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1. / 2., 1. / 2.)) # BELU-2
27
- elif gram == 3:
28
- bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1. / 3., 1. / 3., 1. / 3.)) # BELU-3
29
- elif gram == 4:
30
- bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1. / 4., 1. / 4., 1. / 4., 1. / 4.)) # BELU-4
31
-
32
- return bleu
33
-
34
-
35
- def caculate_bleu(results, data, gram):
36
- bleus = []
37
- for qid, output in results.items():
38
- prediction = output
39
- target = data[qid]
40
- target = target.strip()
41
- if target == "":
42
- continue
43
- bleu = bleu_score(target, prediction, gram)
44
- bleus.append(bleu)
45
-
46
- avg_bleu = sum(bleus) / len(bleus)
47
-
48
- return avg_bleu
49
-
50
-
51
- ########################
52
- ## Rouge-L
53
- ########################
54
- def score_rouge(str1, str2):
55
- rouge = Rouge(metrics=["rouge-l"])
56
- scores = rouge.get_scores(str1, str2, avg=True)
57
- rouge_l = scores['rouge-l']['f']
58
- return rouge_l
59
-
60
-
61
- def caculate_rouge(results, data):
62
- rouges = []
63
- for qid, output in results.items():
64
- prediction = output
65
- target = data[qid]
66
- target = target.strip()
67
- if prediction == "":
68
- continue
69
- if target == "":
70
- continue
71
- rouge = score_rouge(target, prediction)
72
- rouges.append(rouge)
73
-
74
- avg_rouge = sum(rouges) / len(rouges)
75
- return avg_rouge
76
-
77
-
78
- ########################
79
- ## Sentence Similarity
80
- ########################
81
- def similariry_score(str1, str2, model):
82
- # compute embedding for both lists
83
- embedding_1 = model.encode(str1, convert_to_tensor=True)
84
- embedding_2 = model.encode(str2, convert_to_tensor=True)
85
- score = util.pytorch_cos_sim(embedding_1, embedding_2).item()
86
- return score
87
-
88
-
89
- def caculate_similariry(results, data, model):
90
- scores = []
91
- for qid, output in results.items():
92
- prediction = output
93
- target = data[qid]
94
- target = target.strip()
95
-
96
- score = similariry_score(target, prediction, model)
97
- scores.append(score)
98
-
99
- avg_score = sum(scores) / len(scores)
100
- return avg_score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mm-cot/mm-cot/main.py DELETED
@@ -1,383 +0,0 @@
1
- import os
2
- import numpy as np
3
- import torch
4
- import os
5
- import re
6
- import json
7
- import argparse
8
- import random
9
- from transformers import T5Tokenizer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, T5ForConditionalGeneration
10
- from model import T5ForConditionalGeneration, T5ForMultimodalGeneration
11
- from utils_data import img_shape, load_data_std, load_data_img, ScienceQADatasetStd, ScienceQADatasetImg
12
- from utils_prompt import *
13
- from utils_evaluate import get_scores
14
- from rich.table import Column, Table
15
- from rich import box
16
- from rich.console import Console
17
- console = Console(record=True)
18
- from torch import cuda
19
- import nltk
20
- import evaluate
21
-
22
-
23
- def parse_args():
24
- parser = argparse.ArgumentParser()
25
- parser.add_argument('--data_root', type=str, default='data')
26
- parser.add_argument('--output_dir', type=str, default='experiments')
27
- parser.add_argument('--model', type=str, default='allenai/unifiedqa-t5-base')
28
- parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
29
- parser.add_argument('--epoch', type=int, default=20)
30
- parser.add_argument('--lr', type=float, default=5e-5)
31
- parser.add_argument('--bs', type=int, default=16)
32
- parser.add_argument('--input_len', type=int, default=512)
33
- parser.add_argument('--output_len', type=int, default=64)
34
- parser.add_argument('--eval_bs', type=int, default=16)
35
- parser.add_argument('--eval_acc', type=int, default=None, help='evaluate accumulation step')
36
- parser.add_argument('--train_split', type=str, default='train', choices=['train', 'trainval', 'minitrain'])
37
- parser.add_argument('--val_split', type=str, default='val', choices=['test', 'val', 'minival'])
38
- parser.add_argument('--test_split', type=str, default='test', choices=['test', 'minitest'])
39
-
40
- parser.add_argument('--use_generate', action='store_true', help='only for baseline to improve inference speed')
41
- parser.add_argument('--final_eval', action='store_true', help='only evaluate the model at the final epoch')
42
- parser.add_argument('--user_msg', type=str, default="baseline", help='experiment type in the save_dir')
43
- parser.add_argument('--img_type', type=str, default=None, choices=['detr', 'clip', 'resnet'], help='type of image features')
44
- parser.add_argument('--eval_le', type=str, default=None, help='generated rationale for the dev set')
45
- parser.add_argument('--test_le', type=str, default=None, help='generated rationale for the test set')
46
- parser.add_argument('--evaluate_dir', type=str, default=None, help='the directory of model for evaluation')
47
- parser.add_argument('--caption_file', type=str, default='data/captions.json')
48
- parser.add_argument('--use_caption', action='store_true', help='use image captions or not')
49
- parser.add_argument('--prompt_format', type=str, default='QCM-A', help='prompt format template',
50
- choices=['QCM-A', 'QCM-LE', 'QCMG-A', 'QCM-LEA', 'QCM-ALE'])
51
- parser.add_argument('--seed', type=int, default=42, help='random seed')
52
-
53
- args = parser.parse_args()
54
- return args
55
-
56
- def T5Trainer(
57
- dataframe, args,
58
- ):
59
- torch.manual_seed(args.seed) # pytorch random seed
60
- np.random.seed(args.seed) # numpy random seed
61
- torch.backends.cudnn.deterministic = True
62
-
63
- if args.evaluate_dir is not None:
64
- args.model = args.evaluate_dir
65
-
66
- tokenizer = T5Tokenizer.from_pretrained(args.model)
67
-
68
- console.log(f"""[Model]: Loading {args.model}...\n""")
69
- console.log(f"[Data]: Reading data...\n")
70
- problems = dataframe['problems']
71
- qids = dataframe['qids']
72
- train_qids = qids['train']
73
- test_qids = qids['test']
74
- val_qids = qids['val']
75
-
76
- if args.evaluate_dir is not None:
77
- save_dir = args.evaluate_dir
78
- else:
79
- model_name = args.model.replace("/","-")
80
- gpu_count = torch.cuda.device_count()
81
- save_dir = f"{args.output_dir}/{args.user_msg}_{model_name}_{args.img_type}_{args.prompt_format}_lr{args.lr}_bs{args.bs * gpu_count}_op{args.output_len}_ep{args.epoch}"
82
- if not os.path.exists(save_dir):
83
- os.mkdir(save_dir)
84
-
85
- padding_idx = tokenizer._convert_token_to_id(tokenizer.pad_token)
86
- if args.img_type is not None:
87
- patch_size = img_shape[args.img_type]
88
- model = T5ForMultimodalGeneration.from_pretrained(args.model, patch_size=patch_size, padding_idx=padding_idx, save_dir=save_dir)
89
- name_maps = dataframe['name_maps']
90
- image_features = dataframe['image_features']
91
- train_set = ScienceQADatasetImg(
92
- problems,
93
- train_qids,
94
- name_maps,
95
- tokenizer,
96
- args.input_len,
97
- args.output_len,
98
- args,
99
- image_features,
100
- )
101
- eval_set = ScienceQADatasetImg(
102
- problems,
103
- val_qids,
104
- name_maps,
105
- tokenizer,
106
- args.input_len,
107
- args.output_len,
108
- args,
109
- image_features,
110
- args.eval_le,
111
- )
112
- test_set = ScienceQADatasetImg(
113
- problems,
114
- test_qids,
115
- name_maps,
116
- tokenizer,
117
- args.input_len,
118
- args.output_len,
119
- args,
120
- image_features,
121
- args.test_le,
122
- )
123
- else:
124
- model = T5ForConditionalGeneration.from_pretrained(args.model)
125
- train_set = ScienceQADatasetStd(
126
- problems,
127
- train_qids,
128
- tokenizer,
129
- args.input_len,
130
- args.output_len,
131
- args,
132
- )
133
- eval_set = ScienceQADatasetStd(
134
- problems,
135
- val_qids,
136
- tokenizer,
137
- args.input_len,
138
- args.output_len,
139
- args,
140
- args.eval_le,
141
- )
142
-
143
- test_set = ScienceQADatasetStd(
144
- problems,
145
- test_qids,
146
- tokenizer,
147
- args.input_len,
148
- args.output_len,
149
- args,
150
- args.test_le,
151
- )
152
-
153
- datacollator = DataCollatorForSeq2Seq(tokenizer)
154
- print("model parameters: ", model.num_parameters())
155
- def extract_ans(ans):
156
- pattern = re.compile(r'The answer is \(([A-Z])\)')
157
- res = pattern.findall(ans)
158
-
159
- if len(res) == 1:
160
- answer = res[0] # 'A', 'B', ...
161
- else:
162
- answer = "FAILED"
163
- return answer
164
-
165
- # accuracy for answer inference
166
- def compute_metrics_acc(eval_preds):
167
- if args.use_generate:
168
- preds, targets = eval_preds
169
- if isinstance(preds, tuple):
170
- preds = preds[0]
171
- else:
172
- preds = eval_preds.predictions[0]
173
- targets = eval_preds.label_ids
174
- preds = preds.argmax(axis=2)
175
- preds = tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
176
- targets = tokenizer.batch_decode(targets, skip_special_tokens=True, clean_up_tokenization_spaces=True)
177
- correct = 0
178
- assert len(preds) == len(targets)
179
- for idx, pred in enumerate(preds):
180
- reference = targets[idx]
181
- reference = extract_ans(reference)
182
- extract_pred = extract_ans(pred)
183
- best_option = extract_pred
184
- if reference == best_option:
185
- correct +=1
186
- return {'accuracy': 1.0*correct/len(targets)}
187
-
188
- # rougel for rationale generation
189
- metric = evaluate.load("rouge")
190
- def postprocess_text(preds, labels):
191
- preds = [pred.strip() for pred in preds]
192
- labels = [label.strip() for label in labels]
193
- preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
194
- labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
195
- return preds, labels
196
-
197
- def compute_metrics_rougel(eval_preds):
198
- if args.use_generate:
199
- preds, targets = eval_preds
200
- if isinstance(preds, tuple):
201
- preds = preds[0]
202
- else:
203
- preds = eval_preds.predictions[0]
204
- targets = eval_preds.label_ids
205
- preds = preds.argmax(axis=2)
206
- preds = tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
207
- targets = tokenizer.batch_decode(targets, skip_special_tokens=True, clean_up_tokenization_spaces=True)
208
-
209
- decoded_preds, decoded_labels = postprocess_text(preds, targets)
210
-
211
- result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
212
- result = {k: round(v * 100, 4) for k, v in result.items()}
213
- prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
214
- result["gen_len"] = np.mean(prediction_lens)
215
- return result
216
-
217
- # only use the last model for evaluation to save time
218
- if args.final_eval:
219
- training_args = Seq2SeqTrainingArguments(
220
- save_dir,
221
- do_train=True if args.evaluate_dir is None else False,
222
- do_eval=False,
223
- evaluation_strategy="no",
224
- logging_strategy="steps",
225
- save_strategy="epoch",
226
- save_total_limit = 2,
227
- learning_rate= args.lr,
228
- eval_accumulation_steps=args.eval_acc,
229
- per_device_train_batch_size=args.bs,
230
- per_device_eval_batch_size=args.eval_bs,
231
- weight_decay=0.01,
232
- num_train_epochs=args.epoch,
233
- predict_with_generate=args.use_generate,
234
- report_to="none",
235
- )
236
- # evaluate at each epoch
237
- else:
238
- training_args = Seq2SeqTrainingArguments(
239
- save_dir,
240
- do_train=True if args.evaluate_dir is None else False,
241
- do_eval=True,
242
- evaluation_strategy="epoch",
243
- logging_strategy="steps",
244
- save_strategy="epoch",
245
- save_total_limit = 2,
246
- learning_rate= args.lr,
247
- eval_accumulation_steps=args.eval_acc,
248
- per_device_train_batch_size=args.bs,
249
- per_device_eval_batch_size=args.eval_bs,
250
- weight_decay=0.01,
251
- num_train_epochs=args.epoch,
252
- metric_for_best_model="accuracy" if args.prompt_format != "QCM-LE" else "rougeL",
253
- predict_with_generate=args.use_generate,
254
- load_best_model_at_end=True,
255
- report_to="none",
256
- )
257
-
258
- trainer = Seq2SeqTrainer(
259
- model=model,
260
- args=training_args,
261
- train_dataset=train_set,
262
- eval_dataset=eval_set,
263
- data_collator=datacollator,
264
- tokenizer=tokenizer,
265
- compute_metrics = compute_metrics_acc if args.prompt_format != "QCM-LE" else compute_metrics_rougel
266
- )
267
-
268
- if args.evaluate_dir is None:
269
- trainer.train()
270
- trainer.save_model(save_dir)
271
-
272
- metrics = trainer.evaluate(eval_dataset = test_set)
273
- trainer.log_metrics("test", metrics)
274
- trainer.save_metrics("test", metrics)
275
-
276
- predict_results = trainer.predict(test_dataset=test_set, max_length=args.output_len)
277
- if trainer.is_world_process_zero():
278
- if args.use_generate:
279
- preds, targets = predict_results.predictions, predict_results.label_ids
280
- else:
281
- preds = predict_results.predictions[0]
282
- targets = predict_results.label_ids
283
- preds = preds.argmax(axis=2)
284
-
285
- preds = tokenizer.batch_decode(
286
- preds, skip_special_tokens=True, clean_up_tokenization_spaces=True
287
- )
288
- targets = tokenizer.batch_decode(
289
- targets, skip_special_tokens=True, clean_up_tokenization_spaces=True
290
- )
291
-
292
- results_ans = {}
293
- results_rationale = {}
294
- results_reference = {}
295
-
296
- num_fail = 0
297
- for idx, qid in enumerate(test_qids):
298
- pred = preds[int(idx)]
299
- ref = targets[int(idx)]
300
- extract_pred = extract_ans(pred)
301
- if extract_pred != "FAILED":
302
- if extract_pred in args.options:
303
- extract_pred = args.options.index(extract_pred)
304
- else:
305
- extract_pred = random.choice(range(0,len(args.options)))
306
- else:
307
- num_fail += 1
308
- extract_pred = random.choice(range(len(args.options))) # random choose one option
309
- results_ans[str(qid)] = extract_pred
310
- results_rationale[str(qid)] = pred
311
- results_reference[str(qid)] = ref
312
-
313
- scores = get_scores(results_ans, results_rationale, results_reference, os.path.join(args.data_root, "scienceqa/problems.json"))
314
- preds = [pred.strip() for pred in preds]
315
- output_data = {
316
- "num_fail": num_fail,
317
- "scores": scores,
318
- "preds": preds,
319
- "labels": targets}
320
- output_prediction_file = os.path.join(save_dir,"predictions_ans_test.json")
321
- with open(output_prediction_file, "w") as writer:
322
- writer.write(json.dumps(output_data, indent=4))
323
-
324
- # generate the rationale for the eval set
325
- if args.prompt_format == "QCM-LE":
326
- torch.cuda.empty_cache()
327
- del predict_results, preds, targets
328
- predict_results = trainer.predict(test_dataset=eval_set, max_length=args.output_len)
329
- if trainer.is_world_process_zero():
330
- if args.use_generate:
331
- preds, targets = predict_results.predictions, predict_results.label_ids
332
- else:
333
- preds = predict_results.predictions[0]
334
- targets = predict_results.label_ids
335
- preds = preds.argmax(axis=2)
336
-
337
- preds = tokenizer.batch_decode(
338
- preds, skip_special_tokens=True, clean_up_tokenization_spaces=True
339
- )
340
- targets = tokenizer.batch_decode(
341
- targets, skip_special_tokens=True, clean_up_tokenization_spaces=True
342
- )
343
- preds = [pred.strip() for pred in preds]
344
- output_data = {"preds": preds,
345
- "labels": targets}
346
- output_prediction_file = os.path.join(save_dir,"predictions_ans_eval.json")
347
- with open(output_prediction_file, "w") as writer:
348
- writer.write(json.dumps(output_data, indent=4))
349
-
350
-
351
- if __name__ == '__main__':
352
-
353
- # training logger to log training progress
354
- training_logger = Table(
355
- Column("Epoch", justify="center"),
356
- Column("Steps", justify="center"),
357
- Column("Loss", justify="center"),
358
- title="Training Status",
359
- pad_edge=False,
360
- box=box.ASCII,
361
- )
362
-
363
- args = parse_args()
364
- print("args",args)
365
- print('====Input Arguments====')
366
- print(json.dumps(vars(args), indent=2, sort_keys=False))
367
-
368
- random.seed(args.seed)
369
-
370
- if not os.path.exists(args.output_dir):
371
- os.mkdir(args.output_dir)
372
-
373
- if args.img_type is not None:
374
- problems, qids, name_maps, image_features = load_data_img(args) # probelms, test question ids, shot example ids
375
- dataframe = {'problems':problems, 'qids':qids, 'name_maps': name_maps, 'image_features': image_features}
376
- else:
377
- problems, qids = load_data_std(args) # probelms, test question ids, shot example ids
378
- dataframe = {'problems':problems, 'qids':qids}
379
-
380
- T5Trainer(
381
- dataframe=dataframe,
382
- args = args
383
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mm-cot/mm-cot/model.py DELETED
@@ -1,194 +0,0 @@
1
- '''
2
- Adapted from https://github.com/huggingface/transformers
3
- '''
4
-
5
- from transformers import T5Config, T5ForConditionalGeneration
6
- from transformers.models.t5.modeling_t5 import T5Stack, __HEAD_MASK_WARNING_MSG, T5EncoderModel
7
- import copy
8
- import math
9
- import os
10
- import warnings
11
- from typing import Optional, Tuple, Union
12
- import torch
13
- from torch import nn
14
- from torch.nn import CrossEntropyLoss
15
- from transformers.modeling_outputs import (
16
- BaseModelOutput,
17
- Seq2SeqLMOutput,
18
- )
19
-
20
- class T5ForMultimodalGeneration(T5ForConditionalGeneration):
21
- _keys_to_ignore_on_load_missing = [
22
- r"encoder.embed_tokens.weight",
23
- r"decoder.embed_tokens.weight",
24
- r"lm_head.weight",
25
- ]
26
- _keys_to_ignore_on_load_unexpected = [
27
- r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
28
- ]
29
-
30
- def __init__(self, config: T5Config, patch_size, padding_idx, save_dir):
31
- super().__init__(config)
32
- self.model_dim = config.d_model
33
-
34
- self.padding_idx = padding_idx
35
- self.out = open(os.path.join(save_dir, 'gate.txt'), 'w')
36
-
37
- self.shared = nn.Embedding(config.vocab_size, config.d_model)
38
- self.patch_num, self.patch_dim = patch_size
39
-
40
- self.image_dense = nn.Linear(self.patch_dim, config.d_model)
41
- self.mha_layer = torch.nn.MultiheadAttention(embed_dim=config.hidden_size, kdim=config.hidden_size, vdim=config.hidden_size, num_heads=1, batch_first=True)
42
- self.gate_dense = nn.Linear(2*config.hidden_size, config.hidden_size)
43
- self.sigmoid = nn.Sigmoid()
44
-
45
- encoder_config = copy.deepcopy(config)
46
- encoder_config.is_decoder = False
47
- encoder_config.use_cache = False
48
- encoder_config.is_encoder_decoder = False
49
- self.encoder = T5Stack(encoder_config, self.shared)
50
-
51
- decoder_config = copy.deepcopy(config)
52
- decoder_config.is_decoder = True
53
- decoder_config.is_encoder_decoder = False
54
- decoder_config.num_layers = config.num_decoder_layers
55
- self.decoder = T5Stack(decoder_config, self.shared)
56
-
57
- self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
58
-
59
- # Initialize weights and apply final processing
60
- self.post_init()
61
-
62
- # Model parallel
63
- self.model_parallel = False
64
- self.device_map = None
65
-
66
- def forward(
67
- self,
68
- input_ids: Optional[torch.LongTensor] = None,
69
- image_ids=None,
70
- attention_mask: Optional[torch.FloatTensor] = None,
71
- decoder_input_ids: Optional[torch.LongTensor] = None,
72
- decoder_attention_mask: Optional[torch.BoolTensor] = None,
73
- head_mask: Optional[torch.FloatTensor] = None,
74
- decoder_head_mask: Optional[torch.FloatTensor] = None,
75
- cross_attn_head_mask: Optional[torch.Tensor] = None,
76
- encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
77
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
78
- inputs_embeds: Optional[torch.FloatTensor] = None,
79
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
80
- labels: Optional[torch.LongTensor] = None,
81
- use_cache: Optional[bool] = None,
82
- output_attentions: Optional[bool] = None,
83
- output_hidden_states: Optional[bool] = None,
84
- return_dict: Optional[bool] = None,
85
- ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
86
- use_cache = use_cache if use_cache is not None else self.config.use_cache
87
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
88
-
89
- # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
90
- if head_mask is not None and decoder_head_mask is None:
91
- if self.config.num_layers == self.config.num_decoder_layers:
92
- warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
93
- decoder_head_mask = head_mask
94
-
95
- # Encode if needed (training, first prediction pass)
96
- if encoder_outputs is None:
97
- # Convert encoder inputs in embeddings if needed
98
- encoder_outputs = self.encoder(
99
- input_ids=input_ids,
100
- attention_mask=attention_mask,
101
- inputs_embeds=inputs_embeds,
102
- head_mask=head_mask,
103
- output_attentions=output_attentions,
104
- output_hidden_states=output_hidden_states,
105
- return_dict=return_dict,
106
- )
107
-
108
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
109
- encoder_outputs = BaseModelOutput(
110
- last_hidden_state=encoder_outputs[0],
111
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
112
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
113
- )
114
-
115
-
116
- hidden_states = encoder_outputs[0]
117
-
118
- image_embedding = self.image_dense(image_ids)
119
- image_att, _ = self.mha_layer(hidden_states, image_embedding, image_embedding)
120
-
121
- merge = torch.cat([hidden_states, image_att], dim=-1)
122
- gate = self.sigmoid(self.gate_dense(merge))
123
- hidden_states = (1 - gate) * hidden_states + gate * image_att
124
-
125
- if self.model_parallel:
126
- torch.cuda.set_device(self.decoder.first_device)
127
-
128
- if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
129
- # get decoder inputs from shifting lm labels to the right
130
- decoder_input_ids = self._shift_right(labels)
131
-
132
- # Set device for model parallelism
133
- if self.model_parallel:
134
- torch.cuda.set_device(self.decoder.first_device)
135
- hidden_states = hidden_states.to(self.decoder.first_device)
136
- if decoder_input_ids is not None:
137
- decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
138
- if attention_mask is not None:
139
- attention_mask = attention_mask.to(self.decoder.first_device)
140
- if decoder_attention_mask is not None:
141
- decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
142
-
143
- # Decode
144
- decoder_outputs = self.decoder(
145
- input_ids=decoder_input_ids,
146
- attention_mask=decoder_attention_mask,
147
- inputs_embeds=decoder_inputs_embeds,
148
- past_key_values=past_key_values,
149
- encoder_hidden_states=hidden_states,
150
- encoder_attention_mask=attention_mask,
151
- head_mask=decoder_head_mask,
152
- cross_attn_head_mask=cross_attn_head_mask,
153
- use_cache=use_cache,
154
- output_attentions=output_attentions,
155
- output_hidden_states=output_hidden_states,
156
- return_dict=return_dict,
157
- )
158
-
159
- sequence_output = decoder_outputs[0]
160
-
161
- # Set device for model parallelism
162
- if self.model_parallel:
163
- torch.cuda.set_device(self.encoder.first_device)
164
- self.lm_head = self.lm_head.to(self.encoder.first_device)
165
- sequence_output = sequence_output.to(self.lm_head.weight.device)
166
-
167
- if self.config.tie_word_embeddings:
168
- # Rescale output before projecting on vocab
169
- # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
170
- sequence_output = sequence_output * (self.model_dim**-0.5)
171
-
172
- lm_logits = self.lm_head(sequence_output)
173
-
174
- loss = None
175
- if labels is not None:
176
- loss_fct = CrossEntropyLoss(ignore_index=-100)
177
- loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
178
- # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
179
-
180
- if not return_dict:
181
- output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
182
- return ((loss,) + output) if loss is not None else output
183
-
184
- return Seq2SeqLMOutput(
185
- loss=loss,
186
- logits=lm_logits,
187
- past_key_values=decoder_outputs.past_key_values,
188
- decoder_hidden_states=decoder_outputs.hidden_states,
189
- decoder_attentions=decoder_outputs.attentions,
190
- cross_attentions=decoder_outputs.cross_attentions,
191
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
192
- encoder_hidden_states=encoder_outputs.hidden_states,
193
- encoder_attentions=encoder_outputs.attentions,
194
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mm-cot/mm-cot/requirements.txt DELETED
@@ -1,11 +0,0 @@
1
- huggingface-hub==0.0.12
2
- numpy==1.23.2
3
- openai==0.23.0
4
- pandas==1.4.3
5
- rouge==1.0.1
6
- sentence-transformers==2.2.2
7
- transformers==4.21.1
8
- nltk==3.6.6
9
- evaluate==0.4.0
10
- rouge==1.0.1
11
- rouge_score==0.1.2
 
 
 
 
 
 
 
 
 
 
 
 
mm-cot/mm-cot/run_inference.sh DELETED
@@ -1,17 +0,0 @@
1
- # rationale generation
2
- CUDA_VISIBLE_DEVICES=0,1 python main.py \
3
- --model allenai/unifiedqa-t5-base \
4
- --user_msg rationale --img_type detr \
5
- --bs 8 --eval_bs 4 --eval_acc 10 --output_len 512 \
6
- --final_eval --prompt_format QCM-LE \
7
- --evaluate_dir models/rationale
8
-
9
- # answer inference
10
- CUDA_VISIBLE_DEVICES=0,1 python main.py \
11
- --model allenai/unifiedqa-t5-base \
12
- --user_msg answer --img_type detr \
13
- --bs 8 --eval_bs 4 --eval_acc 10 --output_len 64 \
14
- --final_eval --prompt_format QCMG-A \
15
- --eval_le models/rationale/predictions_ans_eval.json \
16
- --test_le models/rationale/predictions_ans_test.json \
17
- --evaluate_dir models/answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mm-cot/mm-cot/run_training.sh DELETED
@@ -1,15 +0,0 @@
1
- # rationale generation
2
- CUDA_VISIBLE_DEVICES=0,1 python main.py \
3
- --model allenai/unifiedqa-t5-base \
4
- --user_msg rationale --img_type detr \
5
- --bs 8 --eval_bs 4 --eval_acc 10 --output_len 512 \
6
- --final_eval --prompt_format QCM-LE
7
-
8
- # answer inference
9
- CUDA_VISIBLE_DEVICES=0,1 python main.py \
10
- --model allenai/unifiedqa-t5-base \
11
- --user_msg answer --img_type detr \
12
- --bs 8 --eval_bs 4 --eval_acc 10 --output_len 64 \
13
- --final_eval --prompt_format QCMG-A \
14
- --eval_le experiments/rationale_allenai-unifiedqa-t5-base_detr_QCM-LE_lr5e-05_bs16_op512_ep20/predictions_ans_eval.json \
15
- --test_le experiments/rationale_allenai-unifiedqa-t5-base_detr_QCM-LE_lr5e-05_bs16_op512_ep20/predictions_ans_test.json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mm-cot/mm-cot/utils_data.py DELETED
@@ -1,228 +0,0 @@
1
- import os
2
- from torch.utils.data import Dataset
3
- import os
4
- import json
5
- import numpy as np
6
- import torch
7
- from utils_prompt import *
8
-
9
- img_shape = {
10
- "resnet": (512, 2048),
11
- "clip": (49, 2048),
12
- "detr": (100, 256),
13
- }
14
-
15
- def load_data_std(args):
16
- problems = json.load(open(os.path.join(args.data_root, 'scienceqa/problems.json')))
17
- pid_splits = json.load(open(os.path.join(args.data_root, 'scienceqa/pid_splits.json')))
18
- captions = json.load(open(args.caption_file))["captions"]
19
-
20
- for qid in problems:
21
- problems[qid]['caption'] = captions[qid] if qid in captions else ""
22
-
23
- train_qids = pid_splits['%s' % (args.train_split)]
24
- val_qids = pid_splits['%s' % (args.val_split)]
25
- test_qids = pid_splits['%s' % (args.test_split)]
26
- print(f"number of train problems: {len(train_qids)}\n")
27
- print(f"number of val problems: {len(val_qids)}\n")
28
- print(f"number of test problems: {len(test_qids)}\n")
29
-
30
- qids = {'train': train_qids, 'val':val_qids,'test':test_qids}
31
- return problems, qids,
32
-
33
- def load_data_img(args):
34
- problems = json.load(open(os.path.join(args.data_root, 'scienceqa/problems.json')))
35
- pid_splits = json.load(open(os.path.join(args.data_root, 'scienceqa/pid_splits.json')))
36
- captions = json.load(open(args.caption_file))["captions"]
37
- name_maps = json.load(open('vision_features/name_map.json'))
38
-
39
- # check
40
- if args.img_type == "resnet":
41
- image_features = np.load('vision_features/resnet.npy')
42
- image_features = np.expand_dims(image_features, axis=1)
43
- image_features = image_features.repeat(512, axis=1)
44
- elif args.img_type == "clip":
45
- image_features = np.load('vision_features/clip.npy')
46
- elif args.img_type == "detr":
47
- image_features = np.load('vision_features/detr.npy')
48
- else:
49
- image_features = np.load('vision_features/detr.npy')
50
- print("img_features size: ", image_features.shape)
51
-
52
- for qid in problems:
53
- problems[qid]['caption'] = captions[qid] if qid in captions else ""
54
-
55
- train_qids = pid_splits['%s' % (args.train_split)]
56
- val_qids = pid_splits['%s' % (args.val_split)]
57
- test_qids = pid_splits['%s' % (args.test_split)]
58
- print(f"number of train problems: {len(train_qids)}\n")
59
- print(f"number of val problems: {len(val_qids)}\n")
60
- print(f"number of test problems: {len(test_qids)}\n")
61
-
62
- qids = {'train': train_qids, 'val':val_qids,'test':test_qids}
63
- return problems, qids, name_maps, image_features
64
-
65
- class ScienceQADatasetStd(Dataset):
66
- """
67
- Creating a custom dataset for reading the dataset and
68
- loading it into the dataloader to pass it to the
69
- neural network for finetuning the model
70
-
71
- """
72
-
73
- def __init__(
74
- self, problems, qids, tokenizer, source_len, target_len, args, test_le=None
75
- ):
76
- self.tokenizer = tokenizer
77
- self.data = {qid : problems[qid] for qid in qids}
78
- self.source_len = source_len
79
- self.summ_len = target_len
80
- self.target_text = []
81
- self.source_text = []
82
- if test_le is not None:
83
- test_le_data =json.load(open(test_le))["preds"]
84
- else:
85
- test_le_data = None
86
- idx = 0
87
- for qid in self.data:
88
- if test_le_data is not None:
89
- curr_le_data = test_le_data[idx]
90
- idx += 1
91
- else:
92
- curr_le_data = None
93
- prompt, target = build_train_pair(problems, qid, args, curr_le_data)
94
- self.target_text.append(target)
95
- self.source_text.append(prompt)
96
-
97
- def __len__(self):
98
- return len(self.target_text)
99
-
100
- def __getitem__(self, index):
101
- source_text = str(self.source_text[index])
102
- target_text = str(self.target_text[index])
103
-
104
- # cleaning data so as to ensure data is in string type
105
- source_text = " ".join(source_text.split())
106
- target_text = " ".join(target_text.split())
107
-
108
- source = self.tokenizer.batch_encode_plus(
109
- [source_text],
110
- max_length=self.source_len,
111
- pad_to_max_length=True,
112
- truncation=True,
113
- padding="max_length",
114
- return_tensors="pt",
115
- )
116
- target = self.tokenizer.batch_encode_plus(
117
- [target_text],
118
- max_length=self.summ_len,
119
- pad_to_max_length=True,
120
- truncation=True,
121
- padding="max_length",
122
- return_tensors="pt",
123
- )
124
- source_ids = source["input_ids"].squeeze()
125
- source_mask = source["attention_mask"].squeeze()
126
- target_ids = target["input_ids"].squeeze().tolist()
127
-
128
- return {
129
- "input_ids": source_ids,
130
- "attention_mask": source_mask,
131
- "labels": target_ids,
132
- }
133
-
134
-
135
- class ScienceQADatasetImg(Dataset):
136
- """
137
- Creating a custom dataset for reading the dataset and
138
- loading it into the dataloader to pass it to the
139
- neural network for finetuning the model
140
-
141
- """
142
-
143
- def __init__(
144
- self, problems, qids, name_maps, tokenizer, source_len, target_len, args, image_features, test_le=None
145
- ):
146
- """
147
- Initializes a Dataset class
148
-
149
- Args:
150
- dataframe (pandas.DataFrame): Input dataframe
151
- tokenizer (transformers.tokenizer): Transformers tokenizer
152
- source_len (int): Max length of source text
153
- target_len (int): Max length of target text
154
- source_text (str): column name of source text
155
- target_text (str): column name of target text
156
- """
157
- self.tokenizer = tokenizer
158
- self.data = {qid : problems[qid] for qid in qids}
159
- self.source_len = source_len
160
- self.summ_len = target_len
161
- self.target_text = []
162
- self.source_text = []
163
- self.image_ids = []
164
- if test_le is not None:
165
- test_le_data =json.load(open(test_le))["preds"]
166
- else:
167
- test_le_data = None
168
- idx = 0
169
- for qid in self.data:
170
- if test_le_data is not None:
171
- curr_le_data = test_le_data[idx]
172
- idx += 1
173
- else:
174
- curr_le_data = None
175
- prompt, target = build_train_pair(problems, qid, args, curr_le_data)
176
- self.target_text.append(target)
177
- self.source_text.append(prompt)
178
- if str(qid) in name_maps:
179
- i_vectors = image_features[int(name_maps[str(qid)])]
180
- self.image_ids.append(i_vectors)
181
- else:
182
- shape = img_shape[args.img_type]
183
- self.image_ids.append(np.zeros(shape))
184
-
185
- def __len__(self):
186
- """returns the length of dataframe"""
187
-
188
- return len(self.target_text)
189
-
190
- def __getitem__(self, index):
191
- """return the input ids, attention masks and target ids"""
192
-
193
- source_text = str(self.source_text[index])
194
- target_text = str(self.target_text[index])
195
- image_ids = self.image_ids[index]
196
-
197
- # cleaning data so as to ensure data is in string type
198
- source_text = " ".join(source_text.split())
199
- target_text = " ".join(target_text.split())
200
-
201
- source = self.tokenizer.batch_encode_plus(
202
- [source_text],
203
- max_length=self.source_len,
204
- pad_to_max_length=True,
205
- truncation=True,
206
- padding="max_length",
207
- return_tensors="pt",
208
- )
209
- target = self.tokenizer.batch_encode_plus(
210
- [target_text],
211
- max_length=self.summ_len,
212
- pad_to_max_length=True,
213
- truncation=True,
214
- padding="max_length",
215
- return_tensors="pt",
216
- )
217
- source_ids = source["input_ids"].squeeze()
218
- source_mask = source["attention_mask"].squeeze()
219
- target_ids = target["input_ids"].squeeze().tolist()
220
-
221
- image_ids = torch.tensor(image_ids).squeeze()
222
-
223
- return {
224
- "input_ids": source_ids,
225
- "attention_mask": source_mask,
226
- "image_ids": image_ids,
227
- "labels": target_ids,
228
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mm-cot/mm-cot/utils_evaluate.py DELETED
@@ -1,108 +0,0 @@
1
- '''
2
- Adapted from https://github.com/lupantech/ScienceQA
3
- '''
4
-
5
- import os
6
- import json
7
- import argparse
8
- import warnings
9
- import pandas as pd
10
- from sentence_transformers import SentenceTransformer
11
- from evaluations import caculate_bleu, caculate_rouge, caculate_similariry
12
-
13
- warnings.filterwarnings('ignore')
14
-
15
- def get_acc_with_contion(res_pd, key, values):
16
- if isinstance(values, list):
17
- total_pd = res_pd[res_pd[key].isin(values)]
18
- else:
19
- total_pd = res_pd[res_pd[key] == values]
20
- correct_pd = total_pd[total_pd['true_false'] == True]
21
- acc = "{:.2f}".format(len(correct_pd) / len(total_pd) * 100)
22
- return acc
23
-
24
-
25
- def get_scores(result_data, rationale_data, results_reference, data_file):
26
- # read result file
27
- results = result_data
28
- num = len(results)
29
- assert num == 4241
30
- #print("number of questions:", num)
31
-
32
- # read data file
33
- sqa_data = json.load(open(data_file))
34
-
35
- # construct pandas data
36
- sqa_pd = pd.DataFrame(sqa_data).T
37
- res_pd = sqa_pd[sqa_pd['split'] == 'test'] # test set
38
-
39
- # update data
40
- for index, row in res_pd.iterrows():
41
-
42
- res_pd.loc[index, 'no_context'] = True if (not row['hint'] and not row['image']) else False
43
- res_pd.loc[index, 'has_text'] = True if row['hint'] else False
44
- res_pd.loc[index, 'has_image'] = True if row['image'] else False
45
- res_pd.loc[index, 'has_text_image'] = True if (row['hint'] and row['image']) else False
46
-
47
- label = row['answer']
48
- pred = int(results[index])
49
- res_pd.loc[index, 'pred'] = pred
50
- res_pd.loc[index, 'true_false'] = (label == pred)
51
-
52
- # accuracy scores
53
- acc_average = len(res_pd[res_pd['true_false'] == True]) / num * 100
54
- #assert result_file.split('_')[-1] == "{:.3f}.json".format(acc_average)
55
-
56
-
57
- # rationale quality
58
-
59
- ## BLEU
60
- bleu1 = caculate_bleu(rationale_data, results_reference, gram=1)
61
- bleu4 = caculate_bleu(rationale_data, results_reference, gram=4)
62
-
63
- ## Rouge-L
64
- rouge = caculate_rouge(rationale_data, results_reference)
65
-
66
- ## Similarity
67
- model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2').cuda()
68
- similariry = caculate_similariry(rationale_data, results_reference, model)
69
-
70
- scores = {
71
- "answer":{
72
- 'acc_natural':
73
- get_acc_with_contion(res_pd, 'subject', 'natural science'),
74
- 'acc_social':
75
- get_acc_with_contion(res_pd, 'subject', 'social science'),
76
- 'acc_language':
77
- get_acc_with_contion(res_pd, 'subject', 'language science'),
78
- 'acc_has_text':
79
- get_acc_with_contion(res_pd, 'has_text', True),
80
- 'acc_has_image':
81
- get_acc_with_contion(res_pd, 'has_image', True),
82
- 'acc_no_context':
83
- get_acc_with_contion(res_pd, 'no_context', True),
84
- 'acc_grade_1_6':
85
- get_acc_with_contion(res_pd, 'grade', ['grade1', 'grade2', 'grade3', 'grade4', 'grade5', 'grade6']),
86
- 'acc_grade_7_12':
87
- get_acc_with_contion(res_pd, 'grade', ['grade7', 'grade8', 'grade9', 'grade10', 'grade11', 'grade12']),
88
- 'acc_average':
89
- "{:.2f}".format(acc_average),
90
- },
91
- "rationale":{
92
- 'bleu1': bleu1 * 100,
93
- 'bleu4': bleu4 * 100,
94
- 'rouge': rouge * 100,
95
- 'similariry': similariry * 100,
96
- }
97
- }
98
-
99
- return scores
100
-
101
-
102
- def print_scores(scores):
103
- latex_output = ""
104
- for key, score in scores.items():
105
- print(f"{key[4:]}: \t{score}")
106
- latex_output += f"& {score} "
107
- latex_output += "\\\\"
108
- print(latex_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mm-cot/mm-cot/utils_prompt.py DELETED
@@ -1,240 +0,0 @@
1
- '''
2
- Adapted from https://github.com/lupantech/ScienceQA
3
- '''
4
-
5
- from dataclasses import dataclass
6
- from typing import List, Optional
7
-
8
- def get_question_text(problem):
9
- question = problem['question']
10
- return question
11
-
12
-
13
- def get_context_text(problem, use_caption):
14
- txt_context = problem['hint']
15
- img_context = problem['caption'] if use_caption else ""
16
- context = " ".join([txt_context, img_context]).strip()
17
- if context == "":
18
- context = "N/A"
19
- return context
20
-
21
-
22
- def get_choice_text(probelm, options):
23
- choices = probelm['choices']
24
- choice_list = []
25
- for i, c in enumerate(choices):
26
- choice_list.append("({}) {}".format(options[i], c))
27
- choice_txt = " ".join(choice_list)
28
- #print(choice_txt)
29
- return choice_txt
30
-
31
- def get_origin_answer(problem, options):
32
- return problem['choices'][problem['answer']]
33
-
34
- def get_answer(problem, options):
35
- return options[problem['answer']]
36
-
37
-
38
- def get_lecture_text(problem):
39
- # \\n: GPT-3 can generate the lecture with more tokens.
40
- lecture = problem['lecture'].replace("\n", "\\n")
41
- return lecture
42
-
43
-
44
- def get_solution_text(problem):
45
- # \\n: GPT-3 can generate the solution with more tokens
46
- solution = problem['solution'].replace("\n", "\\n")
47
- return solution
48
-
49
-
50
- def create_one_example(format, question, context, choice, answer, lecture, solution, test_example=True, WithOutput = False, curr_le_data=None):
51
-
52
- input_format, output_format = format.split("-")
53
-
54
- ## Inputs
55
- if input_format == "CQM":
56
- input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
57
- elif input_format == "QCM":
58
- input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
59
- elif input_format == "QM":
60
- input = f"Question: {question}\nOptions: {choice}\n"
61
- elif input_format == "QC":
62
- input = f"Question: {question}\nContext: {context}\n"
63
- elif input_format == "QCMG":
64
- if curr_le_data is not None:
65
- input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n{curr_le_data}\n"
66
- else:
67
- input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nSolution: {lecture} {solution}\n"
68
- elif input_format == "CQMG":
69
- if curr_le_data is not None:
70
- input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n{curr_le_data}\n"
71
- else:
72
- input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\nSolution: {lecture} {solution}\n"
73
- # upper bound experiment
74
- elif input_format == "QCML":
75
- input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
76
- elif input_format == "QCME":
77
- input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
78
- elif input_format == "QCMLE":
79
- input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
80
-
81
- elif input_format == "QCLM":
82
- input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
83
- elif input_format == "QCEM":
84
- input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
85
- elif input_format == "QCLEM":
86
- input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
87
- elif input_format == "QCMA":
88
- input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nAnswer: The answer is {answer}.\n"
89
- elif input_format == "QCA":
90
- input = f"Question: {question}\nContext: {context}\nAnswer: The answer is {answer}. \nBECAUSE:"
91
-
92
- # Outputs
93
- if test_example:
94
- if output_format == 'A':
95
- output = "Answer:"
96
- elif output_format == 'E':
97
- output = "Solution:"
98
- else:
99
- output = "Solution:"
100
- elif output_format == 'A':
101
- output = f"Answer: The answer is {answer}."
102
-
103
- elif output_format == 'AL':
104
- output = f"Answer: The answer is {answer}. BECAUSE: {solution}"
105
- elif output_format == 'AE':
106
- output = f"Answer: The answer is {answer}. BECAUSE: {lecture}"
107
- elif output_format == 'ALE':
108
- output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}"
109
- elif output_format == 'AEL':
110
- output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}"
111
-
112
- elif output_format == 'LA':
113
- output = f"Answer: {lecture} The answer is {answer}."
114
- elif output_format == 'EA':
115
- output = f"Answer: {solution} The answer is {answer}."
116
- elif output_format == 'LEA':
117
- output = f"Answer: {lecture} {solution} The answer is {answer}."
118
- elif output_format == 'ELA':
119
- output = f"Answer: {solution} {lecture} The answer is {answer}."
120
-
121
- elif output_format == 'LE':
122
- output = f"Solution: {lecture} {solution}."
123
-
124
- elif output_format == 'E':
125
- output = f"Solution: {solution}"
126
-
127
-
128
- if WithOutput:
129
- if output.endswith("BECAUSE:"):
130
- output = output.replace("BECAUSE:", "").strip()
131
- if output_format == 'E':
132
- text = input + f'Solution:'
133
- elif output_format == 'A':
134
- text = input + f'Answer:'
135
- else:
136
- text = input + f'Solution:'
137
- text = text.replace(" ", " ").strip()
138
- output = output.replace(" ", " ").strip()
139
- return text, output
140
-
141
-
142
- text = input + output
143
- text = text.replace(" ", " ").strip()
144
- if text.endswith("BECAUSE:"):
145
- text = text.replace("BECAUSE:", "").strip()
146
- return text
147
-
148
-
149
- def build_prompt(problems, shot_qids, test_qid, args):
150
-
151
- examples = []
152
-
153
- # n-shot training examples
154
- for qid in shot_qids:
155
- question = get_question_text(problems[qid])
156
- context = get_context_text(problems[qid], args.use_caption)
157
- choice = get_choice_text(problems[qid], args.options)
158
- answer = get_answer(problems[qid], args.options)
159
- lecture = get_lecture_text(problems[qid])
160
- solution = get_solution_text(problems[qid])
161
-
162
- train_example = create_one_example(args.prompt_format,
163
- question,
164
- context,
165
- choice,
166
- answer,
167
- lecture,
168
- solution,
169
- test_example=False)
170
- examples.append(train_example)
171
-
172
- # test example
173
- question = get_question_text(problems[test_qid])
174
- context = get_context_text(problems[test_qid], args.use_caption)
175
- choice = get_choice_text(problems[test_qid], args.options)
176
- answer = get_answer(problems[test_qid], args.options)
177
- lecture = get_lecture_text(problems[test_qid])
178
- solution = get_solution_text(problems[test_qid])
179
-
180
- test_example = create_one_example(args.prompt_format,
181
- question,
182
- context,
183
- choice,
184
- answer,
185
- lecture,
186
- solution,
187
- test_example=True)
188
- examples.append(test_example)
189
-
190
- # create the prompt input
191
- prompt_input = '\n\n'.join(examples)
192
-
193
- return prompt_input
194
-
195
- def build_train_pair(problems, test_qid, args, curr_le_data=None):
196
-
197
- examples = []
198
-
199
- # test example
200
- question = get_question_text(problems[test_qid])
201
- context = get_context_text(problems[test_qid], args.use_caption)
202
- choice = get_choice_text(problems[test_qid], args.options)
203
-
204
- lecture = get_lecture_text(problems[test_qid])
205
- solution = get_solution_text(problems[test_qid])
206
-
207
- # answer_text = get_origin_answer(problems[test_qid], args.options)
208
- answer_option = get_answer(problems[test_qid], args.options)
209
- answer = "(" + answer_option + ")"
210
-
211
- test_example, target = create_one_example(args.prompt_format,
212
- question,
213
- context,
214
- choice,
215
- answer,
216
- lecture,
217
- solution,
218
- test_example=False,WithOutput = True, curr_le_data=curr_le_data)
219
- examples.append(test_example)
220
-
221
- target = target.replace("Answer:", "").strip()
222
- # create the prompt input
223
- prompt_input = '\n\n'.join(examples)
224
-
225
- return prompt_input, target
226
-
227
- @dataclass(frozen=True)
228
- class InputFeatures:
229
- """
230
- A single set of features of data.
231
- Property names are the same names as the corresponding inputs to a model.
232
- """
233
-
234
- input_ids: List[List[int]]
235
- attention_mask: Optional[List[List[int]]]
236
- token_type_ids: Optional[List[List[int]]]
237
- le_input_ids: List[List[int]]
238
- le_attention_mask: Optional[List[List[int]]]
239
- le_token_type_ids: Optional[List[List[int]]]
240
- label: Optional[int]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mm-cot/mm-cot/vision_features/mm-cot.png DELETED
Binary file (893 kB)
 
mm-cot/model.py DELETED
@@ -1,194 +0,0 @@
1
- '''
2
- Adapted from https://github.com/huggingface/transformers
3
- '''
4
-
5
- from transformers import T5Config, T5ForConditionalGeneration
6
- from transformers.models.t5.modeling_t5 import T5Stack, __HEAD_MASK_WARNING_MSG, T5EncoderModel
7
- import copy
8
- import math
9
- import os
10
- import warnings
11
- from typing import Optional, Tuple, Union
12
- import torch
13
- from torch import nn
14
- from torch.nn import CrossEntropyLoss
15
- from transformers.modeling_outputs import (
16
- BaseModelOutput,
17
- Seq2SeqLMOutput,
18
- )
19
-
20
- class T5ForMultimodalGeneration(T5ForConditionalGeneration):
21
- _keys_to_ignore_on_load_missing = [
22
- r"encoder.embed_tokens.weight",
23
- r"decoder.embed_tokens.weight",
24
- r"lm_head.weight",
25
- ]
26
- _keys_to_ignore_on_load_unexpected = [
27
- r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
28
- ]
29
-
30
- def __init__(self, config: T5Config, patch_size, padding_idx, save_dir):
31
- super().__init__(config)
32
- self.model_dim = config.d_model
33
-
34
- self.padding_idx = padding_idx
35
- self.out = open(os.path.join(save_dir, 'gate.txt'), 'w')
36
-
37
- self.shared = nn.Embedding(config.vocab_size, config.d_model)
38
- self.patch_num, self.patch_dim = patch_size
39
-
40
- self.image_dense = nn.Linear(self.patch_dim, config.d_model)
41
- self.mha_layer = torch.nn.MultiheadAttention(embed_dim=config.hidden_size, kdim=config.hidden_size, vdim=config.hidden_size, num_heads=1, batch_first=True)
42
- self.gate_dense = nn.Linear(2*config.hidden_size, config.hidden_size)
43
- self.sigmoid = nn.Sigmoid()
44
-
45
- encoder_config = copy.deepcopy(config)
46
- encoder_config.is_decoder = False
47
- encoder_config.use_cache = False
48
- encoder_config.is_encoder_decoder = False
49
- self.encoder = T5Stack(encoder_config, self.shared)
50
-
51
- decoder_config = copy.deepcopy(config)
52
- decoder_config.is_decoder = True
53
- decoder_config.is_encoder_decoder = False
54
- decoder_config.num_layers = config.num_decoder_layers
55
- self.decoder = T5Stack(decoder_config, self.shared)
56
-
57
- self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
58
-
59
- # Initialize weights and apply final processing
60
- self.post_init()
61
-
62
- # Model parallel
63
- self.model_parallel = False
64
- self.device_map = None
65
-
66
- def forward(
67
- self,
68
- input_ids: Optional[torch.LongTensor] = None,
69
- image_ids=None,
70
- attention_mask: Optional[torch.FloatTensor] = None,
71
- decoder_input_ids: Optional[torch.LongTensor] = None,
72
- decoder_attention_mask: Optional[torch.BoolTensor] = None,
73
- head_mask: Optional[torch.FloatTensor] = None,
74
- decoder_head_mask: Optional[torch.FloatTensor] = None,
75
- cross_attn_head_mask: Optional[torch.Tensor] = None,
76
- encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
77
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
78
- inputs_embeds: Optional[torch.FloatTensor] = None,
79
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
80
- labels: Optional[torch.LongTensor] = None,
81
- use_cache: Optional[bool] = None,
82
- output_attentions: Optional[bool] = None,
83
- output_hidden_states: Optional[bool] = None,
84
- return_dict: Optional[bool] = None,
85
- ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
86
- use_cache = use_cache if use_cache is not None else self.config.use_cache
87
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
88
-
89
- # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
90
- if head_mask is not None and decoder_head_mask is None:
91
- if self.config.num_layers == self.config.num_decoder_layers:
92
- warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
93
- decoder_head_mask = head_mask
94
-
95
- # Encode if needed (training, first prediction pass)
96
- if encoder_outputs is None:
97
- # Convert encoder inputs in embeddings if needed
98
- encoder_outputs = self.encoder(
99
- input_ids=input_ids,
100
- attention_mask=attention_mask,
101
- inputs_embeds=inputs_embeds,
102
- head_mask=head_mask,
103
- output_attentions=output_attentions,
104
- output_hidden_states=output_hidden_states,
105
- return_dict=return_dict,
106
- )
107
-
108
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
109
- encoder_outputs = BaseModelOutput(
110
- last_hidden_state=encoder_outputs[0],
111
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
112
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
113
- )
114
-
115
-
116
- hidden_states = encoder_outputs[0]
117
-
118
- image_embedding = self.image_dense(image_ids)
119
- image_att, _ = self.mha_layer(hidden_states, image_embedding, image_embedding)
120
-
121
- merge = torch.cat([hidden_states, image_att], dim=-1)
122
- gate = self.sigmoid(self.gate_dense(merge))
123
- hidden_states = (1 - gate) * hidden_states + gate * image_att
124
-
125
- if self.model_parallel:
126
- torch.cuda.set_device(self.decoder.first_device)
127
-
128
- if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
129
- # get decoder inputs from shifting lm labels to the right
130
- decoder_input_ids = self._shift_right(labels)
131
-
132
- # Set device for model parallelism
133
- if self.model_parallel:
134
- torch.cuda.set_device(self.decoder.first_device)
135
- hidden_states = hidden_states.to(self.decoder.first_device)
136
- if decoder_input_ids is not None:
137
- decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
138
- if attention_mask is not None:
139
- attention_mask = attention_mask.to(self.decoder.first_device)
140
- if decoder_attention_mask is not None:
141
- decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
142
-
143
- # Decode
144
- decoder_outputs = self.decoder(
145
- input_ids=decoder_input_ids,
146
- attention_mask=decoder_attention_mask,
147
- inputs_embeds=decoder_inputs_embeds,
148
- past_key_values=past_key_values,
149
- encoder_hidden_states=hidden_states,
150
- encoder_attention_mask=attention_mask,
151
- head_mask=decoder_head_mask,
152
- cross_attn_head_mask=cross_attn_head_mask,
153
- use_cache=use_cache,
154
- output_attentions=output_attentions,
155
- output_hidden_states=output_hidden_states,
156
- return_dict=return_dict,
157
- )
158
-
159
- sequence_output = decoder_outputs[0]
160
-
161
- # Set device for model parallelism
162
- if self.model_parallel:
163
- torch.cuda.set_device(self.encoder.first_device)
164
- self.lm_head = self.lm_head.to(self.encoder.first_device)
165
- sequence_output = sequence_output.to(self.lm_head.weight.device)
166
-
167
- if self.config.tie_word_embeddings:
168
- # Rescale output before projecting on vocab
169
- # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
170
- sequence_output = sequence_output * (self.model_dim**-0.5)
171
-
172
- lm_logits = self.lm_head(sequence_output)
173
-
174
- loss = None
175
- if labels is not None:
176
- loss_fct = CrossEntropyLoss(ignore_index=-100)
177
- loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
178
- # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
179
-
180
- if not return_dict:
181
- output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
182
- return ((loss,) + output) if loss is not None else output
183
-
184
- return Seq2SeqLMOutput(
185
- loss=loss,
186
- logits=lm_logits,
187
- past_key_values=decoder_outputs.past_key_values,
188
- decoder_hidden_states=decoder_outputs.hidden_states,
189
- decoder_attentions=decoder_outputs.attentions,
190
- cross_attentions=decoder_outputs.cross_attentions,
191
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
192
- encoder_hidden_states=encoder_outputs.hidden_states,
193
- encoder_attentions=encoder_outputs.attentions,
194
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mm-cot/requirements.txt DELETED
@@ -1,11 +0,0 @@
1
- huggingface-hub==0.0.12
2
- numpy==1.23.2
3
- openai==0.23.0
4
- pandas==1.4.3
5
- rouge==1.0.1
6
- sentence-transformers==2.2.2
7
- transformers==4.21.1
8
- nltk==3.6.6
9
- evaluate==0.4.0
10
- rouge==1.0.1
11
- rouge_score==0.1.2
 
 
 
 
 
 
 
 
 
 
 
 
mm-cot/run_inference.sh DELETED
@@ -1,17 +0,0 @@
1
- # rationale generation
2
- CUDA_VISIBLE_DEVICES=0,1 python main.py \
3
- --model allenai/unifiedqa-t5-base \
4
- --user_msg rationale --img_type detr \
5
- --bs 8 --eval_bs 4 --eval_acc 10 --output_len 512 \
6
- --final_eval --prompt_format QCM-LE \
7
- --evaluate_dir models/rationale
8
-
9
- # answer inference
10
- CUDA_VISIBLE_DEVICES=0,1 python main.py \
11
- --model allenai/unifiedqa-t5-base \
12
- --user_msg answer --img_type detr \
13
- --bs 8 --eval_bs 4 --eval_acc 10 --output_len 64 \
14
- --final_eval --prompt_format QCMG-A \
15
- --eval_le models/rationale/predictions_ans_eval.json \
16
- --test_le models/rationale/predictions_ans_test.json \
17
- --evaluate_dir models/answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mm-cot/run_training.sh DELETED
@@ -1,15 +0,0 @@
1
- # rationale generation
2
- CUDA_VISIBLE_DEVICES=0,1 python main.py \
3
- --model allenai/unifiedqa-t5-base \
4
- --user_msg rationale --img_type detr \
5
- --bs 8 --eval_bs 4 --eval_acc 10 --output_len 512 \
6
- --final_eval --prompt_format QCM-LE
7
-
8
- # answer inference
9
- CUDA_VISIBLE_DEVICES=0,1 python main.py \
10
- --model allenai/unifiedqa-t5-base \
11
- --user_msg answer --img_type detr \
12
- --bs 8 --eval_bs 4 --eval_acc 10 --output_len 64 \
13
- --final_eval --prompt_format QCMG-A \
14
- --eval_le experiments/rationale_allenai-unifiedqa-t5-base_detr_QCM-LE_lr5e-05_bs16_op512_ep20/predictions_ans_eval.json \
15
- --test_le experiments/rationale_allenai-unifiedqa-t5-base_detr_QCM-LE_lr5e-05_bs16_op512_ep20/predictions_ans_test.json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mm-cot/utils_data.py DELETED
@@ -1,228 +0,0 @@
1
- import os
2
- from torch.utils.data import Dataset
3
- import os
4
- import json
5
- import numpy as np
6
- import torch
7
- from utils_prompt import *
8
-
9
- img_shape = {
10
- "resnet": (512, 2048),
11
- "clip": (49, 2048),
12
- "detr": (100, 256),
13
- }
14
-
15
- def load_data_std(args):
16
- problems = json.load(open(os.path.join(args.data_root, 'scienceqa/problems.json')))
17
- pid_splits = json.load(open(os.path.join(args.data_root, 'scienceqa/pid_splits.json')))
18
- captions = json.load(open(args.caption_file))["captions"]
19
-
20
- for qid in problems:
21
- problems[qid]['caption'] = captions[qid] if qid in captions else ""
22
-
23
- train_qids = pid_splits['%s' % (args.train_split)]
24
- val_qids = pid_splits['%s' % (args.val_split)]
25
- test_qids = pid_splits['%s' % (args.test_split)]
26
- print(f"number of train problems: {len(train_qids)}\n")
27
- print(f"number of val problems: {len(val_qids)}\n")
28
- print(f"number of test problems: {len(test_qids)}\n")
29
-
30
- qids = {'train': train_qids, 'val':val_qids,'test':test_qids}
31
- return problems, qids,
32
-
33
- def load_data_img(args):
34
- problems = json.load(open(os.path.join(args.data_root, 'scienceqa/problems.json')))
35
- pid_splits = json.load(open(os.path.join(args.data_root, 'scienceqa/pid_splits.json')))
36
- captions = json.load(open(args.caption_file))["captions"]
37
- name_maps = json.load(open('vision_features/name_map.json'))
38
-
39
- # check
40
- if args.img_type == "resnet":
41
- image_features = np.load('vision_features/resnet.npy')
42
- image_features = np.expand_dims(image_features, axis=1)
43
- image_features = image_features.repeat(512, axis=1)
44
- elif args.img_type == "clip":
45
- image_features = np.load('vision_features/clip.npy')
46
- elif args.img_type == "detr":
47
- image_features = np.load('vision_features/detr.npy')
48
- else:
49
- image_features = np.load('vision_features/detr.npy')
50
- print("img_features size: ", image_features.shape)
51
-
52
- for qid in problems:
53
- problems[qid]['caption'] = captions[qid] if qid in captions else ""
54
-
55
- train_qids = pid_splits['%s' % (args.train_split)]
56
- val_qids = pid_splits['%s' % (args.val_split)]
57
- test_qids = pid_splits['%s' % (args.test_split)]
58
- print(f"number of train problems: {len(train_qids)}\n")
59
- print(f"number of val problems: {len(val_qids)}\n")
60
- print(f"number of test problems: {len(test_qids)}\n")
61
-
62
- qids = {'train': train_qids, 'val':val_qids,'test':test_qids}
63
- return problems, qids, name_maps, image_features
64
-
65
- class ScienceQADatasetStd(Dataset):
66
- """
67
- Creating a custom dataset for reading the dataset and
68
- loading it into the dataloader to pass it to the
69
- neural network for finetuning the model
70
-
71
- """
72
-
73
- def __init__(
74
- self, problems, qids, tokenizer, source_len, target_len, args, test_le=None
75
- ):
76
- self.tokenizer = tokenizer
77
- self.data = {qid : problems[qid] for qid in qids}
78
- self.source_len = source_len
79
- self.summ_len = target_len
80
- self.target_text = []
81
- self.source_text = []
82
- if test_le is not None:
83
- test_le_data =json.load(open(test_le))["preds"]
84
- else:
85
- test_le_data = None
86
- idx = 0
87
- for qid in self.data:
88
- if test_le_data is not None:
89
- curr_le_data = test_le_data[idx]
90
- idx += 1
91
- else:
92
- curr_le_data = None
93
- prompt, target = build_train_pair(problems, qid, args, curr_le_data)
94
- self.target_text.append(target)
95
- self.source_text.append(prompt)
96
-
97
- def __len__(self):
98
- return len(self.target_text)
99
-
100
- def __getitem__(self, index):
101
- source_text = str(self.source_text[index])
102
- target_text = str(self.target_text[index])
103
-
104
- # cleaning data so as to ensure data is in string type
105
- source_text = " ".join(source_text.split())
106
- target_text = " ".join(target_text.split())
107
-
108
- source = self.tokenizer.batch_encode_plus(
109
- [source_text],
110
- max_length=self.source_len,
111
- pad_to_max_length=True,
112
- truncation=True,
113
- padding="max_length",
114
- return_tensors="pt",
115
- )
116
- target = self.tokenizer.batch_encode_plus(
117
- [target_text],
118
- max_length=self.summ_len,
119
- pad_to_max_length=True,
120
- truncation=True,
121
- padding="max_length",
122
- return_tensors="pt",
123
- )
124
- source_ids = source["input_ids"].squeeze()
125
- source_mask = source["attention_mask"].squeeze()
126
- target_ids = target["input_ids"].squeeze().tolist()
127
-
128
- return {
129
- "input_ids": source_ids,
130
- "attention_mask": source_mask,
131
- "labels": target_ids,
132
- }
133
-
134
-
135
- class ScienceQADatasetImg(Dataset):
136
- """
137
- Creating a custom dataset for reading the dataset and
138
- loading it into the dataloader to pass it to the
139
- neural network for finetuning the model
140
-
141
- """
142
-
143
- def __init__(
144
- self, problems, qids, name_maps, tokenizer, source_len, target_len, args, image_features, test_le=None
145
- ):
146
- """
147
- Initializes a Dataset class
148
-
149
- Args:
150
- dataframe (pandas.DataFrame): Input dataframe
151
- tokenizer (transformers.tokenizer): Transformers tokenizer
152
- source_len (int): Max length of source text
153
- target_len (int): Max length of target text
154
- source_text (str): column name of source text
155
- target_text (str): column name of target text
156
- """
157
- self.tokenizer = tokenizer
158
- self.data = {qid : problems[qid] for qid in qids}
159
- self.source_len = source_len
160
- self.summ_len = target_len
161
- self.target_text = []
162
- self.source_text = []
163
- self.image_ids = []
164
- if test_le is not None:
165
- test_le_data =json.load(open(test_le))["preds"]
166
- else:
167
- test_le_data = None
168
- idx = 0
169
- for qid in self.data:
170
- if test_le_data is not None:
171
- curr_le_data = test_le_data[idx]
172
- idx += 1
173
- else:
174
- curr_le_data = None
175
- prompt, target = build_train_pair(problems, qid, args, curr_le_data)
176
- self.target_text.append(target)
177
- self.source_text.append(prompt)
178
- if str(qid) in name_maps:
179
- i_vectors = image_features[int(name_maps[str(qid)])]
180
- self.image_ids.append(i_vectors)
181
- else:
182
- shape = img_shape[args.img_type]
183
- self.image_ids.append(np.zeros(shape))
184
-
185
- def __len__(self):
186
- """returns the length of dataframe"""
187
-
188
- return len(self.target_text)
189
-
190
- def __getitem__(self, index):
191
- """return the input ids, attention masks and target ids"""
192
-
193
- source_text = str(self.source_text[index])
194
- target_text = str(self.target_text[index])
195
- image_ids = self.image_ids[index]
196
-
197
- # cleaning data so as to ensure data is in string type
198
- source_text = " ".join(source_text.split())
199
- target_text = " ".join(target_text.split())
200
-
201
- source = self.tokenizer.batch_encode_plus(
202
- [source_text],
203
- max_length=self.source_len,
204
- pad_to_max_length=True,
205
- truncation=True,
206
- padding="max_length",
207
- return_tensors="pt",
208
- )
209
- target = self.tokenizer.batch_encode_plus(
210
- [target_text],
211
- max_length=self.summ_len,
212
- pad_to_max_length=True,
213
- truncation=True,
214
- padding="max_length",
215
- return_tensors="pt",
216
- )
217
- source_ids = source["input_ids"].squeeze()
218
- source_mask = source["attention_mask"].squeeze()
219
- target_ids = target["input_ids"].squeeze().tolist()
220
-
221
- image_ids = torch.tensor(image_ids).squeeze()
222
-
223
- return {
224
- "input_ids": source_ids,
225
- "attention_mask": source_mask,
226
- "image_ids": image_ids,
227
- "labels": target_ids,
228
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mm-cot/utils_evaluate.py DELETED
@@ -1,108 +0,0 @@
1
- '''
2
- Adapted from https://github.com/lupantech/ScienceQA
3
- '''
4
-
5
- import os
6
- import json
7
- import argparse
8
- import warnings
9
- import pandas as pd
10
- from sentence_transformers import SentenceTransformer
11
- from evaluations import caculate_bleu, caculate_rouge, caculate_similariry
12
-
13
- warnings.filterwarnings('ignore')
14
-
15
- def get_acc_with_contion(res_pd, key, values):
16
- if isinstance(values, list):
17
- total_pd = res_pd[res_pd[key].isin(values)]
18
- else:
19
- total_pd = res_pd[res_pd[key] == values]
20
- correct_pd = total_pd[total_pd['true_false'] == True]
21
- acc = "{:.2f}".format(len(correct_pd) / len(total_pd) * 100)
22
- return acc
23
-
24
-
25
- def get_scores(result_data, rationale_data, results_reference, data_file):
26
- # read result file
27
- results = result_data
28
- num = len(results)
29
- assert num == 4241
30
- #print("number of questions:", num)
31
-
32
- # read data file
33
- sqa_data = json.load(open(data_file))
34
-
35
- # construct pandas data
36
- sqa_pd = pd.DataFrame(sqa_data).T
37
- res_pd = sqa_pd[sqa_pd['split'] == 'test'] # test set
38
-
39
- # update data
40
- for index, row in res_pd.iterrows():
41
-
42
- res_pd.loc[index, 'no_context'] = True if (not row['hint'] and not row['image']) else False
43
- res_pd.loc[index, 'has_text'] = True if row['hint'] else False
44
- res_pd.loc[index, 'has_image'] = True if row['image'] else False
45
- res_pd.loc[index, 'has_text_image'] = True if (row['hint'] and row['image']) else False
46
-
47
- label = row['answer']
48
- pred = int(results[index])
49
- res_pd.loc[index, 'pred'] = pred
50
- res_pd.loc[index, 'true_false'] = (label == pred)
51
-
52
- # accuracy scores
53
- acc_average = len(res_pd[res_pd['true_false'] == True]) / num * 100
54
- #assert result_file.split('_')[-1] == "{:.3f}.json".format(acc_average)
55
-
56
-
57
- # rationale quality
58
-
59
- ## BLEU
60
- bleu1 = caculate_bleu(rationale_data, results_reference, gram=1)
61
- bleu4 = caculate_bleu(rationale_data, results_reference, gram=4)
62
-
63
- ## Rouge-L
64
- rouge = caculate_rouge(rationale_data, results_reference)
65
-
66
- ## Similarity
67
- model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2').cuda()
68
- similariry = caculate_similariry(rationale_data, results_reference, model)
69
-
70
- scores = {
71
- "answer":{
72
- 'acc_natural':
73
- get_acc_with_contion(res_pd, 'subject', 'natural science'),
74
- 'acc_social':
75
- get_acc_with_contion(res_pd, 'subject', 'social science'),
76
- 'acc_language':
77
- get_acc_with_contion(res_pd, 'subject', 'language science'),
78
- 'acc_has_text':
79
- get_acc_with_contion(res_pd, 'has_text', True),
80
- 'acc_has_image':
81
- get_acc_with_contion(res_pd, 'has_image', True),
82
- 'acc_no_context':
83
- get_acc_with_contion(res_pd, 'no_context', True),
84
- 'acc_grade_1_6':
85
- get_acc_with_contion(res_pd, 'grade', ['grade1', 'grade2', 'grade3', 'grade4', 'grade5', 'grade6']),
86
- 'acc_grade_7_12':
87
- get_acc_with_contion(res_pd, 'grade', ['grade7', 'grade8', 'grade9', 'grade10', 'grade11', 'grade12']),
88
- 'acc_average':
89
- "{:.2f}".format(acc_average),
90
- },
91
- "rationale":{
92
- 'bleu1': bleu1 * 100,
93
- 'bleu4': bleu4 * 100,
94
- 'rouge': rouge * 100,
95
- 'similariry': similariry * 100,
96
- }
97
- }
98
-
99
- return scores
100
-
101
-
102
- def print_scores(scores):
103
- latex_output = ""
104
- for key, score in scores.items():
105
- print(f"{key[4:]}: \t{score}")
106
- latex_output += f"& {score} "
107
- latex_output += "\\\\"
108
- print(latex_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mm-cot/utils_prompt.py DELETED
@@ -1,240 +0,0 @@
1
- '''
2
- Adapted from https://github.com/lupantech/ScienceQA
3
- '''
4
-
5
- from dataclasses import dataclass
6
- from typing import List, Optional
7
-
8
- def get_question_text(problem):
9
- question = problem['question']
10
- return question
11
-
12
-
13
- def get_context_text(problem, use_caption):
14
- txt_context = problem['hint']
15
- img_context = problem['caption'] if use_caption else ""
16
- context = " ".join([txt_context, img_context]).strip()
17
- if context == "":
18
- context = "N/A"
19
- return context
20
-
21
-
22
- def get_choice_text(probelm, options):
23
- choices = probelm['choices']
24
- choice_list = []
25
- for i, c in enumerate(choices):
26
- choice_list.append("({}) {}".format(options[i], c))
27
- choice_txt = " ".join(choice_list)
28
- #print(choice_txt)
29
- return choice_txt
30
-
31
- def get_origin_answer(problem, options):
32
- return problem['choices'][problem['answer']]
33
-
34
- def get_answer(problem, options):
35
- return options[problem['answer']]
36
-
37
-
38
- def get_lecture_text(problem):
39
- # \\n: GPT-3 can generate the lecture with more tokens.
40
- lecture = problem['lecture'].replace("\n", "\\n")
41
- return lecture
42
-
43
-
44
- def get_solution_text(problem):
45
- # \\n: GPT-3 can generate the solution with more tokens
46
- solution = problem['solution'].replace("\n", "\\n")
47
- return solution
48
-
49
-
50
- def create_one_example(format, question, context, choice, answer, lecture, solution, test_example=True, WithOutput = False, curr_le_data=None):
51
-
52
- input_format, output_format = format.split("-")
53
-
54
- ## Inputs
55
- if input_format == "CQM":
56
- input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
57
- elif input_format == "QCM":
58
- input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
59
- elif input_format == "QM":
60
- input = f"Question: {question}\nOptions: {choice}\n"
61
- elif input_format == "QC":
62
- input = f"Question: {question}\nContext: {context}\n"
63
- elif input_format == "QCMG":
64
- if curr_le_data is not None:
65
- input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n{curr_le_data}\n"
66
- else:
67
- input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nSolution: {lecture} {solution}\n"
68
- elif input_format == "CQMG":
69
- if curr_le_data is not None:
70
- input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n{curr_le_data}\n"
71
- else:
72
- input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\nSolution: {lecture} {solution}\n"
73
- # upper bound experiment
74
- elif input_format == "QCML":
75
- input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
76
- elif input_format == "QCME":
77
- input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
78
- elif input_format == "QCMLE":
79
- input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
80
-
81
- elif input_format == "QCLM":
82
- input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
83
- elif input_format == "QCEM":
84
- input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
85
- elif input_format == "QCLEM":
86
- input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
87
- elif input_format == "QCMA":
88
- input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nAnswer: The answer is {answer}.\n"
89
- elif input_format == "QCA":
90
- input = f"Question: {question}\nContext: {context}\nAnswer: The answer is {answer}. \nBECAUSE:"
91
-
92
- # Outputs
93
- if test_example:
94
- if output_format == 'A':
95
- output = "Answer:"
96
- elif output_format == 'E':
97
- output = "Solution:"
98
- else:
99
- output = "Solution:"
100
- elif output_format == 'A':
101
- output = f"Answer: The answer is {answer}."
102
-
103
- elif output_format == 'AL':
104
- output = f"Answer: The answer is {answer}. BECAUSE: {solution}"
105
- elif output_format == 'AE':
106
- output = f"Answer: The answer is {answer}. BECAUSE: {lecture}"
107
- elif output_format == 'ALE':
108
- output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}"
109
- elif output_format == 'AEL':
110
- output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}"
111
-
112
- elif output_format == 'LA':
113
- output = f"Answer: {lecture} The answer is {answer}."
114
- elif output_format == 'EA':
115
- output = f"Answer: {solution} The answer is {answer}."
116
- elif output_format == 'LEA':
117
- output = f"Answer: {lecture} {solution} The answer is {answer}."
118
- elif output_format == 'ELA':
119
- output = f"Answer: {solution} {lecture} The answer is {answer}."
120
-
121
- elif output_format == 'LE':
122
- output = f"Solution: {lecture} {solution}."
123
-
124
- elif output_format == 'E':
125
- output = f"Solution: {solution}"
126
-
127
-
128
- if WithOutput:
129
- if output.endswith("BECAUSE:"):
130
- output = output.replace("BECAUSE:", "").strip()
131
- if output_format == 'E':
132
- text = input + f'Solution:'
133
- elif output_format == 'A':
134
- text = input + f'Answer:'
135
- else:
136
- text = input + f'Solution:'
137
- text = text.replace(" ", " ").strip()
138
- output = output.replace(" ", " ").strip()
139
- return text, output
140
-
141
-
142
- text = input + output
143
- text = text.replace(" ", " ").strip()
144
- if text.endswith("BECAUSE:"):
145
- text = text.replace("BECAUSE:", "").strip()
146
- return text
147
-
148
-
149
- def build_prompt(problems, shot_qids, test_qid, args):
150
-
151
- examples = []
152
-
153
- # n-shot training examples
154
- for qid in shot_qids:
155
- question = get_question_text(problems[qid])
156
- context = get_context_text(problems[qid], args.use_caption)
157
- choice = get_choice_text(problems[qid], args.options)
158
- answer = get_answer(problems[qid], args.options)
159
- lecture = get_lecture_text(problems[qid])
160
- solution = get_solution_text(problems[qid])
161
-
162
- train_example = create_one_example(args.prompt_format,
163
- question,
164
- context,
165
- choice,
166
- answer,
167
- lecture,
168
- solution,
169
- test_example=False)
170
- examples.append(train_example)
171
-
172
- # test example
173
- question = get_question_text(problems[test_qid])
174
- context = get_context_text(problems[test_qid], args.use_caption)
175
- choice = get_choice_text(problems[test_qid], args.options)
176
- answer = get_answer(problems[test_qid], args.options)
177
- lecture = get_lecture_text(problems[test_qid])
178
- solution = get_solution_text(problems[test_qid])
179
-
180
- test_example = create_one_example(args.prompt_format,
181
- question,
182
- context,
183
- choice,
184
- answer,
185
- lecture,
186
- solution,
187
- test_example=True)
188
- examples.append(test_example)
189
-
190
- # create the prompt input
191
- prompt_input = '\n\n'.join(examples)
192
-
193
- return prompt_input
194
-
195
- def build_train_pair(problems, test_qid, args, curr_le_data=None):
196
-
197
- examples = []
198
-
199
- # test example
200
- question = get_question_text(problems[test_qid])
201
- context = get_context_text(problems[test_qid], args.use_caption)
202
- choice = get_choice_text(problems[test_qid], args.options)
203
-
204
- lecture = get_lecture_text(problems[test_qid])
205
- solution = get_solution_text(problems[test_qid])
206
-
207
- # answer_text = get_origin_answer(problems[test_qid], args.options)
208
- answer_option = get_answer(problems[test_qid], args.options)
209
- answer = "(" + answer_option + ")"
210
-
211
- test_example, target = create_one_example(args.prompt_format,
212
- question,
213
- context,
214
- choice,
215
- answer,
216
- lecture,
217
- solution,
218
- test_example=False,WithOutput = True, curr_le_data=curr_le_data)
219
- examples.append(test_example)
220
-
221
- target = target.replace("Answer:", "").strip()
222
- # create the prompt input
223
- prompt_input = '\n\n'.join(examples)
224
-
225
- return prompt_input, target
226
-
227
- @dataclass(frozen=True)
228
- class InputFeatures:
229
- """
230
- A single set of features of data.
231
- Property names are the same names as the corresponding inputs to a model.
232
- """
233
-
234
- input_ids: List[List[int]]
235
- attention_mask: Optional[List[List[int]]]
236
- token_type_ids: Optional[List[List[int]]]
237
- le_input_ids: List[List[int]]
238
- le_attention_mask: Optional[List[List[int]]]
239
- le_token_type_ids: Optional[List[List[int]]]
240
- label: Optional[int]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mm-cot/vision_features/mm-cot.png DELETED
Binary file (893 kB)